写在前面
目前PyTorch在学术界几乎完全盖过了Tensorflow,从个人体验来说:
- PyTorch并没有表现出比Tensorflow性能差很多,基本是相当;
- 与Tensorflow的非Eager API相比,PyTorch上手、使用的难度相对来说非常低,并且调试非常容易(可断点可单步);
- 以及笔者还没有体会到的动态图与静态图之争
但是在应用场景下仍然还有将模型使用TF-Serving部署的需求,做转换还是有意义的。
本文将以一个transformer为例来介绍整个流程。
依赖:onnx
model = MultiTaskModel(encoder, sst_head, mrpc_head)
# 关注点1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 关注点2
tensor = torch.tensor((), dtype=torch.int64)
input_ids = tensor.new_ones((1, 48)).to(device)
token_type_ids = tensor.new_ones((1, 48)).to(device)
attention_mask = tensor.new_ones((1, 48)).to(device)
# 关注点3
torch.onnx.export(model,
(input_ids, attention_mask, token_type_ids),
"/path/to/model.onnx",
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["intent", "predicate"],
dynamic_axes={"input_ids": [0], "attention_mask": [0], "token_type_ids": [0], "intent": [0], "predicate": [0]})
关注点:
input_ids/token_type_ids/attention_mask
均是),数值不重要,也不用担心batch设为1了,后续模型就真的只能接收大小是(1, 48)
的参数了(固定参数不过仍然有调整空间).to(device)
如果都是在CPU上的时候可以不设,但是在GPU上就需要显式地让模型和参数位于相同设备torch.onnx.export
参数解释:
dynamic_axes
:设置之后相应的维度就会是可变的,有需要可以设(这里是将batch_size设置为了可变)在明确单点预测的情况下可以忽略。opset_version
: 使用默认值就好了,使用更高的值在下游转换时可能出现opset未实现的错误
依赖 onnx-tf
onnx转换出来的单pb模型不好部署(比如Java本地加载调试,笔者只走通了saved_model),需要再转换为saved_model。
其实看看代码,直接转换一步到位应该也是可以的,没有去探索偷个懒。
可以使用命令
onnx-tf convert --infile model.onnx --outfile model.pb
也可以上代码,方便集成
import onnx
from onnx_tf.backend import prepare
onnx_model = onnx.load("/path/model.onnx") # load onnx model
tf_rep = prepare(onnx_model)
tf_rep.export_graph("/path/model.pb")
2.1 中转出的单pb模型没有签名,加载有些麻烦,重新加载加签名加标签。
import tensorflow as tf
import os
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
with tf.gfile.GFile("/path/model.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
export_dir = "/path/saved_model"
if os.path.exists(export_dir): # 不会自动覆盖,需要手工删除
os.rmdir(export_dir)
sigs = {}
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
tf_graph = tf.get_default_graph()
input_ids = tf_graph.get_tensor_by_name("input_ids:0")
attention_mask = tf_graph.get_tensor_by_name("attention_mask:0")
token_type_ids = tf_graph.get_tensor_by_name("token_type_ids:0")
intent = tf_graph.get_tensor_by_name("intent:0")
predicate = tf_graph.get_tensor_by_name("predicate:0")
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids},
{"intent": intent, "predicate": predicate})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
这样看,TF2.0不像1.x有各种显示的变量名称,在保存模型时也这样来一次或者转换一次、打上签名和标签,能够让调用更简单。
最终会生成文件和目录如下:
saved_model
|-- variables
|-- saved_model.pb
INVALID_ARGUMENT: NodeDef mentions attr ‘incompatible_shape_error’ not in Op
z:bool; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_UINT8, …, DT_QINT8, DT_QINT32, DT_STR ING, DT_BOOL, DT_COMPLEX128]; is_commutative=true>; NodeDef: {{node NotEqual}}. (Check whether your GraphDef-interpreting binary i s up to date with your GraphDef-generating binary.)
这个问题折腾了我有点时间,如图加粗,还是应该好好看错误信息才对。当时使用的是1.15做的转换、导出,线上部署是1.14的。换到1.14上导出就能够避开这个问题了。
划重点:onnx转换时用的tensorflow版本最好是与线上部署使用的版本保持一致
本人在测试FP16导出时发现了TF1.14 很难找到一个合适的onnx/onnx-tf配置,会出现各种错误,需要同步做以下更改。上文中并未给出配置信息(抱歉),这里给出一套验证过的配置:
pytorch==1.7.1
tensorflow==2.1
tensorflow-addon==0.9.1 # 需要与tensorflow版本同步变化
onnx==1.8.0
onnx-tf==1.7.0
opset_version
需要设为10
:有比较明显的错误提示要这么做from sys import argv
import onnx
from onnx_tf.backend import prepare
if __name__ == '__main__':
path = argv[1]
print("onnx => tensorflow saved model")
graph_pb = f"{path}/saved_model"
onnx_model = onnx.load(f"{path}/model.onnx") # load onnx model
tf_rep = prepare(onnx_model) # run the loaded model
tf_rep.export_graph(graph_pb)
备注:已知这样操作后输出的名字丢失了(输入没问题),目前还不知道怎么前回来。
本文整合了网上的一些材料,对PyTorch模型转换到TF-Serving做了一点探索,同时也修正了一些网上的排名靠前的文章中的错误,还列举了一些遇到的其它问题。
想想验证完了所有流程,最后遇到2.3时的各种怀疑人生……嗯,做到这里,才算是能够愉快地使用PyTorch了,真香。
参考
Convert your Pytorch Models to Tensorflow (with ONNX)