pytorch读取部分网络

问题描述:在做无监督图像融合时涉及到了编码器和解码器。编码器输出的特征图,经过处理后要在通过解码器。这就要求参数文件要分开读取。

解决问题的核心;pytorch是以字典的方式存储参数(层的名字:对应参数)

下面进行举例:

class decode():   #  解码器
    def __init__(self):
        # initialize model
        self.device = torch.device('cpu')  # 设备选择
        self.model = Decodenet()

        self.model_path = os.path.join(os.getcwd(), "nets", "parameters", "lp+lssim_se_sf_net_times30.pkl")  # 编码器参数的路径
        self.save_model = torch.load(self.model_path, map_location=self.device)
        self.model_dict = self.model.state_dict()       # 模型key
        self.state_dict = {k: v for k, v in self.save_model.items() if k in self.model_dict.keys()}     # 这里还要修改
        self.model_dict.update(self.state_dict)
        self.model.load_state_dict(self.model_dict)
        self.model.to(self.device)
        self.model.eval()
    def reduction(self, f_m):
        img = self.model(f_m)
        return img


class Decodenet(nn.Module):        # 解码的网络
    def __init__(self):
        super(Decodenet, self).__init__()
        self.conv_decode_1 = self.conv_block(64, 64)
        self.conv_decode_2 = self.conv_block(64, 32)
        self.conv_decode_3 = self.conv_block(32, 16)
        self.conv_decode_4 = self.conv_block(16, 1)

    def forward(self , se_cat3):
        with torch.no_grad():
            decode_block1 = self.conv_decode_1(se_cat3)
            decode_block2 = self.conv_decode_2(decode_block1)
            decode_block3 = self.conv_decode_3(decode_block2)
            output = self.conv_decode_4(decode_block3)
            return output

    @staticmethod
    def conv_block(in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                )
        return block

你可能感兴趣的:(pytorch自学笔记,pytorch,网络,深度学习)