Flask를 이용하여 파이토치를 REST API로 베포하기
반응형

본 게시글은 PyTorch 공식 홈페이지의 "FLASK로 REST API를 통해 PYTHON에서 PYTORCH 베포"를 진행하면서 작성한 글입니다!

 

이미지 분류 모델 구축

미리 학습된 DenseNet 모델을 통하여, 주어진 이미지 파일이 뭔지 분류하려고 한다. 

DenseNet 모델은 224x224의 RGB 이미지를 분류하기 때문에, 우선 데이터셋을 정규화해야 한다.

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

이제 미리 학습되어 있는 DenseNet 121 모델을 가지고와서 이미지 분류를 예측한다.

이전 CIFAR-10과 같이 torchvision 라이브러리의 모델을 사용하여 읽어오고 추론을 한다.

from torchvision import models

# 이미 학습된 가중치를 사용하기 위해 `pretrained` 에 `True` 값
model = models.densenet121(pretrained=True)
# 모델을 추론에만 사용할 것이므로, `eval` 모드로
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

이때 y_hat Tensor는 예측된 분류 ID의 인덱스를 포함한다.

근데 이거는 코드같은 것이고, 사람이 읽을 수 있는 분류명이 있어야 하기 때문에 이름-ID를 매핑하는 것이 필요하다.

제공되는 imagenet_class_index.json을 저장하여 이 JSON 파일을 통해 예측 결과의 인덱스에 해당하는 분류명을 표현해야한다.

따라서 get_prediction함수를 JSON을 포함하여 변경해준다.

파일링크 : imagenet_class_index.json (파이토치 홈페이지의 튜토리얼에서 제공해준다.)

import json
# 여기서 주소를 자기가 저장한 곳으로
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

이제 아래와 같은 코드로 한번 실행해보면 사진을 예측한 결과가 나온다.

# 여기서 주소를 자기가 저장한 곳으로
with open("_static/cat.jpg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))
['n02123045', 'tabby']

 

API 정의

이제 REST API에서의 엔드포인트의 요청(request)와 응답(response)를 정의해야 한다.

엔드포인트는 이미지가 포함된 파일의 매개변수를 POST로, /predict에 요청하는 방식으로 한다고 한다.

응답은 JSON으로하고, 예측 결과는 다음과 같은 예시를 원한다.

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

일단 기본적인 Flask의 문서를 살펴보자. (Flask는 미리 설치해두어야 한다.)

from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello_world():
    return 'Hello, Flask!'

app.run()

아래 주소로 들어가면 Hello Flask가 실행되어서 웹페이지에 출력되어 있음을 알 수 있다.

Running on http://127.0.0.1:5000/ 

이제 위 API 정의에 맞게 코드를 수정해보자

메소드를 predict로 변경하고, 엔드포인트의 경로역시 /predict로 변경한다. (해당 URI로 접속해야 동작)

이미지는 POST에만 보내지기 때문에, POST만 허용하도록 수정한다.

또한 API 서버는 이미지를 받는 것만 가정하므로 요청으로부터 파일을 읽게 해야 한다.

from flask import request

@app.route('/')
def hello():
    return 'Image Classification Sample'

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # Request로부터 파일 받기
        file = request.files['file']

        # 파일을 바이트로
        img_bytes = file.read()

        #예측해서 반환
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run()

주피터에서 이전 모델과 함께 위 코드들을 순서대로 정상적으로 실행시키면 위 주소에 들어갈 때 마다 로그가 기록되고, Flask가 구축된다. 

* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [02/Nov/2020 15:12:33] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [02/Nov/2020 15:12:42] "POST /predict HTTP/1.1" 200 -

이제 새로운 파일을 만들어서 연결을 시도해보자. 

API에서 정의했던 것처럼 POST 방식, 이때 /predict의 경로로, 파일을 첨부해서 요청한다.

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('_static/cat.jpg','rb')})
resp.json()

미리 고양이 사진을 저장해놓고 실행했는데, 잘 작동되는 것 같다.

(tabby는 페르시아 고양이 종류로 대충 검은줄이 있는 고양이라고 한다)

얘로 시험해보았습니다!

{'class_id': 'n02123045', 'class_name': 'tabby'}

 

전문가가 아니라 정확하지 않은 지식이 담겨있을 수 있습니다.
언제든지 댓글로 의견을 남겨주세요!

 

 

반응형