pytorch转onnx注意事项(翻译)

  1. 像data[index] = new_data这样的张量就地索引分配目前在导出中不受支持。解决这类问题的一种方法是使用算子散点,显式地更新原始张量。
    就是像tensorflow的静态图,不能随便改变tensor的值,可以用torch的scatter_方法解决
    错误的方式
# def forward(self, data, index, new_data):
#     data[index] = new_data          # 重新赋值
#     return data

正确的方式

class InPlaceIndexedAssignmentONNX(torch.nn.Module):
    def forward(self, data, index, new_data):
        new_data = new_data.unsqueeze(0)
        index = index.expand(1, new_data.size(1))
        data.scatter_(0, index, new_data)
        return data
  1. 装LSTM这种由动态变量的类

你可能感兴趣的:(pytorch,pytorch,onnx)