ACNet:用于图像超分的非对称卷积(附实现code)

https://blog.csdn.net/huohu728/article/details/115448262

class MeanShift(nn.Conv2d):
    def __init__(self,
                 rgb_range=1.0,
                 rgb_mean=(0.4488, 0.4371, 0.4040),
                 rgb_std=(1.0, 1.0, 1.0),
                 sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        self.requires_grad = False

        
class Upsampler(nn.Sequential):
    def __init__(self,
                 scale,
                 channels,
                 bn=False,
                 act=False,
                 bias=True):
        m = []
        if (scale & (scale - 1)) == 0:
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(channels, 4 * channels, 3, 1, 1, bias=bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(channels))
                if act:
                    m.append(nn.ReLU(inplace=True))
        elif scale == 3:
            m.append(nn.Conv2d(channels, 9 * channels, 3, 1, 1, bias=bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(channels))

            if act:
                m.append(nn.ReLU(inplace=True))
        else:
            raise NotImplementedError
        super().__init__(*m)
        
        
class ACBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ACBlock, self).__init__()
        self.conv1x3 = nn.Conv2d(in_channels, out_channels, (1, 3), 1, (0, 1))
        self.conv3x1 = nn.Conv2d(in_channels, out_channels, (3, 1), 1, (1, 0))
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, (3, 3), 1, (1, 1))

    def forward(self, x):
        conv3x1 = self.conv3x1(x)
        conv1x3 = self.conv1x3(x)
        conv3x3 = self.conv3x3(x)
        return conv3x1 + conv1x3 + conv3x3
      
      
class ACNet(nn.Module):
    def __init__(self,
                 scale=2,
                 in_channels=3,
                 out_channels=3,
                 num_features=64,
                 num_blocks=17,
                 rgb_range=1.0):
        super(ACNet, self).__init__()
        self.scale = scale
        self.num_blocks = num_blocks
        self.num_features = num_features

        # pre and post process
        self.sub_mean = MeanShift(rgb_range=rgb_range, sign=-1)
        self.add_mena = MeanShift(rgb_range=rgb_range, sign=1)

        # AB module
        self.blk1 = ACBlock(in_channels, num_features)
        for idx in range(1, num_blocks):
            self.__setattr__(f"blk{idx+1}", nn.Sequential(nn.ReLU(inplace=True), ACBlock(num_features, num_features)))

        # MEB
        self.lff = nn.Sequential(
            nn.ReLU(inplace=False),
            Upsampler(scale, num_features),
            nn.Conv2d(num_features, num_features, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, 3, 1, 1)
        )
        self.hff = nn.Sequential(
            nn.ReLU(inplace=False),
            Upsampler(scale, num_features),
            nn.Conv2d(num_features, num_features, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, 3, 1, 1)
        )

        # HFFEB
        self.fusion = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(num_features, num_features, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, out_channels, 3, 1, 1),
        )

    def forward(self, x):
        inputs = self.sub_mean(x)
        blk1 = self.blk1(inputs)

        high = blk1
        tmp = blk1
        for idx in range(1, self.num_blocks):
            tmp = self.__getattr__(f"blk{idx+1}")(tmp)
            high = high + tmp

        lff = self.lff(blk1)
        hff = self.hff(high)

        fusion = self.fusion(lff + hff)
        output = self.add_mena(fusion)
        return output   

你可能感兴趣的:(pytorch,难啃的深度学习,算法研究,python)