任务简介:
使用PyTorch训练的模型只能在Python环境中使用,在自动驾驶场景中,模型推理过程通常是在硬件设备上进行。TorchScript可以将PyTorch训练的模型转换为C++环境支持的模型,推理速度比Python环境更快。本文对整体转换流程做一个简单的记录,后续需要补充TorchScript的支持的各种语法规则以及注意点。
TorchScript:
TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。
TorchScript模型生成有torch.jit.trace和torch.jit.script两种方法。
传入Module和符合的示例输入。它会调用Moduel并将操作记录下来,当Module运行时记录下操作,然后创建torch.jit.ScriptModule的实例。对于有控制流的模型,直接使用torch.jit.trace()并不能跟踪到控制流,因为它只是对操作进行了记录,对于没有运行到的操作并不会记录,trace方式生成模型的示例如下:
class MyDecisionGate(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]:
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h)) # trace方式
print(traced_cell.dg.code)
print(traced_cell.code)
输出:
def forward(self,
argument_1: Tensor) -> None:
return None
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.dg
_1 = (self.linear).forward(input, )
_2 = (_0).forward(_1, )
_3 = torch.tanh(torch.add(_1, h, alpha=1))
return (_3, _3)
可以看到.code的输出,if-else的分支没有了,控制流会被擦除。
前面提到的问题,可以使用script compiler来解决,可以直接分析Python源代码来把它转化为TrochScript。如下:
scripted_gate = torch.jit.script(MyDecisionGate()) # script方式
my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell) # script方式
print(scripted_gate.code)
print(scripted_cell.code)
输出:
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
可以看到控制流保存下来了。
torch.jit.script()会转换传入Module的所有代码,在实际转换模型的过程中会增加修改代码的工作量,因此通常将torch.jit.trace()和torch.jit.script()进行混合使用,比较灵活。
在需要使用控制流,如不定长的for循环、if-else分支时,在该函数上方输入@torch.jit.script
即可,如:
@torch.jit.script
def get_goal_2D(topk_lane_vector: Tensor, topk_points_mask: Tensor) -> Tensor:
points = torch.zeros([1,2],device=topk_lane_vector.device)
visit: Dict[int,bool]= {}
for index_lane, lane_vector in enumerate(topk_lane_vector):
for i, point in enumerate(lane_vector):
if topk_points_mask[index_lane][i]:
hash: int = int(torch.round((point[0] + 500) * 100) * 1000000 + torch.round((point[1] + 500) * 100))
if hash not in visit:
visit[hash] = True
points = torch.cat([points,point.unsqueeze(0)],dim=0)
point_num, divide_num = _get_subdivide_num(lane_vector, topk_points_mask[index_lane])
if divide_num > 1:
subdivide_points = _get_subdivide_points(lane_vector, point_num, divide_num)
points = torch.cat([points,subdivide_points],dim=0)
return points[1:]
再使用torch.jit.trace()将实例化的model和输入传入,即可生成TorchScript模型。
示例代码:
model.eval()
with torch.no_grad():
traced_script_model = torch.jit.trace(model, script_inputs, strict=False)
traced_script_model.save("models.densetnt.1/model_save/model.16_script.bin")
print(traced_script_model.code)
print('Finish converting model!!!')
在trace之后使用 torch.jit.optimize_for_inference(),然后保存模型,代码:
model = torch.jit.trace(model, script_inputs, strict=False)
model = torch.jit.optimize_for_inference(model)
model.save("models.densetnt.1/model_save/model.16_script.bin")
生成的TorchScript模型推理时间能减少4~6ms左右。
Python中,在推理之前使用torch.inference_mode(),代码:
model = torch.jit.load("models.densetnt.1/model_save/model.16_script.bin")
with torch.inference_mode():
pred = model(vector_data, vector_mask, traj_polyline_mask, map_polyline_mask)
C++中,在推理之前使用c10::InferenceMode guard,代码:
c10::InferenceMode guard;
torch_output_tensor = model.forward(torch_inputs).toTensor().to(torch::kCPU); // inference
可以减少推理时间2~3ms。
问题说明: 在C++中加载完TorchScript模型之后,前20次左右的推理时间普遍很长,最长能达到十几秒。初步分析大概率是由于在不断读取数据推理时GPU显存不断加大分配,导致推理时间过长。最理想的解决方案是在加载模型之后就将显存分配够,那后续的推理时间就会非常短。下面记录一个比较有效的方法:
Python中,在模型加载前使用torch._C._jit_set_profiling_mode(False),代码:
torch._C._jit_set_profiling_mode(False) # 效果更好,原因未知
model = torch.jit.load("models.densetnt.1/model_save/model.16_script.bin")
也可以使用torch._C._jit_set_profiling_executor(False),但torch._C._jit_set_profiling_mode(False)效果更好,具体原因还没搞清楚。
C++中,在模型加载前使用torch::jit::getExecutorMode() = false,代码:
torch::jit::getExecutorMode() = false;
model = torch::jit::load(FLAGS_dense_tnt_torch_script_file, device_);
model.eval();
此后前20次推理时间恢复正常:
dense_tnt inference used time: 22.1718 ms.
dense_tnt inference used time: 22.6176 ms.
dense_tnt inference used time: 23.3946 ms.
dense_tnt inference used time: 7.24652 ms.
dense_tnt inference used time: 21.3756 ms.
dense_tnt inference used time: 19.7572 ms.
dense_tnt inference used time: 19.2617 ms.
dense_tnt inference used time: 7.18374 ms.
dense_tnt inference used time: 13.7297 ms.
dense_tnt inference used time: 9.89283 ms.
···