这两天在看swin_transform,所以做个笔记,在swin_transform中有一个Patch Merging块,这个模块相信有些人一下子就会觉得很熟悉,因为它和yolov4的Focus模块的样子差不多。所以我想二者的作用应该差不多,都是为了防止在下采样的过程中丢失信息。但是二者的代码或者图示还是有一定区别的,并且所用归一化函数也不一样,相信大家读了代码会发现不一样.
如果大家想了解更多,可以去这篇博文看看:Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客_swin transformer结构
Patch Merging(图片来源太阳花的小绿豆)
yolov5中的FOCUS模块
import torch.nn as nn
import torch
import torch.nn.functional as F
class Patch(nn.Module):
def __init__(self,dim):
super(Patch, self).__init__()
self.laynorm = nn.LayerNorm(4 * dim)
self.linear=nn.Linear(4*dim,2*dim)
def forward(self,x,w,h):
#为了适应这个swin_transform,所以输入的格式才会是B,L,C
B,L,C=x.shape#注意如果输入的是[B,W,H,C]则不需要进行下面的判断
assert L==w*h ,"输入大小有错误!"
x=x.view(B,w,h,C)
padding=(h % 2 == 1) or (w % 2 == 1)#防止为了输入的大小不足以步长为2的大小来提取特征块
if padding:
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 0::2, 1::2, :]
x2 = x[:, 1::2, 0::2, :]
x3 = x[:, 1::2, 1::2, :]
x_all=torch.cat([x0,x1,x2,x3],dim=-1)#[b,w/2,h/2,4*c] dim=-1表示最后一个维度进行增加,如果dim=3,效果也一样
x_all=x_all.view(B,-1,4 * C)
x_all=self.laynorm(x_all)
x_all=self.linear(x_all)
return x_all
Patch Merging
from functools import partial
class Focus(nn.Module):
def __init__(self, in_channels, out_channels, ksize=1, stride=1):
super().__init__()
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride)
def forward(self, x):
patch_top_left = x[..., ::2, ::2]
patch_bot_left = x[..., 1::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
#--------------------------------#
#得出的结果应该是[2,2,12]
#--------------------------------#
return self.conv(x)
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False):
super().__init__()
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
#self.bn=partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)这两种写法是一样的,如果大家见到这种写法,不要慌,作者只是为了方便,为了在里面加上eps=0.001, momentum=0.03这两个参数
self.act = nn.ReLU(True)
def forward(self, x):
return self.act(self.conv(x))
yolov5中的FOCUS模块