TorchScript中的核心数据结构是ScriptModule。 它和Torch的nn.Module的类似,是an entire model as a tree of submodules。 与普通nn.Module一样,ScriptModule中单个Module可以包含submodules, parameters, and methods。 在nn.Modules中,methods是作为Python函数实现的,而在ScriptModules中通常用TorchScript函数实现。TorchScript函数是Python函数子集,包含PyTorch所有的内置Tensor操作。 这主要是为了实现用户运行ScriptModules代码时,无需Python interpreter。
ScriptModules和它们内部的TorchScript函数可以通过两种方式创建:
Tracing:
使用torch.jit.trace,提供示例输入,导入已有的Module对象或python函数,然后运行该函数,记录在所有张量上执行的操作。 将生成的记录转换为TorchScript方法。这个TorchScript方法就是ScriptModule的forward方法。 ScriptModule对象还包含Module对象所具有的任何参数。
Example:
import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
trace函数会生成一个ScriptModule对象,该对象只包含一个实现foo函数的forward方法,不包含任何参数。
import torch
import torchvision
traced_net = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
Trace 仅记录函数在给定张量上执行的操作。因此,生成的ScriptModule对象对于任何输入执行的是相同的操作集。当对象需要运行不同的操作集时(根据输入的不同改变对张量操作),就会产生错误,具体取决于输入和/或模型状态。例如:
Trace 不会记录任何控制流,如if语句或循环。当这个控制流在你的模型中保持不变时,Trace能够正常运行 ,它通常只是内联的配置决策。但有时控制流实际上是模型本身的一部分。例如,序列到序列转换中的beam搜索是基于input的序列长度(变化的)上的循环操作,在这种情况下(模型需要运行不同的操作集),Tracing 不合适,Scripting是更好的选择。
在返回的ScriptModule中,在训练和eval模式下具有不同行为的操作将始终表现为处于跟踪期间的模式(无论ScriptModule处于哪种模式)。
可以使用Python语法直接编写TorchScript代码。通过在ScriptModule的子类上使用torch.jit.script批注(对于函数)或torch.jit.script_method批注(对于方法)来实现。 拥有注释的函数的主体将直接转换为TorchScript。 因为TorchScript本身是Python语言的一个子集,所以并非python中的所有功能都可以工作,但官方提供了足够的函数来计算张量并执行与控制相关的操作。
转换函数Example:
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
nn.Mdule中的forward Example:
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
return self.weight.mv(input)
nn.Mdule中的网络层和forward Example:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace
class MyScriptModule(ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
@script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
函数
save(filename)
保存模型,以便在C++程序中被调用。 保存模型的所有方法和参数。 它可以使用torch :: jit :: load(filename)加载到C ++ API中,也可以使用torch.jit.load(filename)加载到Python API中。
为了能够保存模型,它不能对本机python函数进行任何调用。 这意味着所有子模型也必须是ScriptModules的子类。
注意:
所有module,无论其设备如何,在加载期间始终都会加载到CPU上。 这与torch.load()的语义不同,将来可能会发生变化。
torch.jit.load(f, map_location=None)
加载用save保存的ScriptModule模型
所有先前保存的模型,无论其设备如何,都首先加载到CPU上,然后移动到它们保存的设备。 如果此操作失败(例如,因为运行时
系统没有某些设备),则会引发异常。 但是,可以使用map_location参数将存储重新映射到另一组设备。 与torch.load()相比,
此函数中的map_location被简化,只接受字符串(例如'cpu','cuda:0')或torch.device(例如,torch.device('cpu'))
参数:
f - 类文件对象(必须实现read,readline,tell和seek),或包含文件名的字符串
map_location - 可以是一个字符串(例如,'cpu','cuda:0'),一个设备(例如,torch.device('cpu'))
返回:ScriptModule对象。
Example
>>> torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
>>> with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
# Load all tensors to the original device
>>> torch.jit.load(buffer)
# Load all tensors onto CPU, using a device
>>> torch.jit.load(buffer, map_location=torch.device('cpu'))
# Load all tensors onto CPU, using a string
>>> torch.jit.load(buffer, map_location='cpu')
# Load with extra files.
>>> files = {'metadata.json' : ''}
>>> torch.jit.load('scriptmodule.pt', _extra_files = files)
>>> print (files['metadata.json'])
torch.jit.trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-05, _force_outplace=False)
trace一个函数并返回一个可执行的trace,该跟踪将使用即时编译进行优化。
警告
trace仅能正确地记录不依赖于数据的函数和模型(例如,对张量中的数据有依赖条件)并且没有任何无法trace的外部依赖性(例
如,执行输入/输出或访问全局变量)。如果跟踪此类模型,则可能会在随后的静态模型调用中获取不正确的结果。当执行可能导致
生成错误跟踪的内容时,跟踪器会发出警告。
参数:
func(callable或torch.nn.Module) - 输入参数为example_inputs的python函数或torch.nn.Module。参数和返回必须是Tensors或
(可能是嵌套的)包含张量的元组。
example_inputs(tuple) - trace时输入的元组。只要是跟踪操作支持的类型和大小,都可以作为输入来生成跟
踪。example_inputs也可以是单个Tensor,在这种情况下,它会自动包装在元组中
关键字参数:
optimize(bool,optional) - 是否应用优化。默认值:True。
check_trace(bool,optional) - 检查通过跟踪代码运行的相同输入是否产生相同的输出。默认值:True。例如,如果您的网络
包含非确定性操作,或者您确定尽管检查程序失败,网络仍然正确,您可能希望禁用此功能。
check_inputs(元组列表,可选) - 输入参数元组列表,用于检查跟踪和预期对比。Each tuple is equivalent to a seet of input
arguments that would be specified in args. For best results, pass in a set of checking inputs representative of the space of shapes
and types of inputs you expect the network to see. If not specified, the original args is used for checking
check_tolerance(float,optional) - Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.
返回:
ScriptModule对象,有一个包含跟踪代码的forward()方法。当func是torch.nn.Module时,返回的ScriptModule将具有与
func相同的子模型和参数集。
Example
>>> def f(x):
... return x * 2
>>> traced_f = torch.jit.trace(f, torch.rand(1))
Mixing Tracing and Scripting
在许多情况下,Tracing和Scripting都是转换模型的简便方法。 可以编写Tracing 和Scripting来满足模型的特定要求。
Scripted 函数可以调用跟踪函数。 当在前馈模型需要使用控制流时,这尤其有用。 例如,序列到序列模型的beam 搜索通常将以Script编写,但可以通过Tracing来调用。
Example:
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
Traced函数可以调用scripte函数。 当模型的一小部分需要一些控制流时,即使模型大部分只是一个前馈网络,这样子结合就会很有用。 Traced函数调用scripte函数内部的控制流 is preserved correctly:
Example:
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))
这种组合也适用于模型,它可以用于:Script module方法调用Tracing生成的子模型:
Example:
import torch
import torchvision
class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
@torch.jit.script_method
def forward(self, input):
return self.resnet(input - self.means)
实际例子:
转换resnet18模型
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# ScriptModule
output = traced_script_module(torch.ones(1, 3, 224, 224))
traced_script_module.save("model.pt")
c++调用
#include // One-stop header.
#include
#include
int main(){
// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr module = torch::jit::load("model.pt");
assert(module != nullptr);
std::cout << "ok\n";
// Create a vector of inputs.
std::vector inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
module->forward(inputs);
auto out_tensor = module->forward(inputs).toTensor();
std::tuple result = out_tensor.sort(-1, true);
torch::Tensor top_scores = std::get<0>(result)[0];
torch::Tensor top_idxs = std::get<1>(result)[0].toType(torch::kInt32);
// Load labels
std::string label_file = "synset_words.txt";
std::ifstream rf(label_file.c_str());
CHECK(rf) << "Unable to open labels file " << label_file;
std::string line;
std::vector labels;
while (std::getline(rf, line))
labels.push_back(line);
auto top_scores_a = top_scores.accessor();
auto top_idxs_a = top_idxs.accessor();
for (int i = 0; i < 5; ++i) {
int idx = top_idxs_a[i];
std::cout << "top-" << i+1 << " label: ";
std::cout <<" "<