Transformer用于图像分类

对应论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

直接看代码

首先看Transformer 类

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
            #ModuleList是一个存储不同module,并自动将每个模块的参数添加到网络之中的容器
            #与sequential的区别是,它的模块之间并没有先后顺序,运行时可以改
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

self.layers中有多个类定义的对象,按照执行顺序,逐一解释。

Attention类

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads#
        self.heads = heads
        self.scale = dim ** -0.5
        #dim是线性变换后输出张量的最后维度

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        #线性变换改变维度,chunk沿着指定轴最后一维分3块
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        #einsum对张量的求和运算,默认成对出现的下标为求和下标
        mask_value = -torch.finfo(dots.dtype).max
        #finfo 表示浮点的数值属性的对象
        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            # pad是pytorch内置的填充函数
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        #张量求和
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

第二个是FeedForward类

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),#Gaussian Error Linerar Units 引入了随机正则的激活函数
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

简单看下自定义的两各函数,残差函数与归一化函数

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

下面详细描述主类中的使用方法

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        #randn从标准正态分布中抽取一组随机数
        #1,64+1,1024,+1是因为token,可学习变量,不是固定编码
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        #分类输出位置标志,否则分类输出不知道应该取哪个位置
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        #identity建立一个输入模块,什么都不做,通常用在神经网络的输入层

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, mask = None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        #采用了爱因斯坦表达式,具体是采用了einops库实现,内部集成了各种算子,rearrange就是其中一个。
        #p就是patch大小,假设输入是b,3,256,256,则rearrange操作是先变成(b,3,8x32,8x32)
        #最后变成(b,8x8,32x32x3)即(b,64,3072),
        #将每张图片切分成64个小块,每个小块长度是32x32x3=3072
        #也就是说输入长度为64的图像序列,每个元素采用3072长度进行编码。
        x = self.patch_to_embedding(x) 
        #3072有点大,用线性变换完成降维,变成(b,64,1024)
        b, n, _ = x.shape
        
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)        
        x = torch.cat((cls_tokens, x), dim=1)
        #额外追加token,变成b,65,1024
        x += self.pos_embedding[:, :(n + 1)]
        #num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
        #将patch嵌入向量和位置编码向量相加即可作为编码器输入
        x = self.dropout(x)

        x = self.transformer(x, mask)
        # 采用的是没有改动的transformer
        #假设输入是(b,65,1024),那么transformer的输出也是(b,65,1024)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        #65个输出里面只需要第0个输出进行后续分类即可

        x = self.to_latent(x)
        return self.mlp_head(x)
        #在编码器后接fc分类器head

整篇文章使用transformer的思路非常简洁,有助于了解怎样将自然语言处理的transformer方法用于图像处理。

你可能感兴趣的:(Transformer用于图像分类)