onnx模型:查看、转换与使用

1. onnx简介

onnx(Open Neural Network Exchange),是一种用于表示神经网络的规范的模型,便于模型在不同框架下进行转换。

2. 可视化

可视化网页:https://netron.app/
onnx模型:查看、转换与使用_第1张图片
结果:
onnx模型:查看、转换与使用_第2张图片

3. pytorch框架下的转换

pytorch框架下,使用torch.onnx.export()函数进行转换。

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, 
				output_names=None,aten=False, export_raw_ir=False, operator_export_type=None, 
				opset_version=None, _retain_param_name=True,do_constant_folding=False, 
				example_outputs=None, strip_doc_string=True,
				dynamic_axes=None, keep_initializers_as_inputs=None)

几个比较重要的参数声明:

torch.onnx.export(model, 						# 网络模型,在dqn中就是保存eval_net
                torch.randn(1, 3, 224, 224), 	# 描述输入的维数,具体数值无关紧要
                export_onnx_file,   			# 输出onnx的名称,也可以限定位置
                input_names=["input"],   		# 输入节点的名称,可以不写,写就要对应上
                output_names=["output"], 		# 输出节点的名称,可以不写,写就要对应上
                )

4. pytorch下DQN网络转换样例与注意事项

	dummy_input = torch.randn([1,N_STATES])
	torch.onnx.export(dqn.eval_net, dummy_input, r".\models\dqn.onnx")

r".\models\dqn.onnx",通过相对地址的写法,限定了onnx生成的位置与名称
要注意,DQN输入虽然只有N_STATES维,但转换之后的onnx网络却是一个1*N_STATES维的输入,需要通过torch.randn([1,N_STATES])转换成恰当的格式,如果不添加,则会生成如下网络:
onnx模型:查看、转换与使用_第3张图片

参考:

Onnx简介以及使用
torch.onnx.export

你可能感兴趣的:(Python与强化学习,python,pytorch,经验分享)