在代码实现中,直接通过一个卷积层来实现。
在输入Transformer Encoder之前注意需要加上token以及Position Embedding。
import paddle
import paddle.nn as nn
from PIL import Image
import numpy as np
class PatchEmbedding(nn.Layer):
def __init__(self,image_size,patch_size,in_channels,embed_dim,dropout=0.):
super().__init__()
self.embed_dim = embed_dim
n_patches = (image_size//patch_size)*(image_size//patch_size)
self.patch_embedding = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(1.0)),
bias_attr=False)
self.dropout = nn.Dropout(dropout)
# add class token
self.class_token = paddle.create_parameter(shape=[1,1,embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.))
# add position embedding
self.position_embedding = paddle.create_parameter(shape=[1,n_patches+1,embed_dim],dtype='float32',
default_initializer=nn.initializer.TruncatedNormal(std=0.02))
def forward(self,x):
# [1,3,224,224]
# [batch,channel,h,w]
cls_token = self.class_token.expand([x.shape[0],-1,-1])
x = self.patch_embedding(x)
# [1,768, 32,32]
# [batch,embed_dim,h,w]
x = x.flatten(2)
# [1,768, 32*32]
# [batch,embed_dim,h*w]
x = x.transpose([0,2,1])
# [batch,h*w,embed_dim]
x = paddle.concat([cls_token,x],axis=1)
x = x + self.position_embedding
return x
class Attention(nn.Layer):
def __init__(self,embed_dim,num_heads,qkv_bias=False,qk_scale=None,dropout=0.,attention_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim//num_heads
self.all_head_dim = self.head_dim*num_heads
self.qkv = nn.Linear(embed_dim,
self.all_head_dim*3,
bias_attr=False if qk_scale is False else None)
self.scale = self.head_dim**-0.5 if qk_scale is None else qk_scale
self.softmax = nn.Softmax(-1)
self.proj = nn.Linear(self.all_head_dim,embed_dim)
self.dropout = nn.Dropout(p=dropout)
def transpose_multi_head(self,x):
# [B,N,all_head_dim]*3
# [8,16,4,24]
new_shape = x.shape[:-1]+[self.num_heads,self.head_dim]
x = x.reshape(new_shape)
# [8,4,16,24]
x = x.transpose([0,2,1,3])
return x
def forward(self,x):
# 8,16,96
B,N,_ = x.shape
# 8,16,96*3
# [8,16,96],[8,16,96],[8,16,96]
qkv = self.qkv(x).chunk(3,-1)
#[8,4,16,24]
#[B,num_heads,num_pathches,head_dim]
q,k,v = map(self.transpose_multi_head,qkv)
attn = paddle.matmul(q,k,transpose_y=True)
attn = self.softmax(attn*self.scale)
#dropout
#[B,num_heads,num_pathches,num_pathches]
attn = self.dropout(attn)
#[B,num_heads,num_pathches,head_dim]
out = paddle.matmul(attn,v)
out = out.transpose([0,2,1,3])
# 8,16,4,24
#[B,num_pathches,num_heads,head_dim]
out = out.reshape([B,N,-1])
out = self.proj(out)
out = self.dropout(out)
return out
class Mlp(nn.Layer):
def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.):
super().__init__()
w_att_1,b_att_1 = self.init_weight()
w_att_2,b_att_2 = self.init_weight()
self.fc1 = nn.Linear(embed_dim,int(embed_dim*mlp_ratio),weight_attr=w_att_1,bias_attr=b_att_1)
self.fc2 = nn.Linear(int(embed_dim*mlp_ratio),embed_dim,weight_attr=w_att_2,bias_attr=b_att_2)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def init_weight(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=0.2))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(.0))
return weight_attr,bias_attr
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class EncoderLayer(nn.Layer):
def __init__(self,embed_dim,num_head,mlp_ratio=4.0):
super().__init__()
w_att_1,b_att_1 = self.init_weight()
w_att_2,b_att_2 = self.init_weight()
self.attn = Attention(embed_dim,num_head)
self.attn_norm = nn.LayerNorm(embed_dim,weight_attr=w_att_1,bias_attr=b_att_1)
self.mlp = Mlp(embed_dim,mlp_ratio)
self.mlp_norm = nn.LayerNorm(embed_dim,weight_attr=w_att_2,bias_attr=b_att_2)
def init_weight(self):
weight = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias = paddle.ParamAttr(initializer=nn.initializer.Constant(.0))
return weight,bias
def forward(self,x):
h = x
x = self.attn_norm(x)
x = self.attn(x)
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = h+x
return x
class Encoder (nn.Layer):
def __init__(self,embed_dim,num_head,depth):
super().__init__()
w_att_1,b_att_1 = self._init_weights()
layer_list = [EncoderLayer(embed_dim,num_head) for i in range(depth)]
self.encoder = nn.LayerList(layer_list)
self.encoder_norm = nn.LayerNorm(embed_dim,
weight_attr=w_att_1,
bias_attr=b_att_1,
epsilon=1e-6)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self,x):
for encoder in self.encoder:
x = encoder(x)
x = self.encoder_norm(x)
return x
class VisualTransformer(nn.Layer):
def __init__(self,
image_size = 224,
patch_size = 16,
in_channels = 3,
num_classes = 1000,
embed_dim = 768,
depth = 3,
num_heads = 8,
mlp_ratio = 4,
qkv_bias = True,
dropout = 0.,
attention_dropout = 0.,
droppatch = 0.):
super().__init__()
self.patch_embedding = PatchEmbedding(image_size,patch_size,in_channels,embed_dim,dropout)
self.encoder = Encoder(embed_dim,num_heads,depth)
self.classifier = nn.Linear(embed_dim,num_classes)
def forward(self,x):
x = self.patch_embedding(x)
x = x.flatten(2)
x = self.encoder(x)
x = self.classifier(x[:,0])
return x
vit = VisualTransformer()
print(vit)
paddle.summary(vit,(4,3,224,224))
VisualTransformer(
(patch_embedding): PatchEmbedding(
(patch_embedding): Conv2D(3, 768, kernel_size=[16, 16], stride=[16, 16], data_format=NCHW)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
)
(encoder): Encoder(
(encoder): LayerList(
(0): EncoderLayer(
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, dtype=float32)
(softmax): Softmax(axis=-1)
(proj): Linear(in_features=768, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
)
(attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, dtype=float32)
(fc2): Linear(in_features=3072, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
(act): GELU(approximate=False)
)
(mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
)
(1): EncoderLayer(
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, dtype=float32)
(softmax): Softmax(axis=-1)
(proj): Linear(in_features=768, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
)
(attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, dtype=float32)
(fc2): Linear(in_features=3072, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
(act): GELU(approximate=False)
)
(mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
)
(2): EncoderLayer(
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, dtype=float32)
(softmax): Softmax(axis=-1)
(proj): Linear(in_features=768, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
)
(attn_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, dtype=float32)
(fc2): Linear(in_features=3072, out_features=768, dtype=float32)
(dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)
(act): GELU(approximate=False)
)
(mlp_norm): LayerNorm(normalized_shape=[768], epsilon=1e-05)
)
)
(encoder_norm): LayerNorm(normalized_shape=[768], epsilon=1e-06)
)
(classifier): Linear(in_features=768, out_features=1000, dtype=float32)
)
----------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
============================================================================
Conv2D-1 [[4, 3, 224, 224]] [4, 768, 14, 14] 589,824
PatchEmbedding-1 [[4, 3, 224, 224]] [4, 197, 768] 152,064
LayerNorm-1 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-1 [[4, 197, 768]] [4, 197, 2304] 1,771,776
Softmax-1 [[4, 8, 197, 197]] [4, 8, 197, 197] 0
Dropout-2 [[4, 197, 768]] [4, 197, 768] 0
Linear-2 [[4, 197, 768]] [4, 197, 768] 590,592
Attention-1 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-2 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-3 [[4, 197, 768]] [4, 197, 3072] 2,362,368
GELU-1 [[4, 197, 3072]] [4, 197, 3072] 0
Dropout-3 [[4, 197, 768]] [4, 197, 768] 0
Linear-4 [[4, 197, 3072]] [4, 197, 768] 2,360,064
Mlp-1 [[4, 197, 768]] [4, 197, 768] 0
EncoderLayer-1 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-3 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-5 [[4, 197, 768]] [4, 197, 2304] 1,771,776
Softmax-2 [[4, 8, 197, 197]] [4, 8, 197, 197] 0
Dropout-4 [[4, 197, 768]] [4, 197, 768] 0
Linear-6 [[4, 197, 768]] [4, 197, 768] 590,592
Attention-2 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-4 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-7 [[4, 197, 768]] [4, 197, 3072] 2,362,368
GELU-2 [[4, 197, 3072]] [4, 197, 3072] 0
Dropout-5 [[4, 197, 768]] [4, 197, 768] 0
Linear-8 [[4, 197, 3072]] [4, 197, 768] 2,360,064
Mlp-2 [[4, 197, 768]] [4, 197, 768] 0
EncoderLayer-2 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-5 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-9 [[4, 197, 768]] [4, 197, 2304] 1,771,776
Softmax-3 [[4, 8, 197, 197]] [4, 8, 197, 197] 0
Dropout-6 [[4, 197, 768]] [4, 197, 768] 0
Linear-10 [[4, 197, 768]] [4, 197, 768] 590,592
Attention-3 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-6 [[4, 197, 768]] [4, 197, 768] 1,536
Linear-11 [[4, 197, 768]] [4, 197, 3072] 2,362,368
GELU-3 [[4, 197, 3072]] [4, 197, 3072] 0
Dropout-7 [[4, 197, 768]] [4, 197, 768] 0
Linear-12 [[4, 197, 3072]] [4, 197, 768] 2,360,064
Mlp-3 [[4, 197, 768]] [4, 197, 768] 0
EncoderLayer-3 [[4, 197, 768]] [4, 197, 768] 0
LayerNorm-7 [[4, 197, 768]] [4, 197, 768] 1,536
Encoder-1 [[4, 197, 768]] [4, 197, 768] 0
Linear-13 [[4, 768]] [4, 1000] 769,000
============================================================================
Total params: 22,776,040
Trainable params: 22,776,040
Non-trainable params: 0
----------------------------------------------------------------------------
Input size (MB): 2.30
Forward/backward pass size (MB): 323.93
Params size (MB): 86.88
Estimated Total Size (MB): 413.11
----------------------------------------------------------------------------
{'total_params': 22776040, 'trainable_params': 22776040}
代码实现
Vision Transformer详解