pytorch1.0.0 如何载入部分模型

0.载入部分模型的motivation

  1. 可视化feature map
  2. 迁移学习
  3. ...

1.保存

torch.save(model.state_dict(), './LMX2net.pth')

其实同样可以直接save model,state_dict()更加轻便,也是pytorch1.0 官网推荐的方法。load的时候会灵活。

2.载入

先要写一个类,这个类就是train的时候的那个网络结构类。比如我的

class TheModelClass(nn.Module):
    def __init__(self):
        super().__init__()
        # try2
        # np.uint(torch.power(80*i, 0.447)-1)
        # mapping
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv1_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv2_2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv3_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv4 = nn.Conv2d(16, 1, 1)

        # reverse mapping
        self.conv5 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv5_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv6 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv6_2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv7 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv7_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv8 = nn.Conv2d(16, 1, 3, padding=1)

    def no_linear(self, img):
        img[img < 0] = 0
        img[img > 1] = 0
        return img

    def forward(self, x):
        weight = 0.25 * torch.ones(1)
        # skip1 = (4480*x).pow(0.355)
        x1 = F.relu(self.conv1(x))
        x1 = F.relu(self.conv1_2(x1))
        x2 = F.relu(self.conv2(x1))
        x2 = F.relu(self.conv2_2(x2))
        x3 = F.relu(self.conv3(x2))
        x3 = F.relu(self.conv3_2(x3))
        # y = x3+skip1
        x4 = F.relu(self.conv4(x3)) + x

        # 10 bits feature map
        # x4 = (x4-x4.min())/(64*x4.max()-x4.min())
        # print("before:", x4.max(), x4.min())
        x4 = self.no_linear(x4)
        # print("after:", x4.max(), x4.min())
        x4 = x4 * 1023
        # print((x4>1024).sum())
        noise = torch.rand(x4.shape).float().cuda() - 0.5
        # torch.nn.Dropout2d(0.6,inplace=True)
        x4 = x4 + noise

        x4 = x4 / 1023

        # skip2 = x4.pow(1 / 0.355) / 4480

        x5 = F.relu(self.conv5(x4))
        x5 = F.relu(self.conv5_2(x5))
        x6 = F.relu(self.conv6(x5))
        x6 = F.relu(self.conv6_2(x6))
        x7 = F.relu(self.conv7(x6))
        x7 = F.relu(self.conv7_2(x7))
        x8 = F.relu(self.conv8(x7)) + x4
        x8 = x8 * 65535
        return x4, x8

然后load

model = TheModelClass()

model.load_state_dict(torch.load('./LMX2net.pth'))
model.eval()
model.to(dev)

相当于这个本地的pth文件只是参数,我们提前描写骨架TheModelClass。

重点来了:

如果我们想提取前7层,我们可以这样子定义TheModelClass

class TheModelClass(nn.Module):
    def __init__(self):
        super().__init__()
        # try2
        # np.uint(torch.power(80*i, 0.447)-1)
        # mapping
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv1_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv2_2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv3_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv4 = nn.Conv2d(16, 1, 1)

        # reverse mapping
        self.conv5 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv5_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv6 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv6_2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv7 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv7_2 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv8 = nn.Conv2d(16, 1, 3, padding=1)

    def no_linear(self, img):
        img[img < 0] = 0
        img[img > 1] = 0
        return img

    def forward(self, x):
        weight = 0.25 * torch.ones(1)
        # skip1 = (4480*x).pow(0.355)
        x1 = F.relu(self.conv1(x))
        x1 = F.relu(self.conv1_2(x1))
        x2 = F.relu(self.conv2(x1))
        x2 = F.relu(self.conv2_2(x2))
        x3 = F.relu(self.conv3(x2))
        x3 = F.relu(self.conv3_2(x3))
        # y = x3+skip1
        x4 = F.relu(self.conv4(x3)) + x

        # 10 bits feature map
        # x4 = (x4-x4.min())/(64*x4.max()-x4.min())
        # print("before:", x4.max(), x4.min())
        x4 = self.no_linear(x4)
        # print("after:", x4.max(), x4.min())
        x4 = x4 * 1023
        # print((x4>1024).sum())
        noise = torch.rand(x4.shape).float().cuda() - 0.5
        # torch.nn.Dropout2d(0.6,inplace=True)
        x4 = x4 + noise

        x4 = x4 / 1023

        # skip2 = x4.pow(1 / 0.355) / 4480

        # x5 = F.relu(self.conv5(x4))
        # x5 = F.relu(self.conv5_2(x5))
        # x6 = F.relu(self.conv6(x5))
        # x6 = F.relu(self.conv6_2(x6))
        # x7 = F.relu(self.conv7(x6))
        # x7 = F.relu(self.conv7_2(x7))
        # x8 = F.relu(self.conv8(x7)) + x4
        # x8 = x8 * 65535
        return x4

之后跑test的时候

yy = model(img.unsqueeze(0).unsqueeze(0).to(dev))

就可以提取x4层的feature map了

你可能感兴趣的:(deep,learning)