2021SC@SDUSC 开源项目GFPGAN-3-2021-10-15-class ModulatedConv2d(nn.Module)分析

2021SC@SDUSC

目录

  • 一、分析的代码片段
    • 1、代码展示
    • 2、代码作用
    • 3、代码原理细究
      • 1)__init__()方法
        • in_channels:
        • out_channels
        • kernel_size
        • demodulate=True
        • sample_mode
        • eps
      • 2) forward(self, x, style)方法
      • 3)__repr__(self):方法
  • 二、类NormStyleCode在项目中的具体应用
      • 1)__init__()方法
      • 2) forward(self, x, style)方法
      • 3)__repr__(self):方法
  • 三、总结感悟

一、分析的代码片段

1、代码展示

class ModulatedConv2d(nn.Module):
    """Modulated Conv2d used in StyleGAN2.

    There is no bias in ModulatedConv2d.

    Args:
        in_channels (int): Channel number of the input.
        out_channels (int): Channel number of the output.
        kernel_size (int): Size of the convolving kernel.
        num_style_feat (int): Channel number of style features.
        demodulate (bool): Whether to demodulate in the conv layer.
            Default: True.
        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
            Default: None.
        eps (float): A value added to the denominator for numerical stability.
            Default: 1e-8.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 num_style_feat,
                 demodulate=True,
                 sample_mode=None,
                 eps=1e-8):
        super(ModulatedConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.demodulate = demodulate
        self.sample_mode = sample_mode
        self.eps = eps

        # modulation inside each modulated conv
        self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
        # initialization
        default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')

        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
            math.sqrt(in_channels * kernel_size**2))
        self.padding = kernel_size // 2

    def forward(self, x, style):
        """Forward function.

        Args:
            x (Tensor): Tensor with shape (b, c, h, w).
            style (Tensor): Tensor with shape (b, num_style_feat).

        Returns:
            Tensor: Modulated tensor after convolution.
        """
        b, c, h, w = x.shape  # c = c_in
        # weight modulation
        style = self.modulation(style).view(b, 1, c, 1, 1)
        # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
        weight = self.weight * style  # (b, c_out, c_in, k, k)

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(b, self.out_channels, 1, 1, 1)

        weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)

        if self.sample_mode == 'upsample':
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        elif self.sample_mode == 'downsample':
            x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)

        b, c, h, w = x.shape
        x = x.view(1, b * c, h, w)
        # weight: (b*c_out, c_in, k, k), groups=b
        out = F.conv2d(x, weight, padding=self.padding, groups=b)
        out = out.view(b, self.out_channels, *out.shape[2:4])

        return out

    def __repr__(self):
        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
                f'out_channels={self.out_channels}, '
                f'kernel_size={self.kernel_size}, '
                f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')


2、代码作用

从关键函数forward(self, x, style):来看,这个模块的作用就是将对应的张量x规范化,
并且返会相应的卷积;
本部分模块涉及神经网络

3、代码原理细究

1)init()方法

参考自 相关文章

  几个参数解析:
             in_channels,
             out_channels,
             kernel_size,
             num_style_feat
             demodulate=True,
             sample_mode=None,
             eps=1e-8

in_channels:

     这个很好理解,就是输入的四维张量[N, C, H, W]中的C了,即输入张量的channels数。
     这个形参是确定权重等可学习参数的shape所必需的。

out_channels

期望的四维输出张量的channels数,不再多说。

kernel_size

 卷积核的大小,一般我们会使用5x5、3x3这种左右两个数相同的卷积核,
 因此这种情况只需要写kernel_size = 5这样的就行了。
 如果左右两个数不同,比如3x5的卷积核,那么写作kernel_size = (3, 5),注意需要写一个tuple,而不能写一个列表(list)

demodulate=True

  demodulate (bool): Whether to demodulate in the conv layer.
        Default: True.

sample_mode

    sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
        Default: None.

eps

    eps (float): A value added to the denominator for numerical stability.
        Default: 1e-8.

2) forward(self, x, style)方法

*作用:同类NormStyleCode-forward(self, x):方法类似,规范图片的相关格式,
       返回out{out = out.view(b, self.out_channels, *out.shape[2:4])}信息
       也就是 Tensor: Modulated tensor after convolution
*关键名词
 weight:pytorch 权重

3)repr(self):方法

作用:对self对象进行几个维度的变量修改
从代码可以看出:
   (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
            f'out_channels={self.out_channels}, '
            f'kernel_size={self.kernel_size}, '
            f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')

二、类NormStyleCode在项目中的具体应用

1)init()方法

相当普遍的一个修改维度参数的方法

2021SC@SDUSC 开源项目GFPGAN-3-2021-10-15-class ModulatedConv2d(nn.Module)分析_第1张图片

2) forward(self, x, style)方法

只在类内定义过

3)repr(self):方法

2021SC@SDUSC 开源项目GFPGAN-3-2021-10-15-class ModulatedConv2d(nn.Module)分析_第2张图片

三、总结感悟

1、代码出现许多不太明白的变量和公式,但只是研究代码作用,内部细节不必理会
2、内容涉及pytorch ,有必要对深度学习方面的知识做个大概的理解

你可能感兴趣的:(机器学习)