unet各模块内容的理解(包含注意力机制、残差、以及数据维度的变化)

文章目录

  • attention机制
  • Unet的各个模块的设计
    • ①残差块
    • ②下块 做一次残差,做一次自注意力
    • ③上块:这里做了skip connect,做一次残差,做一次注意力
    • ④中块:做两次残差和一次自注意力
    • ⑤上采样:通道数不变,长宽翻两倍
    • ⑥下采样:通道数不变,长宽缩小到原来的一半
  • 整个unet模块
    • unet模块的示意图

参考的unet代码
unet代码

attention机制

参考内容:
超详细图解Self-Attention

首先是计算QKV的矩阵的值,然后是利用QK计算相关性,然后更具相关性重构v
unet各模块内容的理解(包含注意力机制、残差、以及数据维度的变化)_第1张图片
这里需要结合代码来理解其中的维度变换过程

class AttentionBlock(Module):
    """
    ### Attention block

    This is similar to [transformer multi-head attention](../../transformers/mha.html).
    """
	#只需要输入输入的数据的通道,因为后续的通道数不会发生改变
    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
        """
        * `n_channels` is the number of channels in the input(输入的图像的最后一个维度,图像的通道数)
        * `n_heads` is the number of heads in multi-head attention(多头注意力的头的个数)
        * `d_k` is the number of dimensions in each head(每个多头注意力的维度数)
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        """
        super().__init__()

        # Default `d_k`
        if d_k is None:
            d_k = n_channels
       
        #组归一化:你有一个包含64个通道的输入,并且你设置n_groups=8,那么每个组将包含8个通道,组归一化将在这8个通道上独立地计算均值和标准差,并进行归一化
        self.norm = nn.GroupNorm(n_groups, n_channels)
        
        #将通过线性变化,将通道数增大为:多头注意力头数*每个头的维度,以及*3,用来后续划分为Q、K、V
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
     
        #输出为维度通过线性映射恢复为何输入一致
        self.output = nn.Linear(n_heads * d_k, n_channels)
        # Scale for dot-product attention
        self.scale = d_k ** -0.5
        #
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
  
        #首先得到输入数据的批量大小,通道维度,长,宽
        batch_size, n_channels, height, width = x.shape
       
        #将除通道数的维度进行合并,然后将通道数放在最后面
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
    
        #通过投影,将数据的维度进行提升,让其满足多头注意力的维度数
        #然后将数据的维度变化为:批量大小,像素维度(比如长*宽),头数,3*头的维度
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
   
        #将得到QKV按照最后一个维度进行维度划分,得到QKV矩阵
        q, k, v = torch.chunk(qkv, 3, dim=-1)
  
        #QK进行点积计算,维度变为:批量,像素维度,像素维度,头数
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        #将第二个维度归一化
        attn = attn.softmax(dim=2)
        #attn与v进行点积,实现加权计算,维度变为和输入的QKV一样:批量,像素大小,头数,每头维度
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        #将结果的最后两个维度合并:头数,像素大小,升维的维度
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        #将结果的维度调整为和输入一致
        res = self.output(res)
		#做残差连接
        res += x
		#将将结果的维度调整为和输入一致,将长和宽拆开
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

        #
        return res


Unet的各个模块的设计

①残差块

class ResidualBlock(Module):
    """
    ### Residual block

    A residual block has two convolution layers with group normalization.
    Each resolution is processed with two residual blocks.
    """
    #输入:需要输入通道,和输出通道,最后让输入通道变为输出通道

    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 n_groups: int = 32, dropout: float = 0.1):
        """
        * `in_channels` is the number of input channels
        * `out_channels` is the number of input channels
        * `time_channels` is the number channels in the time step ($t$) embeddings
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        * `dropout` is the dropout rate
        """
        super().__init__()
        # Group normalization and the first convolution layer
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
   		#通过卷积将输入的通道转化为输出的通道
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # Group normalization and the second convolution layer
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        # 如果x的维度不一致,则用卷积进行维度的变换
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        # Linear layer for time embeddings
        #对于时间嵌入,将输入的时间嵌入通道转化为输出的通道
        self.time_emb = nn.Linear(time_channels, out_channels)
        self.time_act = Swish()

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # First convolution layer
        h = self.conv1(self.act1(self.norm1(x)))
        # Add time embeddings
        #时间和输入的x都转化为相同的通道后就可以进行相加了
        #在最后添加两个维度,例如如果时间嵌入维度为:32,100  则变为了 32,100,1,1,这里只是改变了张量的形状,而没有改变数据
        h += self.time_emb(self.time_act(t))[:, :, None, None]
        # Second convolution layer
        h = self.conv2(self.dropout(self.act2(self.norm2(h))))

        # Add the shortcut connection and return
        #做残差相加
        return h + self.shortcut(x)

②下块 做一次残差,做一次自注意力

class DownBlock(Module):
    """
    ### Down block

    This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
    """
	
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        #通过残差块,将输入的维度转化为out_channels
        self.res = ResidualBlock(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
        	#不需要注意时,结果不变,为其本身
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
    	#做一次残差
        x = self.res(x, t)
        #做一次注意力,如果不需要注意力,那么结果不变
        x = self.attn(x)
        return x

③上块:这里做了skip connect,做一次残差,做一次注意力

class UpBlock(Module):
    """
    ### Up block

    This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        #输入和输出的通道式相等,将之间下采样的结果和传入到当前块的结果concatenate起来作为了新的输入
        # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
        # from the first half of the U-Net
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x



④中块:做两次残差和一次自注意力

class MiddleBlock(Module):
    """
    ### Middle block

    It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
    This block is applied at the lowest resolution of the U-Net.
    """

    def __init__(self, n_channels: int, time_channels: int):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x

⑤上采样:通道数不变,长宽翻两倍

class Upsample(nn.Module):
    """
    ### Scale up the feature map by $2 \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        #反卷积函数,通道数不变,实现宽度和高度翻两倍
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)

⑥下采样:通道数不变,长宽缩小到原来的一半

class Downsample(nn.Module):
    """
    ### Scale down the feature map by $\frac{1}{2} \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        #用卷积实现下采样,通道数不变,长宽减少一倍
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)

整个unet模块

unet模块的示意图

R为残差,A为注意力,x为上一步的输出,up为上采样,down为下采样

unet各模块内容的理解(包含注意力机制、残差、以及数据维度的变化)_第2张图片

class UNet(Module):
    """
    ## U-Net
    """

    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True),
                 n_blocks: int = 2):
        """
        * `image_channels` is the number of channels in the image. $3$ for RGB.输入的图像通道
        * `n_channels` is number of channels in the initial feature map that we transform the image into图像的初始化通道,一般是将图像通过VAE编码为4,64,64的维度。
        * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`#每一个层的通道数与输入通道数的倍数关系,一开始为1,然后2表示扩大的2倍。
        * `is_attn` is a list of booleans that indicate whether to use attention at each resolution是否用注意力,默认为第3,4块需要,那么1,2块就只是做了残差了
        * `n_blocks` is the number of `UpDownBlocks` at each resolution
        #每层的上下块的个数
        """
        super().__init__()

        # Number of resolutions
        #计算层数:为4
        n_resolutions = len(ch_mults)

        # Project image into feature map
        #将图像的通道数通过卷积提升到64
        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

        # Time embedding layer. Time embedding has `n_channels * 4` channels
        #时间嵌入模块的通道数为256
        self.time_emb = TimeEmbedding(n_channels * 4)

        # #### First half of U-Net - decreasing resolution
        down = []
        # Number of channels
        #刚开始输入输出通道都一样为64 
        out_channels = in_channels = n_channels
        # For each resolution
        #对设计前半个unet的每一块,总共有4块
        for i in range(n_resolutions):
            # Number of output channels at this resolution
            #,每一块的通道数由resolution指定
            out_channels = in_channels * ch_mults[i]
            # Add `n_blocks`
            #每一块包含的小块的个数由n_blocks指定,默认为2
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            # Down sample at all resolutions except the last
            #在每一块的最后下采样,最后一块不需要,所以只有1,2,3 需要下采样3次,按照图像是下采样4次。
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        # Combine the set of modules
        self.down = nn.ModuleList(down)

        # Middle block
        #中间层,做两次残差,一次注意力
        self.middle = MiddleBlock(out_channels, n_channels * 4, )

        # #### Second half of U-Net - increasing resolution
        up = []
        # Number of channels
        in_channels = out_channels
        # For each resolution
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            #这里不改变通道数,有2块上块
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            # Final block to reduce the number of channels
            out_channels = in_channels // ch_mults[i]
           	#再来一块进行通道数的改变
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                up.append(Upsample(in_channels))

        # Combine the set of modules
        self.up = nn.ModuleList(up)

        # Final normalization and convolution layer
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size]`
        """

        # Get time-step embeddings
        t = self.time_emb(t)

        # Get image projection
        #将图像的维度投影到我门需要的维度,一般为64
        x = self.image_proj(x)
		#用一个列表来存储我们上半部分的unet产生的结果,用于后半部分的concatenate
        # `h` will store outputs at each resolution for skip connection
        h = [x]
        # First half of U-Net
        #每一次下采用后将输出的结果存入h中
        for m in self.down:
            x = m(x, t)
            h.append(x)

        # Middle (bottom)
        x = self.middle(x, t)

        # Second half of U-Net
        for m in self.up:
            if isinstance(m, Upsample):
            	#如果是上采样块则直接进行上采样计算
                x = m(x, t)
            else:
                # Get the skip connection from first half of U-Net and concatenate
                #如果不是上采样块,那么先将x与unet上半部分的输出做concatenate
                s = h.pop()
                #这里沿着第1维度做拼接,实际上就是channel 的增加
                x = torch.cat((x, s), dim=1)
                #每一个非上采样的上块都会与之前的x做concatenate,所以输入的维度都是翻了2倍
                x = m(x, t)

        # Final normalization and convolution
        return self.final(self.act(self.norm(x)))

你可能感兴趣的:(#,扩散模型系统性学习,人工智能,神经网络,pytorch)