最近在读TorchScript的入门介绍,看了官方链接的文章,然后感觉是云山雾罩,不知所云。
然后搜索到了Rene Wang的文章,才感觉明白了好多。
官方的介绍TracedModule的缺点例子是这样的:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
输出是:
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(input, ), )
_1 = torch.tanh(torch.add(_0, h, alpha=1))
return (_1, _1)
然后官方再介绍ScriptMoudle:
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)
然后输出是:
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
然后文章里就高潮叫hooray了,我还是一脸懵逼的,根本没有看到ScriptModule的code与TracedModule的code差异啊?
Rene Wang的文章解释的很到位,关键要看my_cell.dg.code,其实他们是这样的
traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)
--输出--
c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
after removing the cwd from sys.path.
def forward(self,
x: Tensor) -> Tensor:
return x
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code)
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
print(traced_cell.code)
#只有从dg.code才能看到 if else 流程控制语句执行了
print(traced_cell.dg.code)
--输出--
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
RecursiveScriptModule(
original_name=MyCell
(dg): RecursiveScriptModule(original_name=MyDecisionGate)
(linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
这样能够清晰的看到ScriptModule追踪到了if else 控制流。
基于torch 1.4.0版本,可能官方的tutorial是基于老的版本的实例。