state_dice是一个简单的python字典,映射了每一层的参数名称和数值。
# save
torch.save(model.state_dict,PATH)
#load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
要清晰的理解这种模型保存和加载的方式,此保存和加载的过程比较直观,可以使用更少的代码。这种保存的方式是使用python的pickle模块保存整个模块。这种方式的缺点也是往往被大家误解忽视的地方是序列化的数据绑定到特定的类,并且使用确切的目录结构。pickle不会保存模型累本身,而且将其保存包含类的文件路径。该路径在加载时使用。
# save
torch.save(model, PATH)
# load
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
直观地来说,这种保存整个模型并不是表面意义上的,在加载中仍需使用模型的相关定义代码,只是不再代码中人为的显示定义,而是pickle反序列时,去加载训练时相同的文件路径下的模型定义。这就导致在其他项目中使用或者代码重构时,容易出现错误。
相关问题issue:torch.load() requires model module in the same folder · Issue #3678 · pytorch/pytorch · GitHub
注:torch.save使用的是python中的pickle来进行序列化,PyTorch 1.6 版本切换了新的zipfile-base 文件格式。torch.load仍然可以加载以前的格式,如果在保存中仍想使用之前的格式,通过传入参数_use_new_zipfile_serialization=False 来控制
针对PyTorch的模型部署,最简单的方式就是在python中使用restful的方式提供服务,这是最简单的方式,但是不适用于具有高性能要求的用例,一些对系统性能要求不高的场景下可以使用。 接下来我们以flask为例,你也可以使用django等其他python web framework。在生产环境,比较简单的一种部署方式就是每块卡对应一个容器环境,在每个容器中使用flask来启动服务。
flask 是轻量级的web应用框架,使用它可以轻松的来部署我们的模型推理服务。 下面给出简单的示例伪代码:
from flask import Flask
from flask import request
app = Flask(__name__)
model = load_model() # load your model
@app.route('/modelinference, methods=['POST'])
def model_inference():
input_data = request.form.get('input_data', '')
output_data = model(input_data)
return output
app.run(host='0.0.0.0', port=7000,debug=True)
上面的代码是一个简单的示例,在flask启动的时候,我们就将模型加载进去,这样在每次请求的时候无需在加载模型,减少了每次请求的时间。上述例子中使用的是表单的方式向服务端post 输入数据,然后进行模型推理的相关逻辑计算,可以将预处理和后处理等都集成在此服务中。
需要注意的是上述的代码是采用的debug方式启动的,在生产环境中,我们不能用这种方式来部署。常用的方式是借助gunicorn来启动flask应用。对于gunicore可以通过设置一些参数,来实现异步和多进程、多线程等。
在生产测试的时候,大量请求同时并发发送时,如遇到服务端返回502、504等错误代码,可通过抓包工具等进行请求排查,在相应的进行调整gunicore中的参数。
在python中训练,在python中部署,可以说真的十分简单。但是往往由于python本身的性能和应用场景,一些对系统稳定性等要求比较高的情况下,我们就要使用c++语言进行部署开发。接下来,我们就会一步一步的介绍如何将在python 中训练的模型在c++中部署。
TorchScript是Pytorch模型的intermediate representation(IR),可以在更高性能环境下运行,例如c++。TorchScript 代码可以在其自己的解释器中调用,同时这种格式允许我们将整个模型保存到磁盘上,然后将其加载到另一个环境中。如何获得TorchScript形式的模型,官方提供了两种方式,这两种方式有各自的限制,很多时候都是两种方式混合使用。
使用torch.jit.trace,并传入了Module和示例输入,该方法记录了运行Module时发生的操作,并创建了torch.jit.ScriptModule的实例。
#使用示例
traced_cell = torch.jit.trace(my_cell, (x, h))
另外我们可以通过.graph属性来检查图,但是这是一种非常低级的表示形式,图中包含的大多数信息对最终的用户没有用,我们也可以使用.code属性查看python的语法解释
# 查看图
print(traced_cell.graph)
#查看python语法解释
print(traced_cell.code)
notice——trace会完全按照我们所说的去做:运行代码,记录发生的操作,并构造一个可以做到这一点的ScriptModule。这样就导致控制流相关的操作不会被记录(只会记录输入示例的那种情况)
如上文所述,trace module虽然用法简单,但是无法捕捉到控制流相关的操作。因为官方还提供了脚本编译器来直接分析python源代码并将其转化为TorchScript。
然而script compiler的方式也不是万能的,例如操作有动态输入等情况,script就会出问题。因此,往往会将两者混合使用来得到我们期待的TorchScript模型。
具体的如果使用trace和script,此处就不将仔细叙述,如果有需要可以详细查看官方文档的使用示例。当我们成功地得到了TorchScript模型后,我们将如何保存和加载呢?
# save
traced.save('wrapped_rnn.zip')
# load
loaded = torch.jit.load('wrapped_rnn.zip')
前面我们介绍了TorchScript,并且介绍了两种将python中的Pytorch转化为TorchScript模型的方式。接下来我们将进入正题,在c++中部署PyTorch模型,显而易见地,我们需要先将模型转化并保存为TorchScript模型,然后在c++中进行模型的加载。 对于推理服务的接口通信,一般我们会使用grpc或者http service来处理。对于一个简单的c++中加载PyTorch模型示例,一般包括以下几个过程:
要想在c++中加载序列化的PyTorch模型,必须依赖PyTorch的c++ api 即LibTorch。所以请先按照官方所提供的方式进行相关的环境配置。
上面已经进行了相关阐述。
# 一个最小的CMakeLists.txt文件
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
#include // One-stop header.
#include
#include
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app \n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
}
// Create a vector of inputs.
std::vector inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
ONNX是Open Neural Network Exchange的缩写,ONNX为AI模型(深度学习模型和传统机器学习模型)提供了开源标准的格式。我们可以将模型转化为ONNX进行生产部署,同时也可以将其作为中介格式,使模型在不同的框架间相互转化。
相关扩展:ONNX RUNTIM 是一种跨平台的推理和训练加速器。使用ONNX RUNTIM 可以提升模型的推理性能、减少训练大型模型的时间和成本、可在python中训练并部署到其他语言应用中、可在不同的硬件和操作系统上运行。
我们以图片为例进行示例说明(动态尺寸输入设置):
# 导出模型
def export_onnx_model(model, input_shape, onnx_path, input_names=None, output_names=None, dynamic_axes=None):
inputs = torch.ones(*input_shape)
torch.onnx.export(model, inputs, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
def convert_model():
model = load_model() # PyTorch模型
batch_size = 1
width = 112
height = 112
#将batch_size,width,height都设置成动态的大小
input_shape = (batch_size, 3, width, height)
input_names = ['input']
output_names = ['output']
dynamic_axes = {
'input': {
0: 'batch_size', 2: 'width', 3: 'height'},
'output': {
0: 'batch_size', 2: 'width', 3: 'height'}}
export_onnx_model(model, input_shape, onnx_path, input_names, output_names, dynamic_axes=dynamic_axes)
总结的来说,torch.onnx.export方法传入模型和相关的示例输入,通过设置动态输入可以使导出的模型接受不同的输入尺度。
对于onnx模型,很多推理框架都支持,我们也可以简单的使用onnx的API进行加载或者使用onnx runtime的api进行部署。
import onnx
onnx_model = onnx.load("your_model.onnx")
onnx.checker.check_model(onnx_model)
这里也推荐使用Netron可视化模型软件,进行模型前后转换的比对。
import onnxruntime
ort_session = onnxruntime.InferenceSession("your_model.onnx")
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {
ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
pip install onnx onnxruntime
注意:安装的onnxruntime 是cpu版本的还是gpu版本的,对于gpu版本的安装对cuda的版本有具体的要求。在安装时仔细看官方文档的版本对应关系,避免踩坑
在转化自己的模型时,经常会遇到一些不支持的操作。比如avg_pool的动态输入等,一般常发生在一些输入是动态的时候,如果遇到相应的问题,可以多去onnx GitHub查看有没有相应的解决办法和操作支持,在了解自己的网络的情况下,也可以考虑进行模型拆分或者底层op的自己实现。
今年亚马逊和facebook还联合推出了TorchServe,针对PyTorch进行服务部署。但是这一工作似乎评价不一,我并没有实际使用,有兴趣和需求的可以尝试。
相关推荐阅读:
TorchServe github: https://github.com/pytorch/serve
如何评价 PyTorch 在 2020 年 4 月推出的 TorchServe
另外现在有很多优秀的模型推理服务项目,比如nvidia家的triton等。这些开源服务已经将服务通信、模型管理等都做好了,我们也可以在这些开源项目上面进行应用和二次开发。值得一提的话,如果你真的希望模型服务性能上有很大的改善,要针对性的进行优化工作,寻找瓶颈在哪里。如果需要模型压缩和优化的就需要对模型的结构有一定的了解。
SAVING AND LOADING MODELS
INTRODUCTION TO TORCHSCRIPT
LOADING A TORCHSCRIPT MODEL IN C++
EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME