__init__
def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
channel
: 输入特征的通道数。dim
: Transformer部分的特征维度。depth
: Transformer的层数。kernel_size
: 卷积层的核大小。patch_size
: 将图像分割为patches的尺寸。mlp_dim
: Transformer中前馈网络的维度。dropout
: Dropout比率,用于正则化。self.mv01 = IRBlock(channel, channel)
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
IRBlock
和 conv_nxn_bn
, conv_1x1_bn
用于特征提取和维度变换。UserDefined
是之前提到的基于Transformer的结构,用于处理序列数据。def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
forward
def forward(self, x):
y = x.clone()
x = self.conv1(x)
x = self.conv2(x)
z = x.clone()
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = self.conv3(x)
x = torch.cat((x, z), 1)
x = self.conv4(x)
x = x + y
x = self.mv01(x)
return x
forward
方法定义了数据通过网络的流程。x
首先经过几个卷积层进行特征提取和维度变换。rearrange
),准备送入Transformer结构。IRBlock
。class MobileViTBv3(nn.Module):
def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
super().__init__()
self.ph, self.pw = patch_size
self.mv01 = IRBlock(channel, channel)
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
def forward(self, x):
y = x.clone()
x = self.conv1(x)
x = self.conv2(x)
z = x.clone()
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = self.conv3(x)
x = torch.cat((x, z), 1)
x = self.conv4(x)
x = x + y
x = self.mv01(x)
return x