Keras构建用于分类任务的Transformer(Vision Transformer/VIT)

文章目录

  • 一、Vision Transformer (ViT)详细信息
  • 二、Vision Transformer结构
  • 三、Keras实现
    • 3.1 相关包
    • 3.2 数据读取
    • 3.3 声明超参数
    • 3.4 使用数据增强方法
    • 3.5 计算训练数据的平均值和方差进行归一化
    • 3.6 定义multilayer perceptron (MLP)
    • 3.7 定义块
    • 3.8 数据可视化
    • 3.9 实现Encoding Layer
    • 3.10 构建ViT模型
    • 3.11 训练+评估(AdamW可以换成Adam,效果可能还更好)
  • 四、完整代码
  • 五、可视化
    • 5.1 待切块图和处理后的切块图
    • 5.2 模型结构参数

一、Vision Transformer (ViT)详细信息

VIT 详细信息
文献←下载 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
案例数据集 CIFAR-100

更多网络模型及出处可见: 各神经网络参考文献整理

二、Vision Transformer结构

Transformer已经基本取代RNNs(包括变体LSTM ,GRU),成为自然语言处理(NLP)领域的主流模型。Dosovitskiy等人将该模型迁移到计算机视觉领域,并且尽量减少了对Transformer的更改,因为是分类,所以模型的输出用全连接层代替。由此,Vision Transformer(ViT)应运而生。这是一种用于分类任务的改进Transformer。
ViT结构如下图所示
Keras构建用于分类任务的Transformer(Vision Transformer/VIT)_第1张图片
原作者的文章可能不太详细。为了便于理解,可参考文献提供的结构:
Vision Transformers for Remote Sensing Image Classification
Keras构建用于分类任务的Transformer(Vision Transformer/VIT)_第2张图片

三、Keras实现

经过测试,多头自注意力机制需要Tensorflow2.4、2.5(2.1-2.3是没有的)等带有MultiHeadAttention包的版本,相关配置在这边→点我
代码按顺序输入即可

3.1 相关包

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

#下包非必须,可以注释,后续内容替换为adam梯度优化算法,经过测试效果比原码好
import tensorflow_addons as tfa 

3.2 数据读取

# 类别数
num_classes = 100
# 数据大小
input_shape = (32, 32, 3)
# 读取cifar100数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train 大小: {x_train.shape} - y_train 大小: {y_train.shape}")
print(f"x_test 大小: {x_test.shape} - y_test 大小: {y_test.shape}")

3.3 声明超参数

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # 改变图形大小
patch_size = 6  # 输入图片拆分的块大小
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Transformer layers的大小
transformer_layers = 8
mlp_head_units = [2048, 1024]  # 输出部分的MLP全连接层的大小

3.4 使用数据增强方法

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)

3.5 计算训练数据的平均值和方差进行归一化

data_augmentation.layers[0].adapt(x_train)

3.6 定义multilayer perceptron (MLP)

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

3.7 定义块

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

3.8 数据可视化

import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"图片大小: {image_size} X {image_size}")
print(f"切块大小e: {patch_size} X {patch_size}")
print(f"每个图对应的切块大小: {patches.shape[1]}")
print(f"每个块对应的元素: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

3.9 实现Encoding Layer

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

这里大概率会报错,方法需要声明,方法可见此文

3.10 构建ViT模型

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # 数据增强
    augmented = data_augmentation(inputs)
    # 创建块.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # 创建多个Transformer encoding 块
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # 创建多头自注意力机制 multi-head attention layer,这里经过测试Tensorflow2.5可用
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection.
        x2 = layers.Add()([attention_output, encoded_patches])
     
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # 增加MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # 输出分类.
    logits = layers.Dense(num_classes)(features)
    # 构建
    model = keras.Model(inputs=inputs, outputs=logits)
    model.summary()
    return model

3.11 训练+评估(AdamW可以换成Adam,效果可能还更好)

#tfa.方法可替换为adam
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
    # 下述可直接替换为  optimizer='adam',
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)

四、完整代码

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
#非必须包↓若用adamw则加,实测用keras自带的adam效果更好
import tensorflow_addons as tfa 
# 类别数
num_classes = 100
# 数据大小
input_shape = (32, 32, 3)
# 读取cifar100数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train 大小: {x_train.shape} - y_train 大小: {y_train.shape}")
print(f"x_test 大小: {x_test.shape} - y_test 大小: {y_test.shape}")

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # 改变图形大小
patch_size = 6  # 输入图片拆分的块大小
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Transformer layers的大小
transformer_layers = 8
mlp_head_units = [2048, 1024]  # 输出部分的MLP全连接层的大小

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
data_augmentation.layers[0].adapt(x_train)

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches



plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"图片大小: {image_size} X {image_size}")
print(f"切块大小e: {patch_size} X {patch_size}")
print(f"每个图对应的切块大小: {patches.shape[1]}")
print(f"每个块对应的元素: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
#这里call后需要定义get_config函数,命名自拟,文章3.9中给出
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # 数据增强
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # 创建多个Transformer encoding 块
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # 创建多头自注意力机制 multi-head attention layer,这里经过测试Tensorflow2.5可用
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection.
        x2 = layers.Add()([attention_output, encoded_patches])
     
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # 增加MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # 输出分类.
    logits = layers.Dense(num_classes)(features)
    # 构建
    model = keras.Model(inputs=inputs, outputs=logits)
    model.summary()
    return model

#tfa.方法可替换为adam
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
    # 下述可直接替换为  optimizer='adam',
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)

五、可视化

5.1 待切块图和处理后的切块图

原始图
Keras构建用于分类任务的Transformer(Vision Transformer/VIT)_第3张图片

切块处理后的图(12*12个块,每个块大小为108个像素点)
Keras构建用于分类任务的Transformer(Vision Transformer/VIT)_第4张图片

5.2 模型结构参数

Keras构建用于分类任务的Transformer(Vision Transformer/VIT)_第5张图片
就参数量来说,Transformer对内存的占用率还是很大(相对一般的模型)

你可能感兴趣的:(python,python,keras)