tensorflow2实现vision transformer完整版

import tensorflow as tf
from tensorflow.keras.layers import (Dense,Conv2D,LayerNormalization,
                                Layer,Dropout,Input,GlobalAveragePooling1D,
                                Add,Softmax,)
from tensorflow.keras import Sequential,Model


class Identity(Layer):
    # usage: 
    # 首先实例化, attn = Identity()
    # 然后传入tensor, out = attn(a_tensor)
    def __init__(self,embed_dim,num_heads,qkv_scale=None,
                    dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim//num_heads
        self.all_head_dim = self.head_dim * num_heads
        self.qkv = Dense(self.all_head_dim*3)
        self.scale = self.head_dim**(-0.5) if qkv_scale==None else qkv_scale
        self.softmax = Softmax()
        self.dropout = Dropout(dropout)
    
    def transpose_multi_head(self,x):
        # print(x.shape) # _,1024,16
        new_shape = x.shape[:-1] + [self.head_dim,self.num_heads]
        # new_shape = _,1024,4,4
        x = tf.reshape(x,new_shape)
        # print(x.shape)

        return x

    
    def call(self,inputs):
        print(f'inputs.shape = {inputs.shape}') # _,1024,16
        B,N,_ = inputs.shape
        # self.qkv(inputs).shape = _,1024,16*3
        q,k,v = tf.split(self.qkv(inputs),3,axis=-1)
        q = self.transpose_multi_head(q)
        k = self.transpose_multi_head(k)
        v = self.transpose_multi_head(v)

        print(f'q.shape={q.shape}') # _,1024,head_dim,num_head
        print(f'k.shape={k.shape}') # _,1024,4,4
        print(f'v.shape={v.shape}') # _,1024,4,4

        # -> _,1024,1024,4
        atten = tf.matmul(q,k,transpose_b=True)
        atten = self.scale * atten
        atten = self.softmax(atten)
        atten = self.dropout(atten)

        # -> _,1024,4,4
        out = tf.matmul(atten,v)
        out = self.dropout(out)

        # -> _,1024,16
        out = tf.reshape(out,shape=[out.shape[0],out.shape[1],out.shape[2]*out.shape[3]])

        return out


class PatchEmbedding(Layer):
    # imag_size=[224,224],in_channels=3, patch_size=7
    # embed_dim=16,
    def __init__(self,patch_size,embed_dim,dropout=0.):
        super().__init__()
        self.patch_embed = Conv2D(embed_dim,patch_size,patch_size)
        self.dropout = Dropout(dropout)
        

    def call(self,inputs):
        # [batch,224,224,3] -> [batch,32,32,16]
        x = self.patch_embed(inputs)

        # [batch,32,32,16] -> [batch,32*32,16]
        x = tf.reshape(x,shape=[x.shape[0],x.shape[1]*x.shape[2],x.shape[3]])

        x = self.dropout(x)
        
        return x

class MLP(Layer):
    def __init__(self,embed_dim, mlp_ratio=4.0,dropout=0.0):
        super().__init__()
        self.fc1 = Dense(int(embed_dim*mlp_ratio))
        self.fc2 = Dense(embed_dim)
        self.dropout = Dropout(rate=dropout)
    
    def call(self, inputs):
        # [batch,h,w,embed_dims] -> [batch,h,w,embed_dims*mlp_ratio]
        x = self.fc1(inputs)
        x = tf.nn.gelu(x) # 激活函数
        x = self.dropout(x)
        
        # [batch,h,w,embed_dims*mlp_ratio] -> [batch,h,w,embed_dims]
        x = self.fc2(x)
        x = self.dropout(x)

        return x

class Encoder(Layer):
    def __init__(self,embed_dims,num_heads):
        super().__init__()
        self.atten = Identity(embed_dim=embed_dims,num_heads=num_heads) # TODO
        self.atten_norm = LayerNormalization()
        self.mlp = MLP(embed_dims)
        self.mlp_norm = LayerNormalization()
        self.add = Add()

    def call(self,inputs):
        # [batch, h'*w', embed_dims] -> [batch, h'*w', embed_dims]
        h = inputs
        x = self.atten_norm(inputs) # 先做层标准化
        # -> 
        x = self.atten(x)
        x = self.add([x,h])

        h = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = self.add([x,h])

        return x

class ViT(Layer):
    def __init__(self,patch_size,embed_dims,num_heads=3,encoder_length=5,num_classes=2):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size=patch_size,embed_dim=embed_dims)
        # encoder list
        layer_list = []
        layer_list = [Encoder(embed_dims=embed_dims,num_heads=num_heads) for i in range(encoder_length)]
        self.encoders = Sequential(layer_list)
        self.head = Dense(num_classes)
        self.avgpool = GlobalAveragePooling1D()
        self.layernorm = LayerNormalization()
    
    def call(self,inputs):
        # [batch, h, w, embed_dims] -> [batch, h'*w', embed_dims]
        x = self.patch_embed(inputs)

        # 通过encoder_length层encoder
        x = self.encoders(x) 

        # layernorm, 对embed_dims维度做归一化
        x = self.layernorm(x)

        # [batch, h'*w', embed_dims] -> [batch,embed_dims]
        x = self.avgpool(x)

        # [batch, embed_dims] -> [batch, num_classes]
        x = self.head(x)

        return x



if __name__ == '__main__':
    inputs = Input(shape=(224,224,3),batch_size=4)
    vision_transformer = ViT(patch_size=16,embed_dims=768,
                            num_heads=12,encoder_length=12,num_classes=2)
    out = vision_transformer(inputs)
    model = Model(inputs=inputs,outputs=out,name='vit-tf2')
    model.summary()

    # attention = Identity(embed_dim=16,num_heads=4)


tensorflow2实现vision transformer完整版_第1张图片

 但是为什么自定义层的layer信息没有打印出来呢,希望有知道的大佬告知一下

你可能感兴趣的:(tensorflow2,transformer,transformer,深度学习,keras)