Swin_transformer的Patch Merging块和YOLO的FCOUS模块

        这两天在看swin_transform,所以做个笔记,在swin_transform中有一个Patch Merging块,这个模块相信有些人一下子就会觉得很熟悉,因为它和yolov4的Focus模块的样子差不多。所以我想二者的作用应该差不多,都是为了防止在下采样的过程中丢失信息。但是二者的代码或者图示还是有一定区别的,并且所用归一化函数也不一样,相信大家读了代码会发现不一样.

        如果大家想了解更多,可以去这篇博文看看:Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客_swin transformer结构

                                                      Patch Merging(图片来源太阳花的小绿豆)

yolov5中的FOCUS模块 

import torch.nn as nn
import torch
import torch.nn.functional as F
class Patch(nn.Module):
    def __init__(self,dim):
        super(Patch, self).__init__()
        self.laynorm = nn.LayerNorm(4 * dim)
        self.linear=nn.Linear(4*dim,2*dim)
    def forward(self,x,w,h):
        #为了适应这个swin_transform,所以输入的格式才会是B,L,C
        B,L,C=x.shape#注意如果输入的是[B,W,H,C]则不需要进行下面的判断
        assert L==w*h ,"输入大小有错误!"
        x=x.view(B,w,h,C)
        padding=(h % 2 == 1) or (w % 2 == 1)#防止为了输入的大小不足以步长为2的大小来提取特征块
        if padding:
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 0::2, 1::2, :]
        x2 = x[:, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x_all=torch.cat([x0,x1,x2,x3],dim=-1)#[b,w/2,h/2,4*c] dim=-1表示最后一个维度进行增加,如果dim=3,效果也一样
        x_all=x_all.view(B,-1,4 * C)
        x_all=self.laynorm(x_all)
        x_all=self.linear(x_all)
        return x_all

 Patch Merging

from functools import partial
class Focus(nn.Module):
    def __init__(self, in_channels, out_channels, ksize=1, stride=1):
        super().__init__()
        self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride)

    def forward(self, x):
        patch_top_left  = x[...,  ::2,  ::2]
        patch_bot_left  = x[..., 1::2,  ::2]
        patch_top_right = x[...,  ::2, 1::2]
        patch_bot_right = x[..., 1::2, 1::2]
        x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
        #--------------------------------#
        #得出的结果应该是[2,2,12]
        #--------------------------------#
        return self.conv(x)

class BaseConv(nn.Module):
    def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False):
        super().__init__()
        pad         = (ksize - 1) // 2
        self.conv   = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
        self.bn     = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
        #self.bn=partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)这两种写法是一样的,如果大家见到这种写法,不要慌,作者只是为了方便,为了在里面加上eps=0.001, momentum=0.03这两个参数
        self.act    = nn.ReLU(True)

    def forward(self, x):
        return self.act(self.conv(x))

yolov5中的FOCUS模块 

你可能感兴趣的:(pytorch,深度学习,人工智能)