import torch.nn as nn
def conv3x3(in_ch: int, out_ch: int,group:int = 1, stride: int = 1):
return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride, padding=1,groups=group);
def conv1x1(in_ch: int, out_ch: int,group:int = 1):
return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1,groups=group);
class MobileNetV2_Block(nn.Module):
def __init__(self,in_ch: int, out_ch: int,stride: int = 1):
super(MobileNetV2_Block, self).__init__()
self.conv1 = conv1x1(in_ch,in_ch * 6,group=in_ch * 6)
self.conv2 = conv3x3(in_ch * 6,in_ch * 6,stride=stride)
self.conv3 = conv1x1(in_ch * 6,out_ch,group=out_ch)
def forward(self,x):
x = self.conv1(x)
x = nn.ReLU6(x)
x = self.conv2(x)
x = nn.ReLU6(x)
x = self.conv3(x)
return x
class MobileVit_Block(nn.Module):
def __init__(self,in_ch: int, out_ch: int,d_model:int,nhead:int = 2,
num_encoder_layers:int = 6,num_decoder_layers:int = 6,
dim_feedforward:int = 2048):
super(MobileVit_Block, self).__init__()
self.d_model = d_model
self.conv1 = conv3x3(in_ch,out_ch,group=out_ch)
self.conv2 = conv1x1(out_ch,d_model)
self.transformer = nn.Transformer(d_model=d_model,nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward)
self.conv3 = conv1x1(d_model,out_ch)
self.conv4 = conv1x1(out_ch, out_ch * 3)
self.conv5 = conv3x3(out_ch * 3, out_ch)
def forward(self, x):
h,w = x.shape()[2:]
x = self.conv1(x)
x = self.conv2(x)
x = x.permute(0,2,3,1)
x = x.view(-1,h*w,self.d_model)
x = self.transformer(x)
x = x.view(-1,h,w,self.d_model)
x = x.permute(0, 3, 1, 2)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
class Mobile_Vit(nn.Module):
def __init__(self,cnn_block,trans_block,in_ch, out_ch):
super(Mobile_Vit, self).__init__()
self.conv1 = conv3x3(in_ch,in_ch,stride=2)
self.mv1 = self._make_cnn_layer(cnn_block,in_ch,in_ch)
self.mv2 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
self.mv3 = self._make_cnn_layer(cnn_block,in_ch,in_ch,2)
self.mv4 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
self.mt1 = self._make_trans_layer(trans_block,in_ch,in_ch,512,2)
self.mv5 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
self.mt2 = self._make_trans_layer(trans_block,in_ch,in_ch,512,4)
self.mv6 = self._make_cnn_layer(cnn_block, in_ch, in_ch, stride=2)
self.mt3 = self._make_trans_layer(trans_block,in_ch,in_ch,512,3)
self.conv2 = conv1x1(in_ch, out_ch)
self.gap = nn.AdaptiveAvgPool2d(1)
self.flat = nn.Flatten()
def _make_cnn_layer(block, in_ch, out_ch, blocks, stride=1):
layers = []
for i in range(0, blocks):
layers.append(block(in_ch, out_ch, stride))
return nn.Sequential(*layers)
def _make_trans_layer(block, in_ch, out_ch,d_model, blocks):
layers = []
for i in range(0, blocks):
layers.append(block(in_ch, out_ch,d_model))
return nn.Sequential(*layers)
def forward(self,x):
x = self.conv1(x)
x = self.mv1(x)
x = self.mv2(x)
x = self.mv3(x)
x = self.mv4(x)
x = self.mt1(x)
x = self.mv5(x)
x = self.mt2(x)
x = self.mv6(x)
x = self.mt3(x)
x = self.conv2(x)
x = self.flat(self.gap(x))
return x
欢迎大家批评指正,论文链接:https://arxiv.org/pdf/2110.02178.pdf