import tensorflow as tf
from tensorflow.keras.layers import (Dense,Conv2D,LayerNormalization,
Layer,Dropout,Input,GlobalAveragePooling1D,
Add,Softmax,)
from tensorflow.keras import Sequential,Model
class Identity(Layer):
# usage:
# 首先实例化, attn = Identity()
# 然后传入tensor, out = attn(a_tensor)
def __init__(self,embed_dim,num_heads,qkv_scale=None,
dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim//num_heads
self.all_head_dim = self.head_dim * num_heads
self.qkv = Dense(self.all_head_dim*3)
self.scale = self.head_dim**(-0.5) if qkv_scale==None else qkv_scale
self.softmax = Softmax()
self.dropout = Dropout(dropout)
def transpose_multi_head(self,x):
# print(x.shape) # _,1024,16
new_shape = x.shape[:-1] + [self.head_dim,self.num_heads]
# new_shape = _,1024,4,4
x = tf.reshape(x,new_shape)
# print(x.shape)
return x
def call(self,inputs):
print(f'inputs.shape = {inputs.shape}') # _,1024,16
B,N,_ = inputs.shape
# self.qkv(inputs).shape = _,1024,16*3
q,k,v = tf.split(self.qkv(inputs),3,axis=-1)
q = self.transpose_multi_head(q)
k = self.transpose_multi_head(k)
v = self.transpose_multi_head(v)
print(f'q.shape={q.shape}') # _,1024,head_dim,num_head
print(f'k.shape={k.shape}') # _,1024,4,4
print(f'v.shape={v.shape}') # _,1024,4,4
# -> _,1024,1024,4
atten = tf.matmul(q,k,transpose_b=True)
atten = self.scale * atten
atten = self.softmax(atten)
atten = self.dropout(atten)
# -> _,1024,4,4
out = tf.matmul(atten,v)
out = self.dropout(out)
# -> _,1024,16
out = tf.reshape(out,shape=[out.shape[0],out.shape[1],out.shape[2]*out.shape[3]])
return out
class PatchEmbedding(Layer):
# imag_size=[224,224],in_channels=3, patch_size=7
# embed_dim=16,
def __init__(self,patch_size,embed_dim,dropout=0.):
super().__init__()
self.patch_embed = Conv2D(embed_dim,patch_size,patch_size)
self.dropout = Dropout(dropout)
def call(self,inputs):
# [batch,224,224,3] -> [batch,32,32,16]
x = self.patch_embed(inputs)
# [batch,32,32,16] -> [batch,32*32,16]
x = tf.reshape(x,shape=[x.shape[0],x.shape[1]*x.shape[2],x.shape[3]])
x = self.dropout(x)
return x
class MLP(Layer):
def __init__(self,embed_dim, mlp_ratio=4.0,dropout=0.0):
super().__init__()
self.fc1 = Dense(int(embed_dim*mlp_ratio))
self.fc2 = Dense(embed_dim)
self.dropout = Dropout(rate=dropout)
def call(self, inputs):
# [batch,h,w,embed_dims] -> [batch,h,w,embed_dims*mlp_ratio]
x = self.fc1(inputs)
x = tf.nn.gelu(x) # 激活函数
x = self.dropout(x)
# [batch,h,w,embed_dims*mlp_ratio] -> [batch,h,w,embed_dims]
x = self.fc2(x)
x = self.dropout(x)
return x
class Encoder(Layer):
def __init__(self,embed_dims,num_heads):
super().__init__()
self.atten = Identity(embed_dim=embed_dims,num_heads=num_heads) # TODO
self.atten_norm = LayerNormalization()
self.mlp = MLP(embed_dims)
self.mlp_norm = LayerNormalization()
self.add = Add()
def call(self,inputs):
# [batch, h'*w', embed_dims] -> [batch, h'*w', embed_dims]
h = inputs
x = self.atten_norm(inputs) # 先做层标准化
# ->
x = self.atten(x)
x = self.add([x,h])
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = self.add([x,h])
return x
class ViT(Layer):
def __init__(self,patch_size,embed_dims,num_heads=3,encoder_length=5,num_classes=2):
super().__init__()
self.patch_embed = PatchEmbedding(patch_size=patch_size,embed_dim=embed_dims)
# encoder list
layer_list = []
layer_list = [Encoder(embed_dims=embed_dims,num_heads=num_heads) for i in range(encoder_length)]
self.encoders = Sequential(layer_list)
self.head = Dense(num_classes)
self.avgpool = GlobalAveragePooling1D()
self.layernorm = LayerNormalization()
def call(self,inputs):
# [batch, h, w, embed_dims] -> [batch, h'*w', embed_dims]
x = self.patch_embed(inputs)
# 通过encoder_length层encoder
x = self.encoders(x)
# layernorm, 对embed_dims维度做归一化
x = self.layernorm(x)
# [batch, h'*w', embed_dims] -> [batch,embed_dims]
x = self.avgpool(x)
# [batch, embed_dims] -> [batch, num_classes]
x = self.head(x)
return x
if __name__ == '__main__':
inputs = Input(shape=(224,224,3),batch_size=4)
vision_transformer = ViT(patch_size=16,embed_dims=768,
num_heads=12,encoder_length=12,num_classes=2)
out = vision_transformer(inputs)
model = Model(inputs=inputs,outputs=out,name='vit-tf2')
model.summary()
# attention = Identity(embed_dim=16,num_heads=4)
但是为什么自定义层的layer信息没有打印出来呢,希望有知道的大佬告知一下