MDTA模块(Restormer)

MDTA模块(Restormer)_第1张图片 MDTA模块(Restormer)_第2张图片

From a layer normalized tensor Y ∈ R H ^ × W ^ × C ^ \mathbf{Y} \in \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}} YRH^×W^×C^, our MDTA first generates query ( Q ) (\mathbf{Q}) (Q), key ( K ) (\mathbf{K}) (K) and value ( V ) (\mathbf{V}) (V) projections, enriched with local context. It is achieved by applying 1 × 1 1 \times 1 1×1 convolutions to aggregate pixel-wise cross-channel context followed by 3 × 3 3 \times 3 3×3 depth-wise convolutions to encode channel-wise spatial context, yielding Q = W d Q W p Q Y , K = W d K W p K Y \mathbf{Q}=W_d^Q W_p^Q \mathbf{Y}, \mathbf{K}=W_d^K W_p^K \mathbf{Y} Q=WdQWpQY,K=WdKWpKY and V = W d V W p V Y \mathbf{V}=W_d^V W_p^V \mathbf{Y} V=WdVWpVY. Where W p ( ⋅ ) W_p^{(\cdot)} Wp() is the 1 × 1 1 \times 1 1×1 point-wise convolution and W d ( ⋅ ) W_d^{(\cdot)} Wd() is the 3 × 3 3 \times 3 3×3 depth-wise convolution. We use bias-free convolutional layers in the network. Next, we reshape query and key projections such that their dot-product interaction generates a transposed-attention map A \mathbf{A} A of size R C ^ × C ^ \mathbb{R}^{\hat{C} \times \hat{C}} RC^×C^, instead of the huge regular attention map of size R H ^ W ^ × H ^ W ^ \mathbb{R}^{\hat{H} \hat{W} \times \hat{H} \hat{W}} RH^W^×H^W^. Overall, the MDTA process is defined as:
X ^ = W p Attention ⁡ ( Q ^ , K ^ , V ^ ) + X Attention ⁡ ( Q ^ , K ^ , V ^ ) = V ^ ⋅ Softmax ⁡ ( K ^ ⋅ Q ^ / α ) \hat{\mathbf{X}}=W_p \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})+\mathbf{X}\\ \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})=\hat{\mathbf{V}} \cdot \operatorname{Softmax}(\hat{\mathbf{K}} \cdot \hat{\mathbf{Q}} / \alpha) X^=WpAttention(Q^,K^,V^)+XAttention(Q^,K^,V^)=V^Softmax(K^Q^/α)
where X \mathbf{X} X and X ^ \hat{\mathbf{X}} X^ are the input and output feature maps; Q ^ ∈ R H ^ W ^ × C ^ ; K ^ ∈ R C ^ × H ^ W ^ ; \hat{\mathbf{Q}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}} ; \hat{\mathbf{K}} \in \mathbb{R}^{\hat{C} \times \hat{H} \hat{W}} ; Q^RH^W^×C^;K^RC^×H^W^; and V ^ ∈ R H ^ W ^ × C ^ \hat{\mathbf{V}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}} V^RH^W^×C^ matrices are obtained after reshaping tensors from the original size R H ^ × W ^ × C ^ \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}} RH^×W^×C^. Here, α \alpha α is a learnable scaling parameter to control the magnitude of the dot product of K ^ \hat{\mathbf{K}} K^ and Q ^ \hat{\mathbf{Q}} Q^ before applying the softmax function. Similar to the conventional multi-head SA , we divide the number of channels into ‘heads’ and learn separate attention maps in parallel.

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(
            dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)',
                      head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)',
                      head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)',
                      head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w',
                        head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

这段代码并没有实现图中的Norm模块,该模块的实现可以参考Layer Normalization(层规范化)。我们看一下Transformer Block是如何包装的:

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))#MDTA
        x = x + self.ffn(self.norm2(x))

        return x

可以看到实现的时候是先Norm,然后通过Attention,最后再残差连接,这整个流程才是上图所示

你可能感兴趣的:(机器学习,人工智能,transformer)