apple公司提出的MobileVit模型,自己根据论文复现了一下模型

import torch.nn as nn


def conv3x3(in_ch: int, out_ch: int,group:int = 1, stride: int = 1):
    return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride, padding=1,groups=group);

def conv1x1(in_ch: int, out_ch: int,group:int = 1):
    return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1,groups=group);


class MobileNetV2_Block(nn.Module):
    def __init__(self,in_ch: int, out_ch: int,stride: int = 1):
        super(MobileNetV2_Block, self).__init__()
        self.conv1 = conv1x1(in_ch,in_ch * 6,group=in_ch * 6)
        self.conv2 = conv3x3(in_ch * 6,in_ch * 6,stride=stride)
        self.conv3 = conv1x1(in_ch * 6,out_ch,group=out_ch)

    def forward(self,x):
        x = self.conv1(x)
        x = nn.ReLU6(x)
        x = self.conv2(x)
        x = nn.ReLU6(x)
        x = self.conv3(x)
        return x

class MobileVit_Block(nn.Module):
    def __init__(self,in_ch: int, out_ch: int,d_model:int,nhead:int = 2,
                 num_encoder_layers:int = 6,num_decoder_layers:int = 6,
                 dim_feedforward:int = 2048):
        super(MobileVit_Block, self).__init__()
        self.d_model = d_model
        self.conv1 = conv3x3(in_ch,out_ch,group=out_ch)
        self.conv2 = conv1x1(out_ch,d_model)
        self.transformer = nn.Transformer(d_model=d_model,nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward)
        self.conv3 = conv1x1(d_model,out_ch)
        self.conv4 = conv1x1(out_ch, out_ch * 3)
        self.conv5 = conv3x3(out_ch * 3, out_ch)

    def forward(self, x):
        h,w = x.shape()[2:]
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.permute(0,2,3,1)
        x = x.view(-1,h*w,self.d_model)
        x = self.transformer(x)
        x = x.view(-1,h,w,self.d_model)
        x = x.permute(0, 3, 1, 2)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class Mobile_Vit(nn.Module):
    def __init__(self,cnn_block,trans_block,in_ch, out_ch):
        super(Mobile_Vit, self).__init__()

        self.conv1 = conv3x3(in_ch,in_ch,stride=2)
        self.mv1 = self._make_cnn_layer(cnn_block,in_ch,in_ch)
        self.mv2 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mv3 = self._make_cnn_layer(cnn_block,in_ch,in_ch,2)
        self.mv4 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mt1 = self._make_trans_layer(trans_block,in_ch,in_ch,512,2)
        self.mv5 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mt2 = self._make_trans_layer(trans_block,in_ch,in_ch,512,4)
        self.mv6 = self._make_cnn_layer(cnn_block, in_ch, in_ch, stride=2)
        self.mt3 = self._make_trans_layer(trans_block,in_ch,in_ch,512,3)
        self.conv2 = conv1x1(in_ch, out_ch)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.flat = nn.Flatten()

    def _make_cnn_layer(block, in_ch, out_ch, blocks, stride=1):
        layers = []
        for i in range(0, blocks):
            layers.append(block(in_ch, out_ch, stride))
        return nn.Sequential(*layers)

    def _make_trans_layer(block, in_ch, out_ch,d_model, blocks):
        layers = []
        for i in range(0, blocks):
            layers.append(block(in_ch, out_ch,d_model))
        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.conv1(x)
        x = self.mv1(x)
        x = self.mv2(x)
        x = self.mv3(x)
        x = self.mv4(x)
        x = self.mt1(x)
        x = self.mv5(x)
        x = self.mt2(x)
        x = self.mv6(x)
        x = self.mt3(x)
        x = self.conv2(x)
        x = self.flat(self.gap(x))
        return x

欢迎大家批评指正,论文链接:https://arxiv.org/pdf/2110.02178.pdf

apple公司提出的MobileVit模型,自己根据论文复现了一下模型_第1张图片

 

你可能感兴趣的:(深度学习,计算机视觉,深度学习)