Flask是一种用Python编写的轻量级Web框架,可以帮助您快速构建Web应用程序。
如果我们正在使用PyTorch框架开发深度学习应用程序,并希望将其部署到Web服务器上,则可以使用Flask框架实现。本文将介绍如何使用Flask对前一篇博客中所编写的基于PyTorch框架的图像分类模型进行本地部署,共包含两个py文件(flask_server.py和flask_predict.py),分别表示服务端和客户端,以实现对该模型的远程访问和使用,下文将会详细介绍。(点击这里:基于PyTorch实现经典网络架构的花卉图像分类模型)
在使用Flask部署PyTorch应用程序之前,需要在本地计算机上安装Flask库,若pip install flask下载速度过慢,可换成conda install flask(安装了anaconda3),就能很快下载完毕。
创建一个名为app的Flask对象,并将__name__作为参数传递给它(__name__是一个特殊变量,它表示当前模块的名称,通常用于确定应用程序根目录的位置)。接着创建一个名为model的变量,并将其初始化为None,该变量将用于存储训练好的PyTorch模型。再创建一个名为use_gpu的布尔变量,并将其初始化为False,这个变量将用于控制是否使用GPU加速模型的计算(GPU不错的小伙伴建议为True)。
初始化的流程较为固定,可作为模板进行使用,代码如下:
app = flask.Flask(__name__)
model = None
use_gpu = False
定义一个load_model函数,传入训练模型model、相应结构和参数。需要注意的是,model的值需与训练时所用模型相同(重要!!),同时将model声明为全局变量。
接着重新定义全连接层(102表示最后输出的类别,需根据自身任务来确定),再加载best.pth文件(best.pth存储着训练时效果最好的参数,与前篇博客是同一文件),再使用model.load_state_dict()函数将保存的模型参数加载到我们定义的模型中.
最后使用model.eval()函数将模型设置为验证模式,这将禁用例如dropout和batch normalization等一些训练时的策略,输出分类的概率值。
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
global model
#这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 类别数自己根据自己任务来
#print(model)
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
#将模型指定为测试格式
model.eval()
#是否使用gpu
if use_gpu:
model.cuda()
数据预处理部分大致与验证集相似。不同之处在于添加了一个格式转换,有可能请求端所给image的格式不同,因此我们需要将其统一至RGB格式(训练时所用格式)。
def prepare_image(image, target_size):
#针对不同模型,image的格式不同,但需要统一至RGB格式
if image.mode != 'RGB':
image = image.convert("RGB")
# Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
image = transforms.Resize(target_size)(image)
image = transforms.ToTensor()(image)
# Convert to Torch.Tensor and normalize. mean与std (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis.增加一个维度,用于按batch测试 本次这里一次测试一张
image = image[None]
if use_gpu:
image = image.cuda()
return Variable(image, volatile=True) #不需要求导
定义一个predict函数用于接收POST请求并进行图像预测的Flask路由处理函数。当POST请求中包含一个名为“image”的文件时,该函数将读取该文件并使用预处理函数prepare_image()进行图像预处理。然后将预处理后的图像作为输入传递给已加载的PyTorch模型,使用softmax函数对预测结果进行归一化,选取前3个最高概率的结果,将它们以标签和概率的形式打包成字典,存入data字典的“predictions”列表中,最终以JSON格式返回该data字典。如果请求成功,则“success”键将被设置为True,代码如下:
@app.route("/predict", methods=["POST"])
def predict():
# Initialize the data dictionary that will be returned from the view.
#做一个标志,刚开始无图像传入时为false,传入图像时为true
data = {"success": False}
# 如果收到请求
if flask.request.method == 'POST':
#判断是否为图像
if flask.request.files.get("image"):
# Read the image in PIL format
# 将收到的图像进行读取
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image)) #二进制数据
# 利用上面的预处理函数将读入的图像进行预处理
image = prepare_image(image, target_size=(64, 64))
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
#将data字典增加一个key,value,其中value为list格式
data['predictions'] = list()
# Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
#label_name = idx2label[str(label)]
r = {"label": str(label), "probability": float(prob)}
#将预测结果添加至data字典
data['predictions'].append(r)
# Indicate that the request was a success.
data["success"] = True
# 将最终结果以json格式文件传出
return flask.jsonify(data)
在最后加上下段代码,这段代码的作用是在服务器启动时加载PyTorch模型,然后启动Flask服务器,监听端口号为5012(自己定义)的请求。
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
#先加载模型
load_model()
#再开启服务
app.run(port='5012')
# url和端口写成自己的
flask_url = 'http://127.0.0.1:5012/predict'
def predict_result(image_path):
#啥方法都行
image = open(image_path, 'rb').read()
payload = {'image': image}
#request发给server.
r = requests.post(flask_url, files=payload).json()
# 成功的话在返回.
if r['success']:
# 输出结果.
for (i, result) in enumerate(r['predictions']):
print('{}. {}: {:.4f}'.format(i + 1, result['label'],
result['probability']))
# 失败了就打印.
else:
print('Request failed')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification demo')
parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file')
args = parser.parse_args()
predict_result(args.file)
客户端代码较简单,flask_url的由来下文会讲。首先读取待预测图像,将其转换为二进制格式,并构造请求payload。接着发送POST请求至服务端,等待服务端返回结果,若服务端返回的JSON文件中success字段为True,则表示预测成功,将预测结果输出;否则表示预测失败,输出相应信息。此外可以通过命令行参数–file指定待预测图像的路径。
首先,打开pycharm下部terminal面板,输出python+服务端文件(python flask_server.py),即会启动服务,并打印相应英文信息,如下图。
接着,将file属性中地址修改为所要预测图片的地址。
最后,点击运行flask_predict.py文件,即可输出结果。
结果如下:
各位小伙伴可以关注博主,博客内所涉及代码与数据集,私聊博主可以免费发给大家哦。