现在在尝试各种网络做医学图像分割,这算是我第一次开始尝试Transformer-CNN的图像分割方法。首先想试试这个用完整Vision Transformer(ViT)做编码器的UNETR,可惜这次网上甚至找不到公开的Tensorflow版代码了,无奈只能自己动手试(缝)试(合),日常抛砖引玉。
这次实现的是2D图像分割版本。不过要转成3D的也很简单,ViT本身对图片维度也不敏感,CNN部分把二维卷积换三维就完事儿了。
文献:UNETR: Transformers for 3D Medical Image Segmentation
参考代码:1. Keras官方示例:Transformer图像分类
2. GitHub用户tamasino52的非官方Pytorch实现
UNETR的完整结构如上,感觉和U-Net相比,最主要的变化就是编码器换成了类似Vision Transformer 16 Base的结构,其它的改变基本都是因此而生的。比如说,由于ViT的输出是固定的patches大小([H/p, W/p, D/p], p=patch_size),导致UNETR中类似“跳跃连接”的部分必须使用连续的反卷积恢复特征图分辨率(图中蓝色块),随后再传输到传统U-Net的解码器层(图中黄色块)。
完整的ViT架构和额外添加的诸多卷积块导致UNETR的参数量骤增,达到92M(UNETR论文中表示nn-UNet为19M),但效果确实是好的,现在已经成了很多3D医学图像分割任务的对比方法了。
代码如下:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import keras
import keras.backend as K
from keras.layers import (Layer, BatchNormalization, LayerNormalization, Conv2D, Conv2DTranspose, Embedding,
Activation, Dense, Dropout, MultiHeadAttention, add, Input, concatenate, GlobalAveragePooling1D)
from keras.models import Model
mlp, Patches 和 PatchEncoder 代码来自keras的Code Example
def mlp(x, hidden_units, dropout_rate):
if not isinstance(hidden_units, list): hidden_units = [hidden_units]
for units in hidden_units:
x = Dense(units, activation=tf.nn.gelu)(x)
x = Dropout(dropout_rate)(x)
return x
class Patches(Layer):
'''
提取图像块并序列化
[B, H, W, C]
-> [B, H/patch_size, W/patch_size, C*(patch_size^2)]
-> [B, H*W/(patch_size^2), C*(patch_size^2)]
'''
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images = images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding='VALID',
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
class PatchEncoder(Layer):
'''
将图块线性投影到projection_dim
并且为图块引入一个可学习的位置嵌入
'''
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = Dense(units=projection_dim)
self.position_embeding = Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embeding(positions)
return encoded
def normalization(input_tensor, normalization, name=None):
if normalization=='batch':
return(BatchNormalization(name=None if name is None else name + '_batchnorm')(input_tensor))
elif normalization=='layer':
return(LayerNormalization(epsilon=1e-6, name=None if name is None else name + '_layernorm')(input_tensor))
elif normalization=='group':
return(tfa.layers.GroupNormalization(groups=8, name=None if name is None else name + '_groupnorm')(input_tensor))
elif normalization == None:
return input_tensor
else:
raise ValueError('Invalid normalization')
def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):
'''
Conv2d + Normalization(norm_type:str) + Activation(act_type:str)
'''
output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
output_tensor = normalization(output_tensor, normalization=norm_type)
if act_type is not None: output_tensor = Activation(act_type)(output_tensor)
return output_tensor
def conv2d_block(input_tensor, filters, kernel_size,
norm_type, use_residual, act_type='relu',
double_features = False, dilation=[1, 1], name=None):
x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_conv2d_0')(input_tensor)
x = normalization(x, norm_type, name=None if name is None else name + '_0')
x = Activation(act_type, name=None if name is None else name + act_type + '_0')(x)
if double_features:
filters *= 2
x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_conv2d_1')(x)
x = normalization(x, norm_type, name=None if name is None else name + '_1')
if use_residual:
if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:
shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_shortcut_conv2d')(input_tensor)
shortcut = normalization(shortcut, norm_type, name=None if name is None else name + '_shortcut')
x = add([x, shortcut])
else:
x = add([x, input_tensor])
x = Activation(act_type, name=None if name is None else name + act_type + '_0')(x)
return x
def deconv_conv_block(x,
filters_list: list,
kernel_size,
norm_type,
act_type,
):
'''
对应UNETR结构图中的蓝色块
'''
for filts in filters_list:
x = Conv2DTranspose(filts, 2, (2, 2), kernel_initializer='he_normal')(x)
x = conv_norm_act(x, filts, kernel_size, norm_type, act_type)
return x
def conv_deconv_block(x,
filters,
kernel_size,
norm_type,
use_residual,
act_type,
):
'''
对应UNETR结构图中的黄色+绿色块
'''
x = conv2d_block(x, filters, kernel_size, norm_type, use_residual, act_type)
x = Conv2DTranspose(filters // 2, 2, (2, 2), kernel_initializer='he_normal')(x)
return x
这一部分也是来自keras的Code Example,主要是去掉了分类头,并且增添了要在特定层输出“跳跃连接”的部分,源码中的注释我基本都保留了。
def create_vit(x,
patch_size,
num_patches,
projection_dim,
num_heads,
transformer_units,
transformer_layers,
dropout_rate,
extract_layers,
):
skip_connections = []
# Create patches.
patches = Patches(patch_size)(x)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Create multiple layers of the Transformer block.
for layer in range(transformer_layers):
# Layer normalization 1.
x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output = MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim//num_heads, dropout=dropout_rate
)(x1, x1)
# Skip connection 1.
x2 = add([attention_output, encoded_patches])
# Layer normalization 2.
x3 = LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=dropout_rate)
# Skip connection 2.
encoded_patches = add([x3, x2])
if layer + 1 in extract_layers:
skip_connections.append(encoded_patches)
return skip_connections
def build_model(# ↓ Base arguments
input_shape = (256, 256, 3),
class_nums = 5,
# ↓ ViT arguments
patch_size = 16,
projection_dim = 768,
num_heads = 12,
transformer_units = [2048, 768],
transformer_layers = 12,
extract_layers = [3, 6, 9, 12],
dropout_rate = 0.1,
# ↓ Conv arguments
kernel_size = 3,
conv_norm = 'batch',
conv_act = 'relu',
use_residual = False,
# ↓ Other arguments
show_summary = True,
output_act = 'auto',
):
'''
input_shape: tuple, (height, width, channel) 注意这是2D分割
class_nums: int, 输出通道数
patch_size: int, 图像分块尺寸
projection_dim: int, ViT中的投影维度
num_heads: int, 多头注意力的头数
transformer_units: list, ViT中MLP模块的隐藏层数,注意是列表形式
transformer_layers: int, Transformer的堆叠层数
extract_layers: list, 决定ViT中哪些层要加入"跳跃连接"中,默认是[3, 6, 9, 12]
dropout_rate: float, ViT部分的dropout比率
kernel_size: int, 卷积核尺寸
conv_norm: str, 卷积层的normalization方式, 'batch'或'layer'或'group'
conv_act: str, 卷积层的激活函数
use_residual: bool, 是否使用残差连接
show_summary: bool, 是否显示模型概况
output_act: str, 输出层的激活函数, 'auto'时会根据class_nums决定, 也可以自己指定'softmax'或'sigmoid'
'''
z4_de_filts = 512
z3_de_filts_list = [512]
z2_de_filts_list = [512, 256]
z1_de_filts_list = [512, 256, 128]
z34_conv_filts = 512
z23_conv_filts = 256
z12_conv_filts = 128
z01_conv_filts = 64
if output_act == 'auto': output_act = 'sigmoid' if class_nums == 1 else 'softmax'
assert input_shape[0] == input_shape[1] and input_shape[0] // patch_size
num_patches = (input_shape[0] * input_shape[1]) // (patch_size ** 2)
inputs = Input(input_shape)
z0 = inputs
z1, z2, z3, z4 = create_vit(z0,
patch_size,
num_patches,
projection_dim,
num_heads,
transformer_units,
transformer_layers,
dropout_rate,
extract_layers)
z1 = tf.reshape(z1, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim)) # [B, H/16, W/16, projection_dim]
z2 = tf.reshape(z2, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))
z3 = tf.reshape(z3, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))
z4 = tf.reshape(z4, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))
z4 = Conv2DTranspose(z4_de_filts, 2, (2, 2), kernel_initializer='he_normal')(z4)
z3 = deconv_conv_block(z3, z3_de_filts_list, kernel_size, conv_norm, conv_act)
z3 = concatenate([z3, z4])
z3 = conv_deconv_block(z3, z34_conv_filts, kernel_size, conv_norm, use_residual, conv_act)
z2 = deconv_conv_block(z2, z2_de_filts_list, kernel_size, conv_norm, conv_act)
z2 = concatenate([z2, z3])
z2 = conv_deconv_block(z2, z23_conv_filts, kernel_size, conv_norm, use_residual, conv_act)
z1 = deconv_conv_block(z1, z1_de_filts_list, kernel_size, conv_norm, conv_act)
z1 = concatenate([z1, z2])
z1 = conv_deconv_block(z1, z12_conv_filts, kernel_size, conv_norm, use_residual, conv_act)
z0 = conv2d_block(z0, z01_conv_filts, kernel_size, conv_norm, use_residual, conv_act)
z0 = concatenate([z0, z1])
z0 = conv2d_block(z0, z01_conv_filts, kernel_size, conv_norm, use_residual, conv_act)
outputs = Conv2D(class_nums, 1, activation=output_act)(z0)
model = Model(inputs=inputs, outputs=outputs)
if show_summary: model.summary()
return model
如果以上代码全部放在同一个python脚本中,可以添加下面的代码并运行脚本,尝试构建网络:
if __name__ == '__main__':
x = np.random.uniform(size=(1, 256, 256, 3))
model = build_model(# ↓ Base arguments
input_shape = (256, 256, 3),
class_nums = 5,
# ↓ ViT arguments
patch_size = 16,
projection_dim = 768,
num_heads = 12,
transformer_units = [2048, 768],
transformer_layers = 12,
extract_layers = [3, 6, 9, 12],
dropout_rate = 0.1,
# ↓ Conv arguments
kernel_size = 3,
conv_norm = 'batch',
conv_act = 'relu',
use_residual = False,
# ↓ Other arguments
show_summary = True,
output_act = 'auto',)
y = model(x)
print(x.shape, y.shape)
唉。