[기계학습][딥러닝] Flask를 이용하여 파이토치를 REST API로 베포하기

 

본 게시글은 PyTorch 공식 홈페이지의 "FLASK로 REST API를 통해 PYTHON에서 PYTORCH 베포"를 진행하면서 작성한 글입니다!

 

이미지 분류 모델 구축

미리 학습된 DenseNet 모델을 통하여, 주어진 이미지 파일이 뭔지 분류하려고 한다. 

DenseNet 모델은 224x224의 RGB 이미지를 분류하기 때문에, 우선 데이터셋을 정규화해야 한다.

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

이제 미리 학습되어 있는 DenseNet 121 모델을 가지고와서 이미지 분류를 예측한다.

이전 CIFAR-10과 같이 torchvision 라이브러리의 모델을 사용하여 읽어오고 추론을 한다.

from torchvision import models

# 이미 학습된 가중치를 사용하기 위해 `pretrained` 에 `True` 값
model = models.densenet121(pretrained=True)
# 모델을 추론에만 사용할 것이므로, `eval` 모드로
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

이때 y_hat Tensor는 예측된 분류 ID의 인덱스를 포함한다.

근데 이거는 코드같은 것이고, 사람이 읽을 수 있는 분류명이 있어야 하기 때문에 이름-ID를 매핑하는 것이 필요하다.

제공되는 imagenet_class_index.json을 저장하여 이 JSON 파일을 통해 예측 결과의 인덱스에 해당하는 분류명을 표현해야한다.

따라서 get_prediction함수를 JSON을 포함하여 변경해준다.

파일링크 : imagenet_class_index.json (파이토치 홈페이지의 튜토리얼에서 제공해준다.)

import json
# 여기서 주소를 자기가 저장한 곳으로
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

이제 아래와 같은 코드로 한번 실행해보면 사진을 예측한 결과가 나온다.

# 여기서 주소를 자기가 저장한 곳으로
with open("_static/cat.jpg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))
['n02123045', 'tabby']

 

API 정의

이제 REST API에서의 엔드포인트의 요청(request)와 응답(response)를 정의해야 한다.

엔드포인트는 이미지가 포함된 파일의 매개변수를 POST로, /predict에 요청하는 방식으로 한다고 한다.

응답은 JSON으로하고, 예측 결과는 다음과 같은 예시를 원한다.

경축! 아무것도 안하여 에스천사게임즈가 새로운 모습으로 재오픈 하였습니다.
어린이용이며, 설치가 필요없는 브라우저 게임입니다.
https://s1004games.com

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

일단 기본적인 Flask의 문서를 살펴보자. (Flask는 미리 설치해두어야 한다.)

from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello_world():
    return 'Hello, Flask!'

app.run()

아래 주소로 들어가면 Hello Flask가 실행되어서 웹페이지에 출력되어 있음을 알 수 있다.

Running on http://127.0.0.1:5000/ 

이제 위 API 정의에 맞게 코드를 수정해보자

메소드를 predict로 변경하고, 엔드포인트의 경로역시 /predict로 변경한다. (해당 URI로 접속해야 동작)

이미지는 POST에만 보내지기 때문에, POST만 허용하도록 수정한다.

또한 API 서버는 이미지를 받는 것만 가정하므로 요청으로부터 파일을 읽게 해야 한다.

from flask import request

@app.route('/')
def hello():
    return 'Image Classification Sample'

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # Request로부터 파일 받기
        file = request.files['file']

        # 파일을 바이트로
        img_bytes = file.read()

        #예측해서 반환
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run()

주피터에서 이전 모델과 함께 위 코드들을 순서대로 정상적으로 실행시키면 위 주소에 들어갈 때 마다 로그가 기록되고, Flask가 구축된다. 

* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [02/Nov/2020 15:12:33] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [02/Nov/2020 15:12:42] "POST /predict HTTP/1.1" 200 -

이제 새로운 파일을 만들어서 연결을 시도해보자. 

API에서 정의했던 것처럼 POST 방식, 이때 /predict의 경로로, 파일을 첨부해서 요청한다.

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('_static/cat.jpg','rb')})
resp.json()

미리 고양이 사진을 저장해놓고 실행했는데, 잘 작동되는 것 같다.

(tabby는 페르시아 고양이 종류로 대충 검은줄이 있는 고양이라고 한다)

 

얘로 시험해보았습니다!

 

{'class_id': 'n02123045', 'class_name': 'tabby'}

 [출처] https://howtolivelikehuman.tistory.com/m/113

 

 

 

본 웹사이트는 광고를 포함하고 있습니다.
광고 클릭에서 발생하는 수익금은 모두 웹사이트 서버의 유지 및 관리, 그리고 기술 콘텐츠 향상을 위해 쓰여집니다.
번호 제목 글쓴이 날짜 조회 수
공지 오라클 기본 샘플 데이터베이스 졸리운_곰 2014.01.02 25085
공지 [SQL컨셉] 서적 "SQL컨셉"의 샘플 데이타 베이스 SAMPLE DATABASE of ORACLE 가을의 곰을... 2013.02.10 24564
공지 [G_SQL] Sample Database 가을의 곰을... 2012.05.20 25943
1025 [postgreSQL] PostgreSQL 계층형 쿼리 구현 방법 졸리운_곰 2023.01.29 35
1024 [postgreSQL] ORACLE쿼리에서 postgreSQL쿼리 변환 졸리운_곰 2023.01.29 26
1023 [postgreSQL] [PostgreSQL] stored function(stored procedures) 사용하기 졸리운_곰 2023.01.23 30
1022 [SQL] CRUD 기본 사용법 file 졸리운_곰 2023.01.23 30
1021 [postgreSQL] [Docker] Docker에 PostgreSQL 설치하기 file 졸리운_곰 2023.01.21 25
1020 [MYSQL] 테이블 스키마 설계 고려사항 졸리운_곰 2022.12.03 33
1019 [MySQL] "아는 만큼 빨라진다" 마이SQL 성능 튜닝 팁 10가지 file 졸리운_곰 2022.11.29 30
1018 [오라클] 오라클 연동 오류 [ORA-01017: invalid username/password; logon denied] 졸리운_곰 2022.11.28 76
1017 [오라클] 제약조건 확인 (FK 찾기) 졸리운_곰 2022.11.28 68
1016 [ADsP] 취업 깡패 ADP 뿌시기! "빅데이터 분석가 최고의 자격증이에요" file 졸리운_곰 2022.11.20 22
1015 [기계학습] [번역] TensorFlow Lite 튜토리얼 3 부 : Raspberry Pi의 음성 인식 졸리운_곰 2022.11.18 7
1014 [기계학습] [번역] TensorFlow Lite 튜토리얼 2 부 : 음성 인식 모델 교육 졸리운_곰 2022.11.18 13
1013 [기계학습] [번역] TensorFlow Lite 튜토리얼 1 부 : Wake Word 기능 추출 졸리운_곰 2022.11.18 10
1012 [기계학습][딥러닝] Generative Adversarial Net (GAN) PyTorch 구현: 손글씨 생성 file 졸리운_곰 2022.11.18 54
1011 [hadoop] Cloudera Quick Start VM in Hyper-V file 졸리운_곰 2022.11.14 14
» [기계학습][딥러닝] Flask를 이용하여 파이토치를 REST API로 베포하기 file 졸리운_곰 2022.11.12 44
1009 [기계학습][머신러닝][딥러닝] Vanilla GAN file 졸리운_곰 2022.11.08 13
1008 [기계학습][머신러닝][딥러닝] Generative Adversarial Net (GAN) PyTorch 구현: 손글씨 생성 file 졸리운_곰 2022.11.08 103
1007 [기계학습][머신러닝][딥러닝] DCGAN 튜토리얼 졸리운_곰 2022.11.08 4
1006 [PyTorch] pytorch 기본 문법 및 코드, 팁 snippets file 졸리운_곰 2022.10.20 30
대표 김성준 주소 : 경기 용인 분당수지 U타워 등록번호 : 142-07-27414
통신판매업 신고 : 제2012-용인수지-0185호 출판업 신고 : 수지구청 제 123호 개인정보보호최고책임자 : 김성준 sjkim70@stechstar.com
대표전화 : 010-4589-2193 [fax] 02-6280-1294 COPYRIGHT(C) stechstar.com ALL RIGHTS RESERVED