- PyTorch模型转换为TorchScript格式
文章目录
- 1、save weight
- 2、save model + weight
- 3、save model + weight (use jit)
import torch
from torch import nn
from torchvision.models import resnet18
from collections import OrderedDict
class Flatten(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return torch.flatten(x,1)
class Model(nn.Module):
def __init__(self,num_classes,device):
super().__init__()
_m = resnet18(False).to(device)
self.backbone = nn.Sequential(
OrderedDict([("stem", nn.Sequential(_m.conv1, _m.bn1, _m.relu, _m.maxpool)),
('layer1',_m.layer1),
('layer2',_m.layer2),
('layer3',_m.layer3),
('layer4',_m.layer4),
]))
self.fc = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Flatten(),
nn.Linear(_m.inplanes, num_classes)
)
def forward(self,x):
x = self.backbone(x)
x = self.fc(x)
return x
if __name__ == "__main__":
device = "cuda"
x = torch.randn([1,3,64,64]).to(device)
model = Model(10,device).to(device)
print(model)
pred = model(x)
print(pred.shape)
1、save weight
torch.save(model.state_dict(),'weight.pth')
from train import Model,Flatten
x = torch.randn([1,3,64,64]).to('cuda')
model = Model(10,'cuda').to('cuda')
model.load_state_dict(torch.load('weight.pth',map_location='cuda'))
print(model(x).shape)
2、save model + weight
torch.save(model,'model.pth')
class Flatten(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return torch.flatten(x,1)
class Model(nn.Module):
pass
x = torch.randn([1,3,64,64]).to('cuda')
model = torch.load("model.pth",map_location='cuda')
print(model)
out = model.backbone(x)
out = model.fc(out)
print(out.shape)
3、save model + weight (use jit)
x = torch.randn([1,3,64,64]).to('cuda')
traced_script_module = torch.jit.trace(model, x)
traced_script_module.save("model_jit.pth")
x = torch.randn([1,3,64,64]).to('cuda')
model = torch.jit.load('model_jit.pth',map_location='cuda')
print(model)
print(model(x).shape)