Flask를 이용하여 파이토치를 REST API로 베포하기2
반응형

본 게시글은 PyTorch 공식 홈페이지의 "FLASK로 REST API를 통해 PYTHON에서 PYTORCH 베포"를 진행하면서 작성한 글입니다!

 

이전 게시글에서는 Flask 서버를 구동하면서 파이썬 코드를 통해 실험해보았었다.

import requests
resp = requests.post("http://localhost:5000/predict", 
                     files={"file": open('_static/cat.jpg','rb')})
resp.json()

 

HTML 렌더링

이번에는 HTML 문서와 연결시켜서 해보겠다.

방법은 간단하다

랜더링이라 하는데 우선 미리 html 문서를 만들어 놓는다.

여기서 주목해야 될 부분은 form 부분이다.

이전에 API에서 POST 형식으로 'file'명칭을 기존 주소 + /predict 으로 호출하기로 정했었기 때문에 그 형식을 따라서 사진을 보내야 한다.

<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Classification</title>
</head>
<body>

    <form action="http://localhost:5000/predict"
          method="post"
          enctype="multipart/form-data">
        <input type="file" name="file" id="file" />
        <input type="submit" />
        <div id="preview"></div>

    </form>

    <script>
        var upload = document.querySelector('#file');
        var upload2 = document.querySelector('#preview');

        var reader = new FileReader();
        reader.onload = (function () {

        	this.image = document.createElement('img');
        	var vm = this;

        	return function (e) {
        		vm.image.src = e.target.result
        	}
        })()

        upload.addEventListener('change',function (e) {
        	var get_file = e.target.files;

        	if(get_file){
        		reader.readAsDataURL(get_file[0]);
       		 }
        	image.style.maxWidth='500px';
        	image.style.maxHeight='500px';
        	preview.appendChild(image);
        })
    </script>
</body>
</html>

HTML에서 type = file이면 자동으로 파일추가 버튼이 생성되어 파일을 선택할 수 있다.

밑의 스크립트는 사진을 추가했을때 밑에 띄워주는 코드인데

구조는 file 버튼에 이벤트리스너를 연결시켜 사진이 추가되면 자동으로 FileReader를 통해 입력받고 사용자에게 보여주는 코드이다.....

 

이전의 파이썬 코드를 한번 더 첨부한다.

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
from flask import render_template


app = Flask(__name__)
# 여기서 주소를 자기가 저장한 곳으로
imagenet_class_index = json.load(open('./_static/imagenet_class_index.json'))
model = models.resnet18(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

# html 연결
@app.route('/')
def main():
    return render_template('index.html')

# 사진 분류
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run()

 

결과

주소가 바뀜에 주목하자!

 

전문가가 아니라 정확하지 않은 지식이 담겨있을 수 있습니다.
언제든지 댓글로 의견을 남겨주세요!

 

 

반응형