仍然以【模型转化】修改onnx节点属性这篇文章中的例子为例来学习onnx的基本操作。这次不再修改Squeeze节点的属性(attribute),而是删除原有Squeeze节点,再在原来位置添加一个新的Squeeze节点。
import onnx
model = onnx.load('my.onnx')
for node_id,node in enumerate(model.graph.node):
print("######%s######" % node_id)
print(node)
要操作的Squeeze为45号节点。
...
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
...
import onnx
model = onnx.load('my.onnx')
#删除老节点
old_squeeze_node = model.graph.node[45]
model.graph.node.remove(old_squeeze_node)
#新建新节点并添加进graph
new_squeeze_node = onnx.helper.make_node(
name = "Squeeze_250",
op_type="Squeeze",
inputs=["384"],
outputs=["385"],
axes=[0,3],
)
model.graph.node.insert(45,new_squeeze_node)
onnx.save(model,"my.onnx")
Squeeze操作需要的axes参数,我在make_node的时候通过kwargs参数作了指定。关于onnx.helper.make_node用法可参见如下说明。
In [4]: onnx.helper.make_node??
Signature:
onnx.helper.make_node(
op_type:str,
inputs:Sequence[str],
outputs:Sequence[str],
name:Union[str, NoneType]=None,
doc_string:Union[str, NoneType]=None,
domain:Union[str, NoneType]=None,
**kwargs:Any,
) -> onnx.onnx_ml_pb2.NodeProto
Source:
def make_node(
op_type: Text,
inputs: Sequence[Text],
outputs: Sequence[Text],
name: Optional[Text] = None,
doc_string: Optional[Text] = None,
domain: Optional[Text] = None,
**kwargs: Any
) -> NodeProto:
"""Construct a NodeProto.
Arguments:
op_type (string): The name of the operator to construct
inputs (list of string): list of input names
outputs (list of string): list of output names
name (string, default None): optional unique identifier for NodeProto
doc_string (string, default None): optional documentation string for NodeProto
domain (string, default None): optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
**kwargs (dict): the attributes of the node. The acceptable values
are documented in :func:`make_attribute`.
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if kwargs:
node.attribute.extend(
make_attribute(key, value)
for key, value in sorted(kwargs.items())
if value is not None)
return node
...
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
attribute {
name: "axes"
ints: 0
ints: 3
type: INTS
}
...