torch.save(model.state_dict(), './LMX2net.pth')
其实同样可以直接save model,state_dict()更加轻便,也是pytorch1.0 官网推荐的方法。load的时候会灵活。
先要写一个类,这个类就是train的时候的那个网络结构类。比如我的
class TheModelClass(nn.Module):
def __init__(self):
super().__init__()
# try2
# np.uint(torch.power(80*i, 0.447)-1)
# mapping
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv1_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv2_2 = nn.Conv2d(32, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
self.conv3_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv4 = nn.Conv2d(16, 1, 1)
# reverse mapping
self.conv5 = nn.Conv2d(1, 16, 3, padding=1)
self.conv5_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv6 = nn.Conv2d(16, 32, 3, padding=1)
self.conv6_2 = nn.Conv2d(32, 32, 3, padding=1)
self.conv7 = nn.Conv2d(32, 16, 3, padding=1)
self.conv7_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv8 = nn.Conv2d(16, 1, 3, padding=1)
def no_linear(self, img):
img[img < 0] = 0
img[img > 1] = 0
return img
def forward(self, x):
weight = 0.25 * torch.ones(1)
# skip1 = (4480*x).pow(0.355)
x1 = F.relu(self.conv1(x))
x1 = F.relu(self.conv1_2(x1))
x2 = F.relu(self.conv2(x1))
x2 = F.relu(self.conv2_2(x2))
x3 = F.relu(self.conv3(x2))
x3 = F.relu(self.conv3_2(x3))
# y = x3+skip1
x4 = F.relu(self.conv4(x3)) + x
# 10 bits feature map
# x4 = (x4-x4.min())/(64*x4.max()-x4.min())
# print("before:", x4.max(), x4.min())
x4 = self.no_linear(x4)
# print("after:", x4.max(), x4.min())
x4 = x4 * 1023
# print((x4>1024).sum())
noise = torch.rand(x4.shape).float().cuda() - 0.5
# torch.nn.Dropout2d(0.6,inplace=True)
x4 = x4 + noise
x4 = x4 / 1023
# skip2 = x4.pow(1 / 0.355) / 4480
x5 = F.relu(self.conv5(x4))
x5 = F.relu(self.conv5_2(x5))
x6 = F.relu(self.conv6(x5))
x6 = F.relu(self.conv6_2(x6))
x7 = F.relu(self.conv7(x6))
x7 = F.relu(self.conv7_2(x7))
x8 = F.relu(self.conv8(x7)) + x4
x8 = x8 * 65535
return x4, x8
然后load
model = TheModelClass()
model.load_state_dict(torch.load('./LMX2net.pth'))
model.eval()
model.to(dev)
相当于这个本地的pth文件只是参数,我们提前描写骨架TheModelClass。
重点来了:
如果我们想提取前7层,我们可以这样子定义TheModelClass
class TheModelClass(nn.Module):
def __init__(self):
super().__init__()
# try2
# np.uint(torch.power(80*i, 0.447)-1)
# mapping
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv1_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv2_2 = nn.Conv2d(32, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
self.conv3_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv4 = nn.Conv2d(16, 1, 1)
# reverse mapping
self.conv5 = nn.Conv2d(1, 16, 3, padding=1)
self.conv5_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv6 = nn.Conv2d(16, 32, 3, padding=1)
self.conv6_2 = nn.Conv2d(32, 32, 3, padding=1)
self.conv7 = nn.Conv2d(32, 16, 3, padding=1)
self.conv7_2 = nn.Conv2d(16, 16, 3, padding=1)
self.conv8 = nn.Conv2d(16, 1, 3, padding=1)
def no_linear(self, img):
img[img < 0] = 0
img[img > 1] = 0
return img
def forward(self, x):
weight = 0.25 * torch.ones(1)
# skip1 = (4480*x).pow(0.355)
x1 = F.relu(self.conv1(x))
x1 = F.relu(self.conv1_2(x1))
x2 = F.relu(self.conv2(x1))
x2 = F.relu(self.conv2_2(x2))
x3 = F.relu(self.conv3(x2))
x3 = F.relu(self.conv3_2(x3))
# y = x3+skip1
x4 = F.relu(self.conv4(x3)) + x
# 10 bits feature map
# x4 = (x4-x4.min())/(64*x4.max()-x4.min())
# print("before:", x4.max(), x4.min())
x4 = self.no_linear(x4)
# print("after:", x4.max(), x4.min())
x4 = x4 * 1023
# print((x4>1024).sum())
noise = torch.rand(x4.shape).float().cuda() - 0.5
# torch.nn.Dropout2d(0.6,inplace=True)
x4 = x4 + noise
x4 = x4 / 1023
# skip2 = x4.pow(1 / 0.355) / 4480
# x5 = F.relu(self.conv5(x4))
# x5 = F.relu(self.conv5_2(x5))
# x6 = F.relu(self.conv6(x5))
# x6 = F.relu(self.conv6_2(x6))
# x7 = F.relu(self.conv7(x6))
# x7 = F.relu(self.conv7_2(x7))
# x8 = F.relu(self.conv8(x7)) + x4
# x8 = x8 * 65535
return x4
之后跑test的时候
yy = model(img.unsqueeze(0).unsqueeze(0).to(dev))
就可以提取x4层的feature map了