Vision Transformer代码讲解及复现

VIT(Vision in Transformer)网络结构

Vision Transformer代码讲解及复现_第1张图片

1. Embedding层结构详解

在代码实现中,直接通过一个卷积层来实现。

在输入Transformer Encoder之前注意需要加上token以及Position Embedding。
Vision Transformer代码讲解及复现_第2张图片

import paddle
import paddle.nn as nn
from PIL import Image
import numpy as np
class PatchEmbedding(nn.Layer):
    def __init__(self,image_size,patch_size,in_channels,embed_dim,dropout=0.):
        super().__init__()

        self.embed_dim = embed_dim
        n_patches = (image_size//patch_size)*(image_size//patch_size)
        self.patch_embedding = nn.Conv2D(in_channels=in_channels,
                                    out_channels=embed_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size,
                                    weight_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(1.0)),
                                    bias_attr=False)
        self.dropout = nn.Dropout(dropout)
        # add class token
        self.class_token = paddle.create_parameter(shape=[1,1,embed_dim],
                                                    dtype='float32',
                                                    default_initializer=nn.initializer.Constant(0.))
        # add position embedding
        self.position_embedding = paddle.create_parameter(shape=[1,n_patches+1,embed_dim],dtype='float32',
                                    default_initializer=nn.initializer.TruncatedNormal(std=0.02))
    def forward(self,x):
        # [1,3,224,224]
        # [batch,channel,h,w]
        cls_token = self.class_token.expand([x.shape[0],-1,-1])
        x = self.patch_embedding(x)
        # [1,768, 32,32]
        # [batch,embed_dim,h,w]
        x = x.flatten(2)
        # [1,768, 32*32]
        # [batch,embed_dim,h*w]
        x = x.transpose([0,2,1])
        # [batch,h*w,embed_dim]
        x = paddle.concat([cls_token,x],axis=1)
        x = x + self.position_embedding
        return x

2. Transformer Encoder详解

  • Layer Norm,这种Normalization方法主要是针对NLP领域提出的,下图1也是比较形象的说明BN和LN的区别。
    Vision Transformer代码讲解及复现_第3张图片
  • Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过。
  • Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,但是在pytorch实现中使用的是Droppath。
  • MLP Block,全连接+GELU激活函数+Dropout组成也非常简单。
    Vision Transformer代码讲解及复现_第4张图片

Multi-Head Attention

class Attention(nn.Layer):
    def __init__(self,embed_dim,num_heads,qkv_bias=False,qk_scale=None,dropout=0.,attention_dropout=0.):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim//num_heads
        self.all_head_dim = self.head_dim*num_heads
        self.qkv = nn.Linear(embed_dim,
                            self.all_head_dim*3,
                            bias_attr=False if qk_scale is False else None)
        self.scale = self.head_dim**-0.5 if qk_scale is None else qk_scale
        self.softmax = nn.Softmax(-1)
        self.proj = nn.Linear(self.all_head_dim,embed_dim)
        self.dropout = nn.Dropout(p=dropout)
    def transpose_multi_head(self,x):
        # [B,N,all_head_dim]*3
        # [8,16,4,24]
        new_shape = x.shape[:-1]+[self.num_heads,self.head_dim]
        x = x.reshape(new_shape)
        # [8,4,16,24]
        x = x.transpose([0,2,1,3])
        return x
    def forward(self,x):
        # 8,16,96
        B,N,_ = x.shape
        # 8,16,96*3
        # [8,16,96],[8,16,96],[8,16,96]
        qkv = self.qkv(x).chunk(3,-1)
        #[8,4,16,24]
        #[B,num_heads,num_pathches,head_dim]
        q,k,v = map(self.transpose_multi_head,qkv)
        attn = paddle.matmul(q,k,transpose_y=True)
        attn = self.softmax(attn*self.scale)
        #dropout
        #[B,num_heads,num_pathches,num_pathches]
        attn = self.dropout(attn)
        #[B,num_heads,num_pathches,head_dim]
        out = paddle.matmul(attn,v)
        out = out.transpose([0,2,1,3])
        # 8,16,4,24
        #[B,num_pathches,num_heads,head_dim]
        out = out.reshape([B,N,-1])
        out = self.proj(out)
        out = self.dropout(out)
        return out


MLP

class Mlp(nn.Layer):
    def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.):
        super().__init__()
        w_att_1,b_att_1 = self.init_weight()
        w_att_2,b_att_2 = self.init_weight()
        self.fc1 = nn.Linear(embed_dim,int(embed_dim*mlp_ratio),weight_attr=w_att_1,bias_attr=b_att_1)
        self.fc2 = nn.Linear(int(embed_dim*mlp_ratio),embed_dim,weight_attr=w_att_2,bias_attr=b_att_2)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()
    def init_weight(self):
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=0.2))
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(.0))
        return  weight_attr,bias_attr
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

Encoder Block

class EncoderLayer(nn.Layer):
    def __init__(self,embed_dim,num_head,mlp_ratio=4.0):
        super().__init__()
        w_att_1,b_att_1 = self.init_weight()
        w_att_2,b_att_2 = self.init_weight()
        self.attn = Attention(embed_dim,num_head)
        self.attn_norm = nn.LayerNorm(embed_dim,weight_attr=w_att_1,bias_attr=b_att_1)
        self.mlp = Mlp(embed_dim,mlp_ratio)
        self.mlp_norm = nn.LayerNorm(embed_dim,weight_attr=w_att_2,bias_attr=b_att_2)
    
    def init_weight(self):
        weight = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
        bias = paddle.ParamAttr(initializer=nn.initializer.Constant(.0))
        return weight,bias
        
    def forward(self,x):
        h = x
        x = self.attn_norm(x)
        x = self.attn(x)
        x = h + x 
        h = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = h+x
        return x
class Encoder (nn.Layer):
    def __init__(self,embed_dim,num_head,depth):
        super().__init__()
        w_att_1,b_att_1 = self._init_weights()
        layer_list  = [EncoderLayer(embed_dim,num_head) for i in range(depth)]
        self.encoder = nn.LayerList(layer_list)
        self.encoder_norm = nn.LayerNorm(embed_dim,
                                         weight_attr=w_att_1,
                                         bias_attr=b_att_1,
                                         epsilon=1e-6)

    def _init_weights(self):
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
        return weight_attr, bias_attr
    def forward(self,x):
        for encoder in self.encoder:
            x = encoder(x)
        x = self.encoder_norm(x)
        return x 
class VisualTransformer(nn.Layer):
    def __init__(self,
                image_size = 224,
                patch_size = 16,
                in_channels = 3,
                num_classes = 1000,
                embed_dim = 768,
                depth = 3,
                num_heads = 8,
                mlp_ratio = 4,
                qkv_bias = True,
                dropout = 0.,
                attention_dropout = 0.,
                droppatch = 0.):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size,patch_size,in_channels,embed_dim,dropout)
        self.encoder = Encoder(embed_dim,num_heads,depth)
        self.classifier = nn.Linear(embed_dim,num_classes)
    def forward(self,x):
        x = self.patch_embedding(x)
        x = x.flatten(2)
        x = self.encoder(x)
        x = self.classifier(x[:,0])
        return x

3. 最后我们打印下网络结果

    vit = VisualTransformer()
    print(vit)
    paddle.summary(vit,(4,3,224,224))
VisualTransformer(
  (patch_embedding): PatchEmbedding(
    (patch_embedding): Conv2D(3, 768, kernel_size=[16, 16], stride=[16, 16], data_format=NCHW)
    (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
  )
  (encoder): Encoder(
    (encoder): LayerList(
      (0): EncoderLayer(
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, dtype=float32)
          (softmax): Softmax(axis=-1)
          (proj): Linear(in_features=768, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
        )
        (attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, dtype=float32)
          (fc2): Linear(in_features=3072, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
          (act): GELU(approximate=False)
        )
        (mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
      )
      (1): EncoderLayer(
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, dtype=float32)
          (softmax): Softmax(axis=-1)
          (proj): Linear(in_features=768, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
        )
        (attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, dtype=float32)
          (fc2): Linear(in_features=3072, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
          (act): GELU(approximate=False)
        )
        (mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
      )
      (2): EncoderLayer(
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, dtype=float32)
          (softmax): Softmax(axis=-1)
          (proj): Linear(in_features=768, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
        )
        (attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, dtype=float32)
          (fc2): Linear(in_features=3072, out_features=768, dtype=float32)
          (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
          (act): GELU(approximate=False)
        )
        (mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
      )
    )
    (encoder_norm): LayerNorm(normalized_shape=[768], epsilon=1e-06)
  )
  (classifier): Linear(in_features=768, out_features=1000, dtype=float32)
)
----------------------------------------------------------------------------
  Layer (type)       Input Shape          Output Shape         Param #    
============================================================================
    Conv2D-1      [[4, 3, 224, 224]]    [4, 768, 14, 14]       589,824    
PatchEmbedding-1  [[4, 3, 224, 224]]     [4, 197, 768]         152,064    
  LayerNorm-1      [[4, 197, 768]]       [4, 197, 768]          1,536     
    Linear-1       [[4, 197, 768]]       [4, 197, 2304]       1,771,776   
   Softmax-1      [[4, 8, 197, 197]]    [4, 8, 197, 197]          0       
   Dropout-2       [[4, 197, 768]]       [4, 197, 768]            0       
    Linear-2       [[4, 197, 768]]       [4, 197, 768]         590,592    
  Attention-1      [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-2      [[4, 197, 768]]       [4, 197, 768]          1,536     
    Linear-3       [[4, 197, 768]]       [4, 197, 3072]       2,362,368   
     GELU-1        [[4, 197, 3072]]      [4, 197, 3072]           0       
   Dropout-3       [[4, 197, 768]]       [4, 197, 768]            0       
    Linear-4       [[4, 197, 3072]]      [4, 197, 768]        2,360,064   
     Mlp-1         [[4, 197, 768]]       [4, 197, 768]            0       
 EncoderLayer-1    [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-3      [[4, 197, 768]]       [4, 197, 768]          1,536     
    Linear-5       [[4, 197, 768]]       [4, 197, 2304]       1,771,776   
   Softmax-2      [[4, 8, 197, 197]]    [4, 8, 197, 197]          0       
   Dropout-4       [[4, 197, 768]]       [4, 197, 768]            0       
    Linear-6       [[4, 197, 768]]       [4, 197, 768]         590,592    
  Attention-2      [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-4      [[4, 197, 768]]       [4, 197, 768]          1,536     
    Linear-7       [[4, 197, 768]]       [4, 197, 3072]       2,362,368   
     GELU-2        [[4, 197, 3072]]      [4, 197, 3072]           0       
   Dropout-5       [[4, 197, 768]]       [4, 197, 768]            0       
    Linear-8       [[4, 197, 3072]]      [4, 197, 768]        2,360,064   
     Mlp-2         [[4, 197, 768]]       [4, 197, 768]            0       
 EncoderLayer-2    [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-5      [[4, 197, 768]]       [4, 197, 768]          1,536     
    Linear-9       [[4, 197, 768]]       [4, 197, 2304]       1,771,776   
   Softmax-3      [[4, 8, 197, 197]]    [4, 8, 197, 197]          0       
   Dropout-6       [[4, 197, 768]]       [4, 197, 768]            0       
   Linear-10       [[4, 197, 768]]       [4, 197, 768]         590,592    
  Attention-3      [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-6      [[4, 197, 768]]       [4, 197, 768]          1,536     
   Linear-11       [[4, 197, 768]]       [4, 197, 3072]       2,362,368   
     GELU-3        [[4, 197, 3072]]      [4, 197, 3072]           0       
   Dropout-7       [[4, 197, 768]]       [4, 197, 768]            0       
   Linear-12       [[4, 197, 3072]]      [4, 197, 768]        2,360,064   
     Mlp-3         [[4, 197, 768]]       [4, 197, 768]            0       
 EncoderLayer-3    [[4, 197, 768]]       [4, 197, 768]            0       
  LayerNorm-7      [[4, 197, 768]]       [4, 197, 768]          1,536     
   Encoder-1       [[4, 197, 768]]       [4, 197, 768]            0       
   Linear-13          [[4, 768]]           [4, 1000]           769,000    
============================================================================
Total params: 22,776,040
Trainable params: 22,776,040
Non-trainable params: 0
----------------------------------------------------------------------------
Input size (MB): 2.30
Forward/backward pass size (MB): 323.93
Params size (MB): 86.88
Estimated Total Size (MB): 413.11
----------------------------------------------------------------------------






{'total_params': 22776040, 'trainable_params': 22776040}

4. 参考地址

代码实现

Vision Transformer详解

你可能感兴趣的:(paddle,torch,pytorch,深度学习)