28 - Vision Transformer(ViT)的原理、难点及其逐行实现

文章目录

  • 1. 原理讲解
    • 1.1 VIT大致思想
    • 1.2 VIT 结构示意图
  • 2. 代码实现
  • 3. 小结

1. 原理讲解

1.1 VIT大致思想

  • paper链接
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
  • 视频讲解链接:
    PyTorch28——Vision Transformer(ViT)的原理、难点及其逐行实现
    ViT论文逐段精读【论文精读】

vision transformer 的大致思想是需要将transformer应用到图像之上,我们知道transformer中面临的问题是当我们用像素点进行建模的时候,一张图片所含的像素点太多,序列太长,而一般transformer的序列在512就非常长了;而作者看到transformer这么香,在NLP领域又这么火;所以想到了,能不能把一张图片切成一块一块的patch,我们把每一块图片位置编码后再加入些其他的标记,最后进入到transformer中,这样就产生了vision transformer ;其实这个想法很朴实无华,其实也有很多的学者这样做过了,无奈钱不够,无法将这么好的VIT模型用在大数据集上。最后谷歌就凭借着自己的财大气粗拿下了这个VIT模型,并且在效果上炸裂。所以科研是需要钱的,没钱太难了,科研靠爱发不了电。

  • DNN Perspective
    • Image2patch
    • Patch2embedding
  • CNN Perspective
    • 2d convolution over to image
    • flatten the output feature map
  • class token embedding
  • position embedding
  • Transformer Encoder
  • Classification head

1.2 VIT 结构示意图

28 - Vision Transformer(ViT)的原理、难点及其逐行实现_第1张图片

2. 代码实现

import torch 
from torch import nn
from torch.nn import functional as F

def image2embed_naive(image,patch_size,weight):
    """
    patch and weight
    step1:splits the image into many fixed_size patches
    step2:use matrix multiplication for the patch and weight
    """
    # image = (channel,image_h,image_w)
    patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
    patch_embedding = patch @ weight
    return patch_embedding

def image2embed_conv(image,kernel,patch_size):
    """
    just only use conv2d for embedding
    step1: use the conv2d for image
    step2: reshape the outputs
    """
    output_conv = F.conv2d(image,kernel,stride=patch_size)
    bs,oc,oh,ow = output_conv.shape
    patch_embedding_conv = output_conv.reshape((bs,oc,oh*ow)).transpose(-1,-2)
    return patch_embedding_conv

bs,ic,image_h,image_w = 1,3,8,8
image = torch.randn(bs,ic,image_h,image_w)
patch_size = 4
patch_depth = patch_size*patch_size*ic
model_dim = 8
max_num_token = 16
label = torch.randint(10,(bs,))
# model_dim是速出通道数目,patch_depth是卷积核的面积乘以输入通道数
weight = torch.randn(patch_depth,model_dim)

# 分块方法得到 embedding
output_naive = image2embed_naive(image,patch_size,weight)
kernel = weight.transpose(0,1).reshape(-1,ic,patch_size,patch_size)

# 二维卷积的方法得到 embedding 
output_conv = image2embed_conv(image,kernel,patch_size)
print(f"output_naive={output_naive}")
print(f"output_conv={output_conv}")
print(torch.isclose(output_naive,output_conv))

# step2: prepend CLS token embedding
cls_token_embedding = torch.randn(batch_size,1,model_dim,requires_grad=True)
token_embedding = torch.cat([cls_token_embedding,patch_embedding_conv],dim=1)

# step3: add position embedding
position_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
token_embedding +=position_embedding

# step4: pass embedding to transformer encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=6)
encoder_output = tranformer_encoder(token_embedding) 

# step5: do classification
cls_token_output = encoder_output[:,0,:]
linear_layer = nn.Linear(model_dim,num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits,label)
print(f"loss={loss}")

3. 小结

  • 将图片打散成一个个大小固定的patch块,这样可以将transformer用到视觉领域
  • VIT模型只用了Encoder部分
  • 深度学习十分耗钱,入门需谨慎。

你可能感兴趣的:(pytorch,pytorch)