近年来,计算机视觉领域中 Transformer 模型的崛起为图像处理带来了新的活力。特别是在 ViT(Vision Transformer)模型提出之后,Transformer 在图像分类、目标检测等任务上展示了超越 CNN 的潜力。然而,标准的 ViT 模型参数量大,计算复杂度高,难以在移动设备等资源受限的环境中部署。
最近,《MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer》 这篇论文提出了一种轻量化、通用且适合移动端的视觉变换器模型。该模型通过结合局部和全局特征的创新设计,在保持良好性能的同时,大大降低了计算资源的需求,为移动应用提供了新的解决方案。
本文将从零开始解读并实现 MobileViT 的核心注意力机制模块,帮助开发者理解这一轻量级视觉变换器的工作原理,从而在实际项目中灵活运用。
标准的 ViT 模型将整个图像划分为不重叠的 patches(块),并将其转换为序列输入到基于Transformer 的编码器中。这种方法虽然在性能上表现出色,但也带来了以下问题:
MobileViT 提出了一种折中的解决方案——结合 局部表示(Local Representation) 和 全局表示(Global Representation),以降低计算复杂度同时保持性能。其核心思想是:
MobileViT 的核心模块是 MobileViTAttention
。我们需要逐步解读其实现细节,并通过代码示例帮助读者理解其工作原理。
[batch_size, in_channel, height, width]
模块主要包含以下几个部分:
以下是完整的 MobileViTAttention
类的实现代码:
import torch
from torch import nn
class MobileViT_Attention(nn.Module):
def __init__(self, in_channels=3, kernel_size=3, patch_size=2, embed_dim=144):
super().__init__()
# 设置 patch 的大小(默认为7x7)
self.ph, self.pw = patch_size, patch_size
# 局部特征提取:通过卷积操作捕获局部上下文信息
self.local_conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
padding=kernel_size//2, stride=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
# 全局特征提取:将张量重排为 [batch_size, patch_height*patch_width, N_h*N_w, embed_dim]
# Transformer 模块用于捕获全局上下文信息
self.global_trans = Transformer(embed_dim=embed_dim,
num_heads=16,
num_transformer_layers=4)
# 特征融合:将局部特征和全局特征拼接,并通过卷积操作生成最终输出
self.fusion_conv = nn.Sequential(
nn.Conv2d(in_channels*2, in_channels, kernel_size=kernel_size,
padding=kernel_size//2, stride=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
# 提取局部特征
local_feats = self.local_conv(x) # 局部特征
if len(local_feats.shape) == 4:
B, C, H, W = local_feats.shape
else:
raise ValueError("Input tensor should have rank 4.")
# 分割图像为 patch,并进行重排:从 [B, C, H, W] 到 [B, (H*W), C]
# 每个 patch 的大小为 (patch_size, patch_size)
patches = []
for i in range(0, H, self.ph):
for j in range(0, W, self.pw):
patch = local_feats[:, :, i:i+self.ph, j:j+self.pw]
patch = torch.flatten(patch, start_dim=2) # 打平patch
patches.append(patch)
# 拼接所有的 patch,形成张量 [B, num_patches, C]
x_patched = torch.stack(patches, dim=1)
# 传递到 Transformer 中提取全局特征
global_feats = self.global_trans(x_patched) # 全局上下文特征
# 特征融合:将原始输入的局部特征与 Transformer 输出的全局特征拼接
x_fused = torch.cat([local_feats, global_feats.unsqueeze(2).unsqueeze(3)], dim=1)
return self.fusion_conv(x_fused) # 最终的特征输出
class Transformer(nn.Module):
def __init__(self, embed_dim=768, num_heads=12,
num_transformer_layers=4):
super().__init__()
self.embedding = nn.Linear(embed_dim, embed_dim)
self.layers =(nn.ModuleList([
TransformerBlock(d_model=embed_dim, nhead=num_heads)
for _ in range(num_transformer_layers)
]))
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return x
nn.Conv2d
在局部区域内捕获上下文信息。self.local_conv = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, padding=1),
nn.BatchNorm2d(3),
nn.ReLU(inplace=True)
)
patch_size x patch_size
的小块,每个块展开成一维向量。class TransformerBlock(nn.Module):
def __init__(self, d_model=768, nhead=12):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
out = self.self_attn(x, x, x)[0]
return F.dropout(out, p=0.1, training=self.training)
self.fusion_conv = nn.Sequential(
nn.Conv2d(3*2, 3, kernel_size=3, padding=1),
nn.BatchNorm2d(3),
nn.ReLU(inplace=True)
)
[batch_size, in_channels, height, width]
[ batch_size: 4, in_channels: 3 (RGB), height: 224, width: 224 ]
[batch_size, in_channels, height, width]
通过结合局部和全局特征提取,MobileViT 成功地在轻量级计算资源的基础上实现了高效的视觉信息处理。这一模块尤其适合应用于移动设备和嵌入式系统中,同时也可以作为其他视觉任务(如目标检测、图像分割)的高效特征提取模块。
未来的工作可以尝试以下方向:
希望通过对这一模块的解读和实现,能够帮助读者更好地理解和应用 MobileViT 模型,在实际项目中发挥其优势。