pip install Flask==2.0.2
pip install Flask_Cors==3.0.9
pip install Pillow
# 1.导入Flask扩展
from flask import Flask
# 2.创建Flask应用程序实例
# 需要传入__name__,作用是为了确定资源所在的路径
app = Flask(__name__)
# 3.定义路由及视图函数
# Flask中定义路由是通过装饰器实现的
# 路由默认只支持GET,如果需要增加,自行制定
@app.route('/', methods=['GET', 'POST'])
def index():
return "hellow flask"
# 4.启动程序
if __name__ == '__main__':
# 执行了app.run,就会将Flask程序运行在简易服务器上
app.run()
import os
import io
import json
import time
import argparse
import cv2
import torch
import imageio
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from data.custom_transforms import FixedResize, AddIgnoreRegions, ToTensor, Normalize
import base64
from utils.utils import get_output, mkdir_if_missing
import numpy as np
from flask_cors import CORS
from utils.common_config import get_model
from utils.config import create_config
import torchvision.transforms as transforms
# 设置允许的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG'])
# 创建Flask实例
app = Flask(__name__)
CORS(app)
# 导入调色板文件
palette_path = "palette.json"
assert os.path.exists(palette_path), f"palette {palette_path} not found."
with open(palette_path, "rb") as f:
pallette_dict = json.load(f)
pallette = []
for v in pallette_dict.values():
pallette += v
weights_path = "configs/PADResults/PASCALContext/hrnet_w18/pad_net/best_model.pth.tar"
assert os.path.exists(weights_path),"weights path does not exits.."
# Parser
parser = argparse.ArgumentParser(description='Vanilla Training')
parser.add_argument('--config_env', default='configs/env.yml',
help='Config file for the environment')
parser.add_argument('--config_exp', default='configs/pascal/pad_net.yml',
help='Config file for the experiment')
args = parser.parse_args()
# Retrieve config file
cv2.setNumThreads(0)
p = create_config(args.config_env, args.config_exp)
# select device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
# create model
model = get_model(p)
model = torch.nn.DataParallel(model)
model = model.cuda()
#loal model weights
# model.load_state_dict(torch.load(p['best_model']))
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
# 图像处理
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.ToPILImage(),
transforms.Resize([512, 512]),
# AddIgnoreRegions(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) ])
img = io.BytesIO(image_bytes)
image = Image.open(img)
if image.mode != "RGB":
raise ValueError("input file does not RGB image...")
image = np.array(image, dtype='uint8')
print(my_transforms(image).shape)
return my_transforms(image).to(device)
# return image.unsqueeze(0).to(device)
# 获取当前时间表示的字符串的小数部分,精确到0.1毫秒
def get_secondFloat(timestamp):
secondFloat = ('%.4f' %(timestamp%1))[1:]
return secondFloat
# 获取当前时间表示的字符串,精确到0.1毫秒
def get_timeString():
now_timestamp = time.time()
now_structTime = time.localtime(now_timestamp)
timeString_pattern = '%Y%m%d_%H%M%S'
now_timeString_1 = time.strftime(timeString_pattern, now_structTime)
return now_timeString_1
def get_prediction(p, image_bytes):
model.eval()
tasks = p.TASKS.NAMES
results_dirPath = 'static/results'
# save_dirs = os.path.join(results_dirPath, task)
if os.path.isdir(results_dirPath):
mkdir_if_missing(results_dirPath)
#
inputs = transform_image(image_bytes=image_bytes)
inputs = inputs.cuda(non_blocking=True)
inputs = inputs.reshape(1, 3, 512, 512)
print(inputs.shape)
# print(inputs)
output = model(inputs)
# 保存预测结果为图片
for task in tasks: # normals 1,512,512,3
if task == 'normals' :
output_task = get_output(output[task], task).cpu().data.numpy()
# for jj in range(0,1):
for jj in range(int(inputs.size()[0])):
result = cv2.resize(output_task[jj], dsize=(512, 512),
interpolation=p.TASKS.INFER_FLAGVALS[task])
imageio.imwrite(os.path.join(results_dirPath, task + '.png'), result.astype(np.uint8))
elif task == 'semseg' :
prediction = output['semseg'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
mask = Image.fromarray(prediction)
mask.putpalette(pallette)
mask.save(os.path.join(results_dirPath, task + '.png'))
else:
pass
return {"semseg": os.path.join(results_dirPath, 'semseg.png'),
"normals": os.path.join(results_dirPath, 'normals.png')
}
# 前后端交互
@app.route('/predict', methods=['GET', 'POST'])
@torch.no_grad()
def predict():
image = request.files['file']
print(image.filename)
received_dirPath = 'webimage/received_images'
if not os.path.isdir(received_dirPath):
os.makedirs(received_dirPath)
imageFilePath = os.path.join(received_dirPath, image.filename)
# print("save finished")
img_bytes = image.read()
# print(img_bytes)
result_info = get_prediction(p, img_bytes)
print(result_info)
return jsonify({'status': 1,
'semseg_url': result_info['semseg'],
'normals_url': result_info['normals']
})
@app.route('/', methods=["GET", "POST"])
def root():
return render_template("./predict.html")
if __name__ == '__main__':
app.run(host="127.0.0.1", port=5005)
这部分借鉴的别人的代码
DOCTYPE html>
<html>
<head>
<title>多任务学习展示title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<script src="https://apps.bdimg.com/libs/jquery/2.1.4/jquery.min.js">script>
head>
<body>
<h1 style="background-color:lightcoral;text-align:center;font-family:arial;color:cornflowerblue;font-size:50px;">多任务学习h1>
<div style="text-align: left;margin-left: 0px;margin-top: 0px;/* width: 60px; */">
<div style="float:left; margin-left: 100px;margin-top: 150px;">
<img src="static/2008_000036.jpg" id="img0" style="margin-left:10px;width: 20rem;height: 20rem;">
<br>
<a href="javascript:;" class="file" style="text-align: center">选择文件
<input type="file" name="file" id="file0" style="text-align: center"><br>
a>
div>
<div style="margin-left: 525px; margin-top: 0px;width: 20px;height: 0px;">
<input type="button" id="b0" onclick="test()" value="使用多任务模型进行预测" style="margin-top: 250px;margin-left: 75px;width: auto;">
div>
<div style="margin-right: px;margin-left: 880px;margin-top: 0px;">
<div style="margin-right: 50px;margin-top: 0px;">
<img src="static/sem_2008_000036.png" id="img1" style="width: 20rem;height: 20rem;margin-top: 0px;">
语义分割
div>
<div style="margin-right: 50px">
<img src="static/nor_008_000036.png" id="img2" style="margin-top:20px;width: 20rem;height: 20rem;">
表面法线估计
div>
div>
div>
<script type="text/javascript">
$("#file0").change(function(){
var objUrl = getObjectURL(this.files[0]) ;//获取文件信息
console.log("objUrl = "+objUrl);
if (objUrl) {
$("#img0").attr("src", objUrl);
}
});
function test() {
var fileobj = $("#file0")[0].files[0];
console.log(fileobj);
var form = new FormData();
form.append("file", fileobj);
var Con1 = $("#img1");
var Con2 = $("#img2");
var out='';
var flower='';
var results = $.ajax({
type: 'POST',
url: "predict",
data: form,
async: false, //同步执行
processData: false, // 告诉jquery要传输data对象
contentType: false, //告诉jquery不需要增加请求头对于contentType的设置
dataType: "json",
success: function (arg) {
out = arg;
console.log(out);
var r = window.confirm("预测完成,显示图片");
if(r == true) {
document.getElementById("img1").src=out['semseg_url'];
document.getElementById("img2").src=out['normals_url'];
}
},error:function(){
console.log("后台处理错误");
}
});
}
function getObjectURL(file) {
var url = null;
if(window.createObjectURL!=undefined) {
url = window.createObjectURL(file) ;
}else if (window.URL!=undefined) { // mozilla(firefox)
url = window.URL.createObjectURL(file) ;
}else if (window.webkitURL!=undefined) { // webkit or chrome
url = window.webkitURL.createObjectURL(file) ;
}
return url ;
}
script>
<style>
.file {
position: relative;
/*display: inline-block;*/
background: #CCC ;
border: 1px solid #CCC;
padding: 4px 4px;
overflow: hidden;
text-decoration: none;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #333;
font-size: 13px;
}
.file input {
position: absolute;
font-size: 13px;
right: 0;
top: 0;
opacity: 0;
border: 1px solid #333;
padding: 4px 4px;
overflow: hidden;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #FFFFFF;
}
#b0{
background: #1899FF;
border: 1px solid #CCC;
padding: 4px 10px;
overflow: hidden;
text-indent: 0;
width:60px;
height:28px;
line-height: 20px;
border-radius: 5px;
color: #FFFFFF;
font-size: 13px;
}
body{
background: paleturquoise;
}
/*.gradient{*/
/*filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);*/
/*-ms-filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);!*IE8*!*/
/*background:#1899FF; !* 一些不支持背景渐变的浏览器 *!*/
/*background:-moz-linear-gradient(top, #fff, #1899FF);*/
/*background:-webkit-gradient(linear, 0 0, 0 bottom, from(#fff), to(#ccc));*/
/*background:-o-linear-gradient(top, #fff, #ccc);*/
/*}*/
style>
body>
html>