Java调用PyTorch下基于BERT等预训练模型生成的复杂模型的一般流程

之所以想要用Java来调Pytorch训练的模型,是希望能够充分利用Java的多线程能力,提升模型在生产环境中的运算效率。

Java想要调用Pytorch生成的模型,目前已知的主要有两种:一是将pytorch模型转换成torch.jit.trace或torch.jit.script模型,二是将pytorch模型转换成ONNX(Open Neural Network Exchange,ONNX | Home)模型。Java再通过DJL等库来调用转换后的TorchScript模型或ONNX模型。

首先介绍下TorchScript,以下内容摘自官方网站(PyTorch、Introduction to TorchScript — PyTorch Tutorials 1.11.0+cu102 documentation、Loading a TorchScript Model in C++ — PyTorch Tutorials 1.11.0+cu102 documentation)

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

We provide tools to incrementally transition a model from a pure Python program to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.

For a gentle introduction to TorchScript, see the Introduction to TorchScript tutorial.

TorchScript Language

TorchScript is a statically typed subset of Python, so many Python features apply directly to TorchScript. See the full TorchScript Language Reference for details.

Built-in Functions and Modules

TorchScript supports the use of most PyTorch functions and many Python built-ins. See TorchScript Builtins for a full reference of supported functions.

PyTorch Functions and Modules

TorchScript supports a subset of the tensor and neural network functions that PyTorch provides. Most methods on Tensor as well as functions in the torch namespace, all functions in torch.nn.functional and most modules from torch.nn are supported in TorchScript.

See TorchScript Unsupported Pytorch Constructs for a list of unsupported PyTorch functions and modules.

Python Functions and Modules

Many of Python’s built-in functions are supported in TorchScript. The math module is also supported (see math Module for details), but no other Python modules (built-in or third party) are supported.

Python Language Reference Comparison

For a full listing of supported Python features, see Python Language Reference Coverage.

Debugging

Disable JIT for Debugging

PYTORCH_JIT

Setting the environment variable PYTORCH_JIT=0 will disable all script and tracing annotations. If there is hard-to-debug error in one of your TorchScript models, you can use this flag to force everything to run using native Python. Since TorchScript (scripting and tracing) is disabled with this flag, you can use tools like pdb to debug the model code. For example:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

Debugging this script with pdb works except for when we invoke the @torch.jit.script function. We can globally disable JIT, so that we can call the @torch.jit.script function as a normal Python function and not compile it. If the above script is called disable_jit_example.py, we can invoke it like so:

$ PYTORCH_JIT=0 python disable_jit_example.py

and we will be able to step into the @torch.jit.script function as a normal Python function. To disable the TorchScript compiler for a specific function, see @torch.jit.ignore.

Inspecting Code

TorchScript provides a code pretty-printer for all ScriptModule instances. This pretty-printer gives an interpretation of the script method’s code as valid Python syntax. For example:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

A ScriptModule with a single forward method will have an attribute code, which you can use to inspect the ScriptModule’s code. If the ScriptModule has more than one method, you will need to access .code on the method itself and not the module. We can inspect the code of a method named foo on a ScriptModule by accessing .foo.code. The example above produces this output:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

This is TorchScript’s compilation of the code for the forward method. You can use this to ensure TorchScript (tracing or scripting) has captured your model code correctly.

Interpreting Graphs

TorchScript also has a representation at a lower level than the code pretty- printer, in the form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation (IR) to represent computation. The instructions in this format consist of ATen (the C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals. As an example:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph follows the same rules described in the Inspecting Code section with regard to forward method lookup.

The example script above produces the graph:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

Take the instruction %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 for example.

  • %rv.1 : Tensor means we assign the output to a (unique) value named rv.1, that value is of Tensor type and that we do not know its concrete shape.

  • aten::zeros is the operator (equivalent to torch.zeros) and the input list (%4, %6, %6, %10, %12) specifies which values in scope should be passed as inputs. The schema for built-in functions like aten::zeros can be found at Builtin Functions.

  • # test.py:9:10 is the location in the original source file that generated this instruction. In this case, it is a file named test.py, on line 9, and at character 10.

Notice that operators can also have associated blocks, namely the prim::Loop and prim::If operators. In the graph print-out, these operators are formatted to reflect their equivalent source code forms to facilitate easy debugging.

Graphs can be inspected as shown to confirm that the computation described by a ScriptModule is correct, in both automated and manual fashion, as described below.

Mixing Tracing and Scripting

In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. Tracing and scripting can be composed to suit the particular requirements of a part of a model.

Scripted functions can call traced functions. This is particularly useful when you need to use control-flow around a simple feed-forward model. For instance the beam search of a sequence to sequence model will typically be written in script but can call an encoder module generated using tracing.

Example (calling a traced function in script):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

Traced functions can call script functions. This is useful when a small part of a model requires some control-flow even though most of the model is just a feed-forward network. Control-flow inside of a script function called by a traced function is preserved correctly.

Example (calling a script function in a traced function):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

This composition also works for nn.Modules as well, where it can be used to generate a submodule using tracing that can be called from the methods of a script module.

Example (using a traced module):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

其次,介绍下ONNX,以下内容主要摘自相关官网(ONNX | Home、torch.onnx — PyTorch 1.11.0 documentation、GitHub - microsoft/onnxruntime: ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator):

ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.

对于Pytorch而言,其提供了torch.onnx模块可供pytorch模型通过torch.onnx.export导出成ONNX模型。

The torch.onnx module can export PyTorch models to ONNX. The model can then be consumed by any of the many runtimes that support ONNX.

以下是一个样例:

Example: AlexNet from PyTorch to ONNX

Here is a simple script which exports a pretrained AlexNet to an ONNX file named alexnet.onnx. The call to torch.onnx.export runs the model once to trace its execution and then exports the traced model to the specified file:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

The resulting alexnet.onnx file contains a binary protocol buffer which contains both the network structure and parameters of the model you exported (in this case, AlexNet). The argument verbose=True causes the exporter to print out a human-readable representation of the model:

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

You can also verify the output using the ONNX library, which you can install using conda:

conda install -c conda-forge onnx

Then, you can run:

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

You can also run the exported model with one of the many runtimes that support ONNX. For example after installing ONNX Runtime, you can load and run the model:

import onnxruntime as ort

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

值得注意的是:

Internally, torch.onnx.export() requires a torch.jit.ScriptModule rather than a torch.nn.Module. If the passed-in model is not already a ScriptModuleexport() will use tracing to convert it to one:

  • Tracing: If torch.onnx.export() is called with a Module that is not already a ScriptModule, it first does the equivalent of torch.jit.trace(), which executes the model once with the given args and records all operations that happen during that execution. This means that if your model is dynamic, e.g., changes behavior depending on input data, the exported model will not capture this dynamic behavior. Similarly, a trace is likely to be valid only for a specific input size. We recommend examining the exported model and making sure the operators look reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly the same as the traced run. If you want to export your model with dynamic control flow, you will need to use scripting.

  • Scripting: Compiling a model via scripting preserves dynamic control flow and is valid for inputs of different sizes. To use scripting:

    • Use torch.jit.script() to produce a ScriptModule.

    • Call torch.onnx.export() with the ScriptModule as the model, and set the example_outputs arg. This is required so that the types and shapes of the outputs can be captured without executing the model.

关于trace module和script module在前文已有介绍。

ONNX模型导出中的注意事项:

(1)避免NumPy and built-in Python types

PyTorch models can be written using NumPy or Python types and functions, but during tracing, any variables of NumPy or Python types (rather than torch.Tensor) are converted to constants, which will produce the wrong result if those values should change depending on the inputs.

For example, rather than using numpy functions on numpy.ndarrays:

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

Use torch operators on torch.Tensors:

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

And rather than using torch.Tensor.item() (which converts a Tensor to a Python built-in number):

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

Use torch’s support for implicit casting of single-element tensors:

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

(2)避免使用Tensor.data

Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph. Use torch.Tensor.detach() instead. (Work is ongoing to remove Tensor.data entirely).

(3)在tracing mode下使用tensor.shape时避免进行in-place operations

In tracing mode, shape values obtained from tensor.shape are traced as tensors, and share the same memory. This might cause a mismatch in values of the final outputs. As a workaround, avoid use of inplace operations in these scenarios. For example, in the model:

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_length and seq_length share the same memory in tracing mode. This could be avoided by rewriting the inplace operation:

real_seq_length = real_seq_length + 2

(4)关于输入输出的数据类型

  • Only torch.Tensors, numeric types that can be trivially converted to torch.Tensors (e.g. float, int), and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and outputs are accepted in tracing mode, but:

    • Any computation that depends on the value of a dict or a str input will be replaced with the constant value seen during the one traced execution.

    • Any output that is a dict will be silently replaced with a flattened sequence of its values (keys will be removed). E.g. {"foo": 1, "bar": 2} becomes (1, 2).

    • Any output that is a str will be silently removed.

  • Certain operations involving tuples and lists are not supported in scripting mode due to limited support in ONNX for nested sequences. In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences will be flattened automatically during the tracing.

(5)算子实现方式的差异

Due to differences in implementations of operators, running the exported model on different runtimes may produce different results from each other or from PyTorch. Normally these differences are numerically small, so this should only be a concern if your application is sensitive to these small differences.

一般都是通过opset_version参数来指定算子实现的版本,通常越新越好(比如opset_version=16)

(6)不支持的Tensor索引方式

Tensor indexing patterns that cannot be exported are listed below. If you are experiencing issues exporting a model that does not include any of the unsupported patterns below, please double check that you are exporting with the latest opset_version.

Reads / Gets

When indexing into a tensor for reading, the following patterns are not supported:

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

Writes / Sets

When indexing into a Tensor for writing, the following patterns are not supported:

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

(7)用户自定算子

If a model uses a custom operator implemented in C++ as described in Extending TorchScript with Custom C++ Operators, you can export it by following this example:

from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args

# Define custom symbolic function
@parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)

# Register custom symbolic function
register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)

class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super(FooModule, self).__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)

model = FooModel(attr1, attr2)
torch.onnx.export(
  model,
  (example_input1, example_input1),
  "model.onnx",
  # only needed if you want to specify an opset version > 1.
  custom_opsets={"custom_domain": 2})

You can export it as one or a combination of standard ONNX ops, or as a custom operator. The example above exports it as a custom operator in the “custom_domain” opset. When exporting a custom operator, you can specify the custom domain version using the custom_opsets dictionary at export. If not specified, the custom opset version defaults to 1. The runtime that consumes the model needs to support the custom op. See Caffe2 custom ops, ONNX Runtime custom ops, or your runtime of choice’s documentation.

(8)Discovering all unconvertible ATen ops at once

When export fails due to an unconvertible ATen op, there may in fact be more than one such op but the error message only mentions the first. To discover all of the unconvertible ops in one go you can:

from torch.onnx import utils as onnx_utils

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = onnx_utils.unconvertible_ops(
    model, args, opset_version=opset_version)

print(set(unconvertible_ops))

(9)torch.onnx.export接口及参数说明

torch.onnx.export(modelargsfexport_params=Trueverbose=Falsetraining=input_names=Noneoutput_names=Noneoperator_export_type=Noneopset_version=Nonedo_constant_folding=Truedynamic_axes=Nonekeep_initializers_as_inputs=Nonecustom_opsets=Noneexport_modules_as_functions=False)[SOURCE]

Exports a model into ONNX format. If model is not a torch.jit.ScriptModule nor a torch.jit.ScriptFunction, this runs model once in order to convert it to a TorchScript graph to be exported (the equivalent of torch.jit.trace()). Thus this has the same limited support for dynamic control flow as torch.jit.trace().

Parameters

  • model (torch.nn.Moduletorch.jit.ScriptModule or torch.jit.ScriptFunction) – the model to be exported.

  • args (tuple or torch.Tensor) –

    args can be structured either as:

    1. ONLY A TUPLE OF ARGUMENTS:

      args = (x, y, z)
      

    The tuple should contain model inputs such that model(*args) is a valid invocation of the model. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in the tuple.

    1. A TENSOR:

      args = torch.Tensor([1])
      

    This is equivalent to a 1-ary tuple of that Tensor.

    1. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:

      args = (x,
              {'y': input_y,
               'z': input_z})
      

    All but the last element of the tuple will be passed as non-keyword arguments, and named arguments will be set from the last element. If a named argument is not present in the dictionary, it is assigned the default value, or None if a default value is not provided.

    NOTE

    If a dictionary is the last element of the args tuple, it will be interpreted as containing named arguments. In order to pass a dict as the last non-keyword arg, provide an empty dict as the last element of the args tuple. For example, instead of:

    torch.onnx.export(
        model,
        (x,
         # WRONG: will be interpreted as named arguments
         {y: z}),
        "test.onnx.pb")

    Write:

    torch.onnx.export(
        model,
        (x,
         {y: z},
         {}),
        "test.onnx.pb")
  • f – a file-like object (such that f.fileno() returns a file descriptor) or a string containing a file name. A binary protocol buffer will be written to this file.

  • export_params (booldefault True) – if True, all parameters will be exported. Set this to False if you want to export an untrained model. In this case, the exported model will first take all of its parameters as arguments, with the ordering as specified by model.state_dict().values()

  • verbose (booldefault False) – if True, prints a description of the model being exported to stdout. In addition, the final ONNX graph will include the field doc_string` from the exported model which mentions the source code locations for model.

  • training (enumdefault TrainingMode.EVAL) –

    • TrainingMode.EVAL: export the model in inference mode.

    • TrainingMode.PRESERVE: export the model in inference mode if model.training is False and in training mode if model.training is True.

    • TrainingMode.TRAINING: export the model in training mode. Disables optimizations which might interfere with training.

  • input_names (list of strdefault empty list) – names to assign to the input nodes of the graph, in order.

  • output_names (list of strdefault empty list) – names to assign to the output nodes of the graph, in order.

  • operator_export_type (enumdefault None) –

    None usually means OperatorExportTypes.ONNX. However if PyTorch was built with -DPYTORCH_ONNX_CAFFE2_BUNDLE, None means OperatorExportTypes.ONNX_ATEN_FALLBACK.

    • OperatorExportTypes.ONNX: Export all ops as regular ONNX ops (in the default opset domain).

    • OperatorExportTypes.ONNX_FALLTHROUGH: Try to convert all ops to standard ONNX ops in the default opset domain. If unable to do so (e.g. because support has not been added to convert a particular torch op to ONNX), fall back to exporting the op into a custom opset domain without conversion. Applies to custom ops as well as ATen ops. For the exported model to be usable, the runtime must support these non-standard ops.

    • OperatorExportTypes.ONNX_ATEN: All ATen ops (in the TorchScript namespace “aten”) are exported as ATen ops (in opset domain “org.pytorch.aten”). ATen is PyTorch’s built-in tensor library, so this instructs the runtime to use PyTorch’s implementation of these ops.

      WARNING

      Models exported this way are probably runnable only by Caffe2.

      This may be useful if the numeric differences in implementations of operators are causing large differences in behavior between PyTorch and Caffe2 (which is more common on untrained models).

    • OperatorExportTypes.ONNX_ATEN_FALLBACK: Try to export each ATen op (in the TorchScript namespace “aten”) as a regular ONNX op. If we are unable to do so (e.g. because support has not been added to convert a particular torch op to ONNX), fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for context. For example:

      graph(%0 : Float):
        %3 : int = prim::Constant[value=0]()
        # conversion unsupported
        %4 : Float = aten::triu(%0, %3)
        # conversion supported
        %5 : Float = aten::mul(%4, %0)
        return (%5)
      

      Assuming aten::triu is not supported in ONNX, this will be exported as:

      graph(%0 : Float):
        %1 : Long() = onnx::Constant[value={0}]()
        # not converted
        %2 : Float = aten::ATen[operator="triu"](%0, %1)
        # converted
        %3 : Float = onnx::Mul(%2, %0)
        return (%3)
      

      If PyTorch was built with Caffe2 (i.e. with BUILD_CAFFE2=1), then Caffe2-specific behavior will be enabled, including special support for ops are produced by the modules described in Quantization.

      WARNING

      Models exported this way are probably runnable only by Caffe2.

  • opset_version (intdefault 9) – The version of the default (ai.onnx) opset to target. Must be >= 7 and <= 15.

  • do_constant_folding (booldefault True) – Apply the constant-folding optimization. Constant-folding will replace some of the ops that have all constant inputs with pre-computed constant nodes.

  • dynamic_axes (dictdictstring>> or dictlist(int)>default empty dict) –

    By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at run-time), set dynamic_axes to a dict with schema:

    • KEY (str): an input or output name. Each name must also be provided in input_names or output_names.

    • VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index.

    For example:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"])
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    While:

    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"],
                      dynamic_axes={
                          # dict value: manually named axes
                          "x": {0: "my_custom_axis_name"},
                          # list value: automatic names
                          "sum": [0],
                      })
    

    Produces:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (booldefault None) –

    If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the non-parameter inputs are added as inputs. This may allow for better optimizations (e.g. constant folding) by backends/runtimes.

    If opset_version < 9, initializers MUST be part of graph inputs and this argument will be ignored and the behavior will be equivalent to setting this argument to True.

    If None, then the behavior is chosen automatically as follows:

    • If operator_export_type=OperatorExportTypes.ONNX, the behavior is equivalent to setting this argument to False.

    • Else, the behavior is equivalent to setting this argument to True.

  • custom_opsets (dictint>default empty dict) –

    A dict with schema:

    • KEY (str): opset domain name

    • VALUE (int): opset version

    If a custom opset is referenced by model but not mentioned in this dictionary, the opset version is set to 1. Only custom opset domain name and version should be indicated through this argument.

  • export_modules_as_functions (bool or set of python:type of nn.Moduledefault False) –

    Flag to enable exporting all nn.Module forward calls as local functions in ONNX. Or a set to indicate the particular types of modules to export as local functions in ONNX. This feature requires opset_version >= 15, otherwise the export will fail. This is because opset_version < 15 implies IR version < 8, which means no local function support.

    • False``(default): export ``nn.Module forward calls as fine grained nodes.

    • True: export all nn.Module forward calls as local function nodes.

    • Set of type of nn.Module: export nn.Module forward calls as local function nodes, only if the type of the nn.Module is found in the set.

你可能感兴趣的:(AI服务化,pytorch,java,bert,人工智能,python)