PyTorch 导出onnx模型没有输入节点

PyTorch 导出onnx模型没有输入节点

Tensor.data使得torch.Tensorrequires_grad=False,因此在torch.onnx.export导出模型时,该Tensor不被追踪,当作了常量参数,最终导出的模型没有输入节点。

torch.onnx转模型时,通过netron查看网络结构,发现没有输入结果,输入被当作常量参数放在原输入结点的下一个节点中。

PyTorch 导出onnx模型没有输入节点_第1张图片PyTorch 导出onnx模型没有输入节点_第2张图片

torch.onnx.export()

 Exports a model into ONNX format. 
 If ``model`` is not a 
 :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, 
 this runs ``model`` once in order to convert it to
  a TorchScript graph to be exported
   (the equivalent of :func:`torch.jit.trace`). 
 Thus this has the same limited support for dynamic control flow as 
 :func:`torch.jit.trace`.

翻译

将模型导出为 ONNX 格式。
 如果 ``model`` 不是 :class:`torch.jit.ScriptModule` 也不是 
 :class:`torch.jit.ScriptFunction`,
 这将运行 ``model`` 一次,以便将其转换为 TorchScript 图 被导出(相当于:func:`torch.jit.trace`)。 
 因此,它对动态控制流的支持与 torch.jit.trace 相同。

如果对执行.data后的input进行.requires_grad_(),设置requires_grad=True,报错如下。

Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

PyTorch 导出onnx模型没有输入节点_第3张图片

注释代码中input.data操作。

PyTorch 导出onnx模型没有输入节点_第4张图片
PyTorch 导出onnx模型没有输入节点_第5张图片

torch==1.11.0
onnx==2.0.1
opset_version=11

你可能感兴趣的:(笔记,pytorch,深度学习,人工智能)