使用Flask将pytorch模型部署在本地服务器

整个项目的思路

训练模型

  • 使用pytorch对resnet18 进行迁移学习,实现对自己的数据进行图像分类。需要将最后一个全连接层中的输出节点数目修改,因为我的数据集中包含有5中图像,所以这里的输出节点数目修改成5
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')     # 使用GPU
# 加载预训练的resnet18模型
net = torchvision.models.resnet18(pretrained=True)
# 冻结原网络参数,仅训练最后新替换的全连接层
for param in net.parameters():
    param.requires_grad = False
num_ftrs = net.fc.in_features   # 原网络最后一层的输入维度
net.fc = nn.Linear(num_ftrs, 5) # 替换新的连接层,输出改为5,预测5个类别
net = net.to(DEVICE)
  • 然后resnet18模型训练好之后,保存训练过程中准确率最高的模型。
 # 保存模型参数net.state_dict()
 torch.save(net.state_dict(), 'net_dict.pt')
# 保存完整模型
torch.save(net, 'net.pt')
  • 然后可以随便找一张image,用保存的训练好的模型进行预测
    Flask—python服务器
    Flask和Django都是web框架。可以将模型发布在服务器(这里使用的是本地服务器)。在对应的URL中实现对模型的调用。
app = flask.Flask(__name__)
...
...
@app.route("/predict", methods=["POST"])
def predict():
	...
	...

出现下图,说明flask服务开启成功
使用Flask将pytorch模型部署在本地服务器_第1张图片
向浏览器中输入该网址,然后可以在终端向服务器,以POST方式传过去待识别的图像。并接收从服务器传过来的识别结果。这里的终端暂时使用的是anaconda虚拟环境中的python.exe,来执行.py文件中的代码。后续将Android作为终端。
遇到的问题及解决方法
错误一:
在启动flask服务程序的下段代码中,

preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)

 # Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
   print(label)  # tensor(162)
   label_name = idx2label[label]
   r = {"label": label_name, "probability": float(prob)}
   data['predictions'].append(r)

这个地方报错:
KeyError: tensor(162)
使用Flask将pytorch模型部署在本地服务器_第2张图片
错误原因: label_name = idx2label[label]
idx2label是一个字典{key0: value0, key1: value1, key2: value2…}, 比如{0: ‘cardboard’, 1: ‘glass’, 2: ‘metal’, 3: ‘paper’, 4: ‘plastic’}。可是输出label,发现label并不是一个数,而是tensor。 所以需要将tensor转换为数值。
修改:将label_name = idx2label[label] ---->label_name = idx2label[int(label)]
错误二:
在用anaconda虚拟环境中的python.exe执行simple_request.py文件时,动态对函数参数赋值传入待预测图像的文件路径时,报错
image = open(image_path, 'rb').read() OSError: [Errno 22] Invalid argument: "'e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg'"
在这里插入图片描述
修改
>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file='e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg中的文件路径改为>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file=e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg
即去掉图片路径中的单引号,光是这个小小的错误让我头疼了一整天。。
完整代码有时间会传到github上的。。

你可能感兴趣的:(计算机视觉,工具使用&环境安装,web编程)