【模型转换】onnx删除并新增节点

        仍然以【模型转化】修改onnx节点属性这篇文章中的例子为例来学习onnx的基本操作。这次不再修改Squeeze节点的属性(attribute),而是删除原有Squeeze节点,再在原来位置添加一个新的Squeeze节点。

1.定位要操作的节点

import onnx                                                                                                                                                                                                                                                           
model = onnx.load('my.onnx') 
for node_id,node in enumerate(model.graph.node):
    print("######%s######" % node_id)
    print(node)

要操作的Squeeze为45号节点。

...
 
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
                        
...

2. 删除并增加新节点

import onnx                                                                                                                                                                                                                                                           
model = onnx.load('my.onnx')

#删除老节点 
old_squeeze_node = model.graph.node[45]
model.graph.node.remove(old_squeeze_node)

#新建新节点并添加进graph
new_squeeze_node = onnx.helper.make_node(
    name = "Squeeze_250",
    op_type="Squeeze",
    inputs=["384"],
    outputs=["385"],
    axes=[0,3], 
)
model.graph.node.insert(45,new_squeeze_node)
onnx.save(model,"my.onnx")

 Squeeze操作需要的axes参数,我在make_node的时候通过kwargs参数作了指定。关于onnx.helper.make_node用法可参见如下说明。

In [4]: onnx.helper.make_node??                                                                                                                                                                                                                                       
Signature:
onnx.helper.make_node(
    op_type:str,
    inputs:Sequence[str],
    outputs:Sequence[str],
    name:Union[str, NoneType]=None,
    doc_string:Union[str, NoneType]=None,
    domain:Union[str, NoneType]=None,
    **kwargs:Any,
) -> onnx.onnx_ml_pb2.NodeProto
Source:   
def make_node(
        op_type: Text,
        inputs: Sequence[Text],
        outputs: Sequence[Text],
        name: Optional[Text] = None,
        doc_string: Optional[Text] = None,
        domain: Optional[Text] = None,
        **kwargs: Any
) -> NodeProto:
    """Construct a NodeProto.

    Arguments:
        op_type (string): The name of the operator to construct
        inputs (list of string): list of input names
        outputs (list of string): list of output names
        name (string, default None): optional unique identifier for NodeProto
        doc_string (string, default None): optional documentation string for NodeProto
        domain (string, default None): optional domain for NodeProto.
            If it's None, we will just use default domain (which is empty)
        **kwargs (dict): the attributes of the node.  The acceptable values
            are documented in :func:`make_attribute`.
    """

    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if kwargs:
        node.attribute.extend(
            make_attribute(key, value)
            for key, value in sorted(kwargs.items())
            if value is not None)
    return node

3.检查新生成的onnx

...
 
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
attribute {
  name: "axes"
  ints: 0
  ints: 3
  type: INTS 
}
                        
...

【模型转换】onnx删除并新增节点_第1张图片

你可能感兴趣的:(模型转换,onnx,onnx)