使用Flask简单部署深度学习模型

使用Flask简单部署深度学习模型

一、安装 Flask

pip install Flask==2.0.2
pip install Flask_Cors==3.0.9
pip install Pillow

二、Flask程序运行过程

  1. 当客户端想要获取资源时,一般会通过浏览器发起HTTP请求。
  2. 此时,Web服务器会把来自客户端的所有请求都交给Flask程序实例。
  3. 程序实例使用Werkzeug来做路由分发(URL请求和视图函数之间的对应关系)。
  4. 根据每个URL请求,找到具体的视图函数并进行调用。在Flask程序中,路由的实现一般是通过程序实例的装饰器实现。
  5. Flask调用视图函数后,可以返回两种内容:
  • 字符串内容:将视图函数的返回值作为响应的内容,返回给客户端(浏览器)。
  • HTML模板内容:获取到数据后,把数据传入HTML模板文件中,模板引擎负责渲染HTTP响应数据,然后返回响应数据给客户端(浏览器)。

三、 Flask开发

# 1.导入Flask扩展
from flask import Flask

# 2.创建Flask应用程序实例
# 需要传入__name__,作用是为了确定资源所在的路径
app = Flask(__name__)

# 3.定义路由及视图函数
# Flask中定义路由是通过装饰器实现的
# 路由默认只支持GET,如果需要增加,自行制定
@app.route('/', methods=['GET', 'POST'])
def index():
    return "hellow flask"
  
# 4.启动程序
if __name__ == '__main__':
    # 执行了app.run,就会将Flask程序运行在简易服务器上  
    app.run()

四、 使用Flask框架完成前后端交互

import os
import io
import json
import time
import argparse
import cv2
import torch
import imageio
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from data.custom_transforms import FixedResize, AddIgnoreRegions, ToTensor, Normalize
import base64
from utils.utils import get_output, mkdir_if_missing
import numpy as np
from flask_cors import CORS
from utils.common_config import get_model
from utils.config import create_config
import torchvision.transforms as transforms
# 设置允许的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG'])

# 创建Flask实例
app = Flask(__name__)
CORS(app)

# 导入调色板文件
palette_path = "palette.json"

assert os.path.exists(palette_path), f"palette {palette_path} not found."
with open(palette_path, "rb") as f:
    pallette_dict = json.load(f)
    pallette = []
    for v in pallette_dict.values():
        pallette += v

weights_path = "configs/PADResults/PASCALContext/hrnet_w18/pad_net/best_model.pth.tar"
assert  os.path.exists(weights_path),"weights path does not exits.."

# Parser
parser = argparse.ArgumentParser(description='Vanilla Training')
parser.add_argument('--config_env', default='configs/env.yml',
                    help='Config file for the environment')
parser.add_argument('--config_exp', default='configs/pascal/pad_net.yml',
                    help='Config file for the experiment')
args = parser.parse_args()

# Retrieve config file
cv2.setNumThreads(0)
p = create_config(args.config_env, args.config_exp)
# select device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
# create model
model = get_model(p)
model = torch.nn.DataParallel(model)
model = model.cuda()
#loal model weights
# model.load_state_dict(torch.load(p['best_model']))
model.load_state_dict(torch.load(weights_path, map_location=device))

model.eval()

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS

# 图像处理
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.ToPILImage(),
                                       transforms.Resize([512, 512]),
                                       # AddIgnoreRegions(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225]) ])

    img = io.BytesIO(image_bytes)
    image = Image.open(img)
    if image.mode != "RGB":
        raise  ValueError("input file does not RGB image...")
    image = np.array(image, dtype='uint8')
    print(my_transforms(image).shape)
    return my_transforms(image).to(device)
    # return image.unsqueeze(0).to(device)

# 获取当前时间表示的字符串的小数部分,精确到0.1毫秒
def get_secondFloat(timestamp):
    secondFloat = ('%.4f' %(timestamp%1))[1:]
    return secondFloat

# 获取当前时间表示的字符串,精确到0.1毫秒
def get_timeString():
    now_timestamp = time.time()
    now_structTime = time.localtime(now_timestamp)
    timeString_pattern = '%Y%m%d_%H%M%S'
    now_timeString_1 = time.strftime(timeString_pattern, now_structTime)
    return now_timeString_1

def get_prediction(p, image_bytes):
    model.eval()
    tasks = p.TASKS.NAMES
    results_dirPath = 'static/results'
    # save_dirs = os.path.join(results_dirPath, task)
    if os.path.isdir(results_dirPath):
        mkdir_if_missing(results_dirPath)
    #
    inputs = transform_image(image_bytes=image_bytes)
    inputs = inputs.cuda(non_blocking=True)
    inputs = inputs.reshape(1, 3, 512, 512)
    print(inputs.shape)
    # print(inputs)
    output = model(inputs)
    # 保存预测结果为图片
    for task in tasks: # normals 1,512,512,3
        if task == 'normals' :
            output_task = get_output(output[task], task).cpu().data.numpy()
        # for jj in range(0,1):
            for jj in range(int(inputs.size()[0])):
                result = cv2.resize(output_task[jj], dsize=(512, 512),
                                    interpolation=p.TASKS.INFER_FLAGVALS[task])
                imageio.imwrite(os.path.join(results_dirPath, task + '.png'), result.astype(np.uint8))
        elif task == 'semseg' :
            prediction = output['semseg'].argmax(1).squeeze(0)
            prediction = prediction.to("cpu").numpy().astype(np.uint8)
            mask = Image.fromarray(prediction)
            mask.putpalette(pallette)
            mask.save(os.path.join(results_dirPath, task + '.png'))
        else:
            pass

    return {"semseg": os.path.join(results_dirPath, 'semseg.png'),
            "normals": os.path.join(results_dirPath, 'normals.png')
            }

# 前后端交互
@app.route('/predict', methods=['GET', 'POST'])
@torch.no_grad()
def predict():
    image = request.files['file']
    print(image.filename)

    received_dirPath = 'webimage/received_images'
    if not os.path.isdir(received_dirPath):
        os.makedirs(received_dirPath)
    imageFilePath = os.path.join(received_dirPath, image.filename)
    # print("save finished")
    img_bytes = image.read()
    # print(img_bytes)
    result_info = get_prediction(p, img_bytes)
    print(result_info)
    return jsonify({'status': 1,
                    'semseg_url':  result_info['semseg'],
                    'normals_url': result_info['normals']
                    })

@app.route('/', methods=["GET", "POST"])
def root():
    return render_template("./predict.html")

if __name__ == '__main__':
    app.run(host="127.0.0.1", port=5005)



五、 前端HTML页面

这部分借鉴的别人的代码

DOCTYPE html>
<html>
<head>
    <title>多任务学习展示title>
    <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
    <script src="https://apps.bdimg.com/libs/jquery/2.1.4/jquery.min.js">script>
head>
<body>

<h1 style="background-color:lightcoral;text-align:center;font-family:arial;color:cornflowerblue;font-size:50px;">多任务学习h1>

<div style="text-align: left;margin-left: 0px;margin-top: 0px;/* width: 60px; */">
    <div style="float:left; margin-left: 100px;margin-top: 150px;">
        <img src="static/2008_000036.jpg" id="img0" style="margin-left:10px;width: 20rem;height: 20rem;">
        <br>
        <a href="javascript:;" class="file" style="text-align: center">选择文件
            <input type="file" name="file" id="file0" style="text-align: center"><br>
        a>
    div>
    <div style="margin-left: 525px; margin-top: 0px;width: 20px;height: 0px;">
        <input type="button" id="b0" onclick="test()" value="使用多任务模型进行预测" style="margin-top: 250px;margin-left: 75px;width: auto;">
    div>
    <div style="margin-right: px;margin-left: 880px;margin-top: 0px;">
        
        <div style="margin-right: 50px;margin-top: 0px;">
            <img src="static/sem_2008_000036.png" id="img1" style="width: 20rem;height: 20rem;margin-top: 0px;">
            语义分割
        div>
        <div style="margin-right: 50px">
            <img src="static/nor_008_000036.png" id="img2" style="margin-top:20px;width: 20rem;height: 20rem;">
            表面法线估计
        div>
    div>
div>

<script type="text/javascript">
    $("#file0").change(function(){
        var objUrl = getObjectURL(this.files[0]) ;//获取文件信息
        console.log("objUrl = "+objUrl);
        if (objUrl) {
            $("#img0").attr("src", objUrl);
        }
    });

    function test() {
        var fileobj = $("#file0")[0].files[0];
        console.log(fileobj);
        var form = new FormData();
        form.append("file", fileobj);
        var Con1 = $("#img1");
        var Con2 = $("#img2");
        var out='';
        var flower='';
        var results = $.ajax({
            type: 'POST',
            url: "predict",
            data: form,
            async: false,       //同步执行
            processData: false, // 告诉jquery要传输data对象
            contentType: false, //告诉jquery不需要增加请求头对于contentType的设置
            dataType: "json",
            success: function (arg) {
                out = arg;
                console.log(out);
                var r = window.confirm("预测完成,显示图片");
                if(r == true) {
                    document.getElementById("img1").src=out['semseg_url'];
                    document.getElementById("img2").src=out['normals_url'];
                }


        },error:function(){
                console.log("后台处理错误");
            }
    });


    }

    function getObjectURL(file) {
        var url = null;
        if(window.createObjectURL!=undefined) {
            url = window.createObjectURL(file) ;
        }else if (window.URL!=undefined) { // mozilla(firefox)
            url = window.URL.createObjectURL(file) ;
        }else if (window.webkitURL!=undefined) { // webkit or chrome
            url = window.webkitURL.createObjectURL(file) ;
        }
        return url ;
    }
script>
<style>
    .file {
        position: relative;
        /*display: inline-block;*/
        background: #CCC ;
        border: 1px solid #CCC;
        padding: 4px 4px;
        overflow: hidden;
        text-decoration: none;
        text-indent: 0;
        width:100px;
        height:30px;
        line-height: 30px;
        border-radius: 5px;
        color: #333;
        font-size: 13px;

    }
    .file input {
        position: absolute;
        font-size: 13px;
        right: 0;
        top: 0;
        opacity: 0;
        border: 1px solid #333;
        padding: 4px 4px;
        overflow: hidden;
        text-indent: 0;
        width:100px;
        height:30px;
        line-height: 30px;
        border-radius: 5px;
        color: #FFFFFF;

    }
    #b0{
        background: #1899FF;
        border: 1px solid #CCC;
        padding: 4px 10px;
        overflow: hidden;
        text-indent: 0;
        width:60px;
        height:28px;
        line-height: 20px;
        border-radius: 5px;
        color: #FFFFFF;
        font-size: 13px;
    }
    body{
        background: paleturquoise;
    }
    /*.gradient{*/

        /*filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);*/
        /*-ms-filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);!*IE8*!*/
        /*background:#1899FF; !* 一些不支持背景渐变的浏览器 *!*/
        /*background:-moz-linear-gradient(top, #fff, #1899FF);*/
        /*background:-webkit-gradient(linear, 0 0, 0 bottom, from(#fff), to(#ccc));*/
        /*background:-o-linear-gradient(top, #fff, #ccc);*/
    /*}*/
style>
body>
html>

六、 debug

1. 将需要读取的图片放在static文件夹下,否则读取不到

2. 在本地测试时通过映射访问服务器的127.0.0.1

你可能感兴趣的:(CV,计算机视觉,python,flask,深度学习)