vision transformer 的大致思想是需要将transformer应用到图像之上,我们知道transformer中面临的问题是当我们用像素点进行建模的时候,一张图片所含的像素点太多,序列太长,而一般transformer的序列在512就非常长了;而作者看到transformer这么香,在NLP领域又这么火;所以想到了,能不能把一张图片切成一块一块的patch,我们把每一块图片位置编码后再加入些其他的标记,最后进入到transformer中,这样就产生了vision transformer ;其实这个想法很朴实无华,其实也有很多的学者这样做过了,无奈钱不够,无法将这么好的VIT模型用在大数据集上。最后谷歌就凭借着自己的财大气粗拿下了这个VIT模型,并且在效果上炸裂。所以科研是需要钱的,没钱太难了,科研靠爱发不了电。
import torch
from torch import nn
from torch.nn import functional as F
def image2embed_naive(image,patch_size,weight):
"""
patch and weight
step1:splits the image into many fixed_size patches
step2:use matrix multiplication for the patch and weight
"""
# image = (channel,image_h,image_w)
patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
def image2embed_conv(image,kernel,patch_size):
"""
just only use conv2d for embedding
step1: use the conv2d for image
step2: reshape the outputs
"""
output_conv = F.conv2d(image,kernel,stride=patch_size)
bs,oc,oh,ow = output_conv.shape
patch_embedding_conv = output_conv.reshape((bs,oc,oh*ow)).transpose(-1,-2)
return patch_embedding_conv
bs,ic,image_h,image_w = 1,3,8,8
image = torch.randn(bs,ic,image_h,image_w)
patch_size = 4
patch_depth = patch_size*patch_size*ic
model_dim = 8
max_num_token = 16
label = torch.randint(10,(bs,))
# model_dim是速出通道数目,patch_depth是卷积核的面积乘以输入通道数
weight = torch.randn(patch_depth,model_dim)
# 分块方法得到 embedding
output_naive = image2embed_naive(image,patch_size,weight)
kernel = weight.transpose(0,1).reshape(-1,ic,patch_size,patch_size)
# 二维卷积的方法得到 embedding
output_conv = image2embed_conv(image,kernel,patch_size)
print(f"output_naive={output_naive}")
print(f"output_conv={output_conv}")
print(torch.isclose(output_naive,output_conv))
# step2: prepend CLS token embedding
cls_token_embedding = torch.randn(batch_size,1,model_dim,requires_grad=True)
token_embedding = torch.cat([cls_token_embedding,patch_embedding_conv],dim=1)
# step3: add position embedding
position_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
token_embedding +=position_embedding
# step4: pass embedding to transformer encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=6)
encoder_output = tranformer_encoder(token_embedding)
# step5: do classification
cls_token_output = encoder_output[:,0,:]
linear_layer = nn.Linear(model_dim,num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits,label)
print(f"loss={loss}")