最初的Unet的下采样过程非常的简单。那么我们看一下他的一些变体吧
import torch
import torch.nn as nn
import torch.nn.functional as f
class Downsample_block(nn.Module):
# 这里的in_channels和out_channels 都是通道数
def __init__(self,in_channels,out_channels):
super(Downsample_block,self).__init__()
# 卷积,正则化,激励函数,卷积,正则化,激励函数,最大池化
self.conv1 = nn.Conv2d(in_channels,out_channels,3,padding=1)
self.bn1 = nn.Conv2d(out_channels)
self.conv2 = nn.Conv2d(in_channels,out_channels,3,padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self,x):
x = F.relu(self.bn1(self.conv1(x)))
y = F.relu(self.bn1(self.conv2(x)))
x = F.max_pool2d(y,2,stride=2)
return x,y
它使用了一种叫做预激活的模块,其实就是进行正则化和relu之后再进行卷积
import torch
import torch.nn as nn
class PreActivateDoubleConv(nn.Module):
def __init__(self,in_channels,out_channels):
super(PreActivateDoubleConv,self).__init__()
self.double_conv = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
)
def forward(self,x):
return self.double_conv(x)
那么下采样过程中,还需要加入一个res块。生成的效果如下:
class PreActivateResBlock(nn.Module):
def __init__(self,in_channels,out_channels):
super(PreActivateResBlock,self).__init__()
self.ch_avg = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_channels)
)
self.double_conv = PreActivateDoubleConv(in_channels,out_channels)
self.down_sample = nn.MaxPool2d(2)
def forward(self,x):
identity = self.ch_avg(x)
out = self.double_conv(x)
out = out + identity
return self.down_sample(out),out
denseUnet其实利用了dense块,dense块其实就是一种全部都有跳层连接的块。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Single_level_densenet(nn.Module):
def __init__(self,filters,num_conv=4):
super(Single_level_densenet,self).__init__()
# 卷积个数,这里的filters一直是不变的,说明这个卷积过程对于channel数没有影响
# 最传统的densenet是四个卷积的。
self.num_conv = num_conv
self.con_list = nn.ModuleList()
self.bn_list = nn.ModuleList()
for i in range(self.num_conv):
self.con_list.append(nn.Conv2d(filters,filters,3,padding=1))
self.bn_list.append(nn.BatchNorm2d(filters))
def forward(self,x):
outs = []
outs.append(x)
for i in range(self.num_conv):
temp_out = self.con_list[i](outs[i])
if i > 0 :
#temp_out就是临时保存这一层的结果作为输入或者输出
for j in range(i):
temp_out += outs[j]
outs.append(F.relu(self.bn_list[i](temp_out)))
out_final = outs[-1]
del outs
return out_final
对于Down_sample块,只需要进行一个最大池化操作就可以了。
class Down_sample(nn.Module):
def __init__(self,kernel_size=2,stride=2):
super(Down_sample,self).__init__()
self.down_sample_layer = nn.MaxPool2d(kernel_size,stride)
def forward(self,x):
y = self.down_sample_layer(x)
return y,x
使用VGG块以及密集连接,对Unet进行改进。算是一种比较复杂的改进,
首先是VGG模块,其实就是两层卷积。
import numpy as np
from torch import nn
form torch.nn import functional as F
import torch
class VGGBlock(nn.Module):
def __init__(self,inchannels,middle_channels,out_channels,act_function= nn.ReLU(inplace = True)):
super(VGGBlock,self).__init__()
self.act_func = act_function
self.conv1 = nn.Conv2d(inchannels,middle_channels,3,padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels,out_channels,3,padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self,x):
out = self.act_func(self.bn1(self.conv1(x)))
out = self.act_func(self.bn2(self.conv2(x)))
return out
其实很多块都用到这个了。接下来先看下别的块吧
我们可以了解一下res块,首先定义一下双卷积块。
class DoubleConv(nn.Module):
def __init__(self,in_channels,out_channels):
super(DoubleConv,self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=Ture),
)
def forward(self,x):
return self.double_conv(x)
class ResBlock(nn,Module):
def __init__(self,in_channels,out_channels):
super(ResBlock,self).__init__()
self.thinway = nn.Sequnetial(
nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_channels)
)
self.double_conv = DoubleConv(in_channels,out_channels)
self.down_sample = nn.MaxPool2d(2)
self.relu = nn.ReLU()
def forward(self):
identity = self.thinway(x)
out = self.double_conv(x)
out = self.relu(out + indentity)
return self.down_sample(out),out
接下来学习一些注意力机制的模块:
最基础的可能就是selayer或者sklayer,
那么我们可以看着这张图写出来selayer的代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class SELayer(nn.Module):
def __init__(self,channel,reduction=16):
super(SELayer,self).__init__()
# 首先是进行全局池化
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# channel 变为原来的 1/r
# relu 一下
# channel 变回原来的值
# sigmod 一下
self.fc = nn.Sequential(
nn.Linear(channel,channel//reduction,bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel//reduction,channel,bias,False),
nn.Sigmoid()
)
def forward(self,x):
batchsize,channels,_,_ = x.size()
y = self.avg_pool(x).view(batchsize,channels)
# 全局池化完最后一个图片就是1x1的了
y = self.fc(y).view(batchsize,channels,1,1)
# 将最初的结果和y这个权重相乘
return x * y.expand_as(x)
相当于给每个通道增加了一个权重。
CBAM有两个注意力,一个是通道注意力,另一个是空间注意力。
组合方式: 作者发现顺序组合并且将通道注意力放在前面可以取得更好的效果。而且是先CAM再SAM效果会更好。
代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 传统的单层卷积
class BasicConv(nn.Module):
def __init__(self,in_planes,out_planes,kernel_size,stride = 1,padding = 0,dilation=1,group=1,relu=True,bn=True,bias=False):
super(BasicConv,self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes,out_planes,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5,momentum=0.01,affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self,x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
# 通道注意力要把整个图变成一个数字
class Flatten(nn.Module):
def forward(self,x):
return x.view(x.size(0),-1)
# 还是要先把参数减少的
class ChannelGate(nn.Module):
def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max']):
super(ChannelGate,self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels,gate_channels//reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels//reduction_ratio,gate_channels),
)
self.pool_types = pool_types
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool = nn.AdaptiveMaxPool2d(1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = self.avgpool(x)
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = self.maxpool(x)
channel_att_raw = self.mlp(max_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = self.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) .unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate,self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size -1 ) // 2 ,relu = False)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = self.sigmoid(x_out)
return x * scale
class CBAM(nn.Module):
def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
super(CBAM,self).__init__()
self.ChannelGate = ChannelGate(gate_channels,reduction_ratio,pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self,x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
给大家展示一下每个部分吧:
都用到了双卷积块,所以双卷积块可以写一下
class conv_block(nn.Module):
def __init__(self,in_channels,out_channels):
super(conv_block,self).__init__()
self.cov = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
class Attention_block(nn.Module):
def __init__(self,F_g,F_l,F_int):
super(Attention_block,self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g,F_int,kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l,F_int,kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
# 这里的输出通道数应该是1,可能是空间注意力机制
nn.Conv2d(F_int,1,kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(1),
nn.sigmoid()
)
self.relu = nn.Relu(inplace=True)
def forward(self,g,x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)
先分层,然后把分开的组,再把分开的组一半用通道注意力,一半用空间注意力。其他配置与SElayer一样
class SALayer(nn.Module):
"""Constructs a Channels Spatial Group module
Args:
k_size:Adaptive selection of kernel size
groups : 组的个数
avg_pool : 全局平均池化
cweight : 通道的权重,初始为全0的参数
cbias : 通道的偏移量
sweight : 形状的权重
sbias : 形状的偏移量
GroupNorm : 全局正则化
sigmod : sigmod函数
"""
def __init__(self,channel,groups=2):
super(SALayer,self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cweight = Parameter(torch.zeros(1,channel//(2*groups),1,1))
self.cbias = Parameter(torch.ones(1,channel // (2*groups),1,1))
self.sweight = Parameter(torch.zeros(1,channel // (2*groups),1,1))
self.sbias = Parameter(torch.ones(1,channel // (2*groups),1,1))
self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(channel // (2*groups),channel // (2*groups))
#将所有的切片融合回去
@staticmethod
def channel_shuffle(x,groups):
b,c,h,w = x.shape
x = x.reshape(b,groups,-1,h,w)
# 把组放在了第三个位置
x = x.premute(0,2,1,3,4)
# flatten
# 把组又融合回去了因为一开始group放在前面了
x = x.reshape(b,-1,h,w)
def forward(self,x):
b,c,h,w = x.shape
# 这里的-1应该必然是2了,分组正好是二分组
x = x.reshape(b*self.groups,-1,h,w)
x_0,x_1 = x.chunk(2,dim=1)
# channel attention
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)
#spatial attention
xs = self.gn(x_1)
xs = self.sweight * xs + self.sbias
xs = x_1 * self.sigmoid(xs)
#concatenate along channel axis
out = torch.cat([xn,xs],dim=1)
out = out.reshape(b,-1,h,w)
out = self.channel_shuffle(out,2)
return out
融合了SE Net 以及 non Local block 的GC-Net
代码部分
class GC_block(nn.Moudle):
def __init__(self,in_channels,out_channels):
super(GC_block,self).__init__()
self.conv = nn.Conv2d(in_channels,out_channels,1)
self.trans = nn.Sequential(
nn.Linear(out_channels,out_channels,bias=False),
nn.LayerNorm(out_channels),
nn.ReLU(inplace=True),
nn.Linear(out_channels,out_channels,bias=False)
)
def forward(self,x):
x1 = nn.Softmax(self.conv(x))
y1 = torch.mm(x,x1)
y2 = self.trans(y1)
z = x + y2
return z
class Recurrent_block(nn.Module):
def __init__(self,ch_out,t=2):
super(Recurrent_block,self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding = 1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
for i in range(self.t):
if i == 0:
x1 =self.conv(x)
x1 = self.conv(x + x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self,ch_in,ch_out,t=2):
super(RRCNN_block,self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out,t=t),
Recurrent_block(ch_out,t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in, ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x + x1
import torch.nn as nn
import torch.nn.Parameter import Parameter
import torch
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,relu=True,bn=True,bias=False):
super(BasicConv,self).__init__()
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
self.bn = nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.01,affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self,x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self,x):
return x.view(x.size(0),-1)
class ChannelGate(nn.Module):
def __init__(self,in_channels,gate_channels,reduction_ratio=8,
pool_types = ['avg','max']):
super(channelGate,self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels,gate_channels//reduction_ratio),
nn.ReLU()
nn.Linear(gate_channels//reduction_ratio,gate_channels)
)
self.pool_types = pool_types
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.sigmoid = nn.Sigmoid()
# 这里不一样,如果输入通道数不等于注意力块通道数,先将通道数转化为注意力块通道数
if in_channels != gate_channels:
self.att_fc = nn.Sequential(
nn.Conv2d(in_channels,gate_channels,kernel_size=1),
nn.BatchNorm2d(gate_channels),
nn.ReLU(inplace=True)
)
# self.alpha 是当前注意力块extration提取的特征图的参数
self.alpha = nn.Sequential(
nn.Conv2d(2,1,bias=False,kernel_size=1),
nn.LayerNorm(gate_channels),
nn.ReLU(inplace=True)
)
def forward(self,inputs):
x = inputs[0]
b,c,_,_ = x.size()
pre_att = inputs[1]
channel_att_sum = None
if pre_att is not None:
if hasattr(self,'att_fc'):
pre_att = self.att_fc(pre_att)
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = self.avgpool(x)
if pre_att is not None:
avg_pool = torch.cat((avg_pool.view(b,1,1,c),self.avgpool(pre_att).view(b,1,1,c)),dim=1)
# 这里用到了self.alpha
avg_pool = self.alpha(avg_pool).view(b,c)
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = self.maxpool(x)
if pre_att is not None:
max_pool = torch.cat((max_pool.view(b,1,1,c),self.maxpool(pre_att).view(b,1,1,c)),dim=1)
max_pool = self.alpha(max_pool).view(b,c)
channel_att_raw = self.mlp(max_pool)
if channel_att_sum is None:
channel_att_sum = channels_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = self.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
out = x * scale
return {0:out,1:out}
class ChannelPool(nn.Module):
def forward(self,x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1),torch.mean(x,1).unsqueeze(1)),dim=1)
class SpatialGate(nn.Module):
def __init__(self,inchannel,gate_channels):
super(SpatialGate,self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size-1)//2,relu=False)
self.sigmoid = nn.Sigmoid()
self.p1 = Parameter(torch.ones(1))
self.p2 = Parameter(torch.zeros(1))
self.bnrelu = nn.Sequential(
nn.BatchNorm2d(2),
nn.ReLU(inplace=True)
)
def forward(self,x):
if x[1] is None:
x_compress = self.compress(x[0])
else:
if x[1].size()[2] != x[0].size()[2]:
exten = (x[1].size()[2])//(x[0].size()[2])
pre_spatial_att = F.avg_pool2d(x[1],kernel_size=extent,stride=extent)
else:
pre_spatial_att = x[1]
x_compress = self.bnrelu(self.p1*self.compress(x[0])+self.p2*compress(pre_spatial_att))
x_out = self.spatial(x_compress)
scale = self.sigmoid(x_out)
return {0:x[0]*scale,1:x[0]*scale}
class CBAM(nn.Module):
def __init__(self,in_channel,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
super(CBAM,self).__init__()
self.ChannelGate = ChannelGate(in_channel,gate_channels,reduction_ratio,pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate(in_channel,gate_channels)
def forward(self,x):
#x[0]其实是out, x[1] 实际是x[1]
x_out = self.ChannelGate({0:x[0],1:x[1]})
channel_att = x_out[1]
if not self.no_spatial:
x_out = self.SpatialGate({0:x_out[0],1:x[2]})
return {0:x_out[0],1:channel_att,2:x_out[1]}
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,groups=1,relu=True,bn=True,bias=False):
super(BasicConvs,self).__init__()
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
self.bn = nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.01,affine=True)
self.relu = nn.ReLU() if relu else None
def forward(self,x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ChannelPool(nn.Module):
def forward(self,x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1),torch.mean(x,1).unsqueeze(1)),dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate,self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size-1)//2,relu=False)
def forward(self,x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
super(TripletAttention,self).__init__()
self.ChannelGateH = SpatialGate()
self.ChannelGateW = SpatialGate()
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self,x):
x_perm1 = x.permute(0,2,1,3).contiguous()
x_out1 = self.ChannelGateH(x_perm1)
x_out11 = x_out1.permute(0,2,1,3).contiguous()
x_perm2 = x.permute(0,3,2,1).contiguous()
x_out2 = self.ChannelGateW(x_perm2)
x_out21 = x_out2.permute(0,3,2,1).contiguous()
if not self.no_spatial:
x_out = self.SpatialGate(x)
x_out = (1/3)*(x_out + x_out11+x_out21)
else:
x_out = (1/2)*(x_out11+x_out21)
return x_out