torch.jit.script(model)使用示例

import torch

class LinearModel(torch.nn.Module):
	def __init__(self):
		super().__init__()
		self.linear = torch.nn.Linear(1, 1)

	def forward(self, x):
		return self.linear(x)

创建模型实例

model = LinearModel()

准备输入数据

x = torch.tensor([[1.0], [2.0]])

运行模型

y = model(x)

将模型转换为Torch Script

scripted_model = torch.jit.script(model)

使用Torch Script进行推理

y_ts = scripted_model(x)

比较两种方式的输出是否相同

print("PyTorch output:", y)
print("Torch Script output:", y_ts)

你可能感兴趣的:(03-Python,pytorch)