Transformer模块(Restormer)

由一个MDTA模块和一个GDFN模块组成一个Transformer Block
Transformer模块(Restormer)_第1张图片
我们看一下代码实现:

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

需要注意的是,Transformer Block中的MDTAGDFN都是残差连接

你可能感兴趣的:(Restormer,transformer,深度学习,人工智能)