UNet系列网络合集

1.U-net

最初的Unet的下采样过程非常的简单。那么我们看一下他的一些变体吧

import torch
import torch.nn as nn
import torch.nn.functional as f

class Downsample_block(nn.Module):
	# 这里的in_channels和out_channels 都是通道数
	def __init__(self,in_channels,out_channels):
		super(Downsample_block,self).__init__()
		# 卷积,正则化,激励函数,卷积,正则化,激励函数,最大池化
		self.conv1 = nn.Conv2d(in_channels,out_channels,3,padding=1)
		self.bn1 = nn.Conv2d(out_channels)
		self.conv2 = nn.Conv2d(in_channels,out_channels,3,padding=1)
		self.bn2 = nn.BatchNorm2d(out_channels)
	def forward(self,x):
		x  = F.relu(self.bn1(self.conv1(x)))
		y  = F.relu(self.bn1(self.conv2(x)))
		x  = F.max_pool2d(y,2,stride=2)
	 	return x,y

2. DeepResUNet

它使用了一种叫做预激活的模块,其实就是进行正则化和relu之后再进行卷积

import torch
import torch.nn as nn

class PreActivateDoubleConv(nn.Module):
	def __init__(self,in_channels,out_channels):
		super(PreActivateDoubleConv,self).__init__()
		self.double_conv = nn.Sequential(
			nn.BatchNorm2d(in_channels),
			nn.ReLU(inplace=True),
			nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
			nn.ReLU(inplace=True),
			nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
		)
	def forward(self,x):
		return self.double_conv(x)

那么下采样过程中,还需要加入一个res块。生成的效果如下:

class PreActivateResBlock(nn.Module):
	def __init__(self,in_channels,out_channels):
		super(PreActivateResBlock,self).__init__()
		self.ch_avg = nn.Sequential(
			nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False),
			nn.BatchNorm2d(out_channels)
		)
		self.double_conv = PreActivateDoubleConv(in_channels,out_channels)
		self.down_sample = nn.MaxPool2d(2)
	def forward(self,x):
		identity = self.ch_avg(x)
		out = self.double_conv(x)
		out = out + identity
		return self.down_sample(out),out

3.denseUnet

denseUnet其实利用了dense块,dense块其实就是一种全部都有跳层连接的块。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Single_level_densenet(nn.Module):
	def __init__(self,filters,num_conv=4):
		super(Single_level_densenet,self).__init__()
		# 卷积个数,这里的filters一直是不变的,说明这个卷积过程对于channel数没有影响
		# 最传统的densenet是四个卷积的。
		self.num_conv = num_conv
		self.con_list = nn.ModuleList()
		self.bn_list = nn.ModuleList()
		for i in range(self.num_conv):
			self.con_list.append(nn.Conv2d(filters,filters,3,padding=1))
			self.bn_list.append(nn.BatchNorm2d(filters))
	def forward(self,x):
		outs = []
		outs.append(x)
		for i in range(self.num_conv):
			temp_out = self.con_list[i](outs[i])
			if i > 0 :
				#temp_out就是临时保存这一层的结果作为输入或者输出
				for j in range(i):
					temp_out += outs[j]
			outs.append(F.relu(self.bn_list[i](temp_out)))
		out_final = outs[-1]
		del outs
		return out_final

对于Down_sample块,只需要进行一个最大池化操作就可以了。

class Down_sample(nn.Module):
	def __init__(self,kernel_size=2,stride=2):
		super(Down_sample,self).__init__()
		self.down_sample_layer = nn.MaxPool2d(kernel_size,stride)
	def forward(self,x):
		y = self.down_sample_layer(x)
		return y,x

4.VGGBlock & Unet++

使用VGG块以及密集连接,对Unet进行改进。算是一种比较复杂的改进,
首先是VGG模块,其实就是两层卷积。

import numpy as np
from torch import nn
form torch.nn import functional as F
import torch

class VGGBlock(nn.Module):
	def __init__(self,inchannels,middle_channels,out_channels,act_function= nn.ReLU(inplace = True)):
		super(VGGBlock,self).__init__()
		self.act_func = act_function
		self.conv1 = nn.Conv2d(inchannels,middle_channels,3,padding=1)
		self.bn1 = nn.BatchNorm2d(middle_channels)
		self.conv2 = nn.Conv2d(middle_channels,out_channels,3,padding=1)
		self.bn2 = nn.BatchNorm2d(out_channels)
	def forward(self,x):
		out = self.act_func(self.bn1(self.conv1(x)))
		out = self.act_func(self.bn2(self.conv2(x)))
		return out

其实很多块都用到这个了。接下来先看下别的块吧

5.ResBlock & Unet++Res version

我们可以了解一下res块,首先定义一下双卷积块。

class DoubleConv(nn.Module):
	def __init__(self,in_channels,out_channels):
		super(DoubleConv,self).__init__()
		self.double_conv = nn.Sequential(
			nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
			nn.BatchNorm2d(out_channels),
			nn.ReLU(inplace=True),
			nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
			nn.BatchNorm2d(out_channels),
			nn.ReLU(inplace=Ture),
		)
	def forward(self,x):
		return self.double_conv(x)
class ResBlock(nn,Module):
	def __init__(self,in_channels,out_channels):
		super(ResBlock,self).__init__()
		self.thinway = nn.Sequnetial(
			nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False),
			nn.BatchNorm2d(out_channels)
		)
		self.double_conv = DoubleConv(in_channels,out_channels)
		self.down_sample = nn.MaxPool2d(2)
		self.relu = nn.ReLU()
	def forward(self):
		identity = self.thinway(x)
		out = self.double_conv(x)
		out = self.relu(out + indentity)
		return self.down_sample(out),out
	

6.attention Blocks

6.1 SElayer

接下来学习一些注意力机制的模块:
最基础的可能就是selayer或者sklayer,
UNet系列网络合集_第1张图片
那么我们可以看着这张图写出来selayer的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class SELayer(nn.Module):
	def __init__(self,channel,reduction=16):
		super(SELayer,self).__init__()
		# 首先是进行全局池化
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		# channel 变为原来的 1/r
		# relu 一下
		# channel 变回原来的值
		# sigmod 一下
		self.fc = nn.Sequential(
			nn.Linear(channel,channel//reduction,bias=False),
			nn.ReLU(inplace=True),
			nn.Linear(channel//reduction,channel,bias,False),
			nn.Sigmoid()
		)
	def forward(self,x):
		batchsize,channels,_,_ = x.size()
		y = self.avg_pool(x).view(batchsize,channels)
		# 全局池化完最后一个图片就是1x1的了
		y = self.fc(y).view(batchsize,channels,1,1)
		# 将最初的结果和y这个权重相乘
		return x * y.expand_as(x)

相当于给每个通道增加了一个权重。

6.2 CBAM Module

UNet系列网络合集_第2张图片
CBAM有两个注意力,一个是通道注意力,另一个是空间注意力。
组合方式: 作者发现顺序组合并且将通道注意力放在前面可以取得更好的效果。而且是先CAM再SAM效果会更好。
代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
# 传统的单层卷积
class BasicConv(nn.Module):
	def __init__(self,in_planes,out_planes,kernel_size,stride = 1,padding = 0,dilation=1,group=1,relu=True,bn=True,bias=False):
		super(BasicConv,self).__init__()
		self.out_channels = out_planes
		self.conv = nn.Conv2d(in_planes,out_planes,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
		self.bn = nn.BatchNorm2d(out_planes,eps=1e-5,momentum=0.01,affine=True) if bn else None
		self.relu = nn.ReLU() if relu else None
	def forward(self,x):
		x = self.conv(x)
		if self.bn is not None:
			x = self.bn(x)
		if self.relu is not None:
			x = self.relu(x)
		return x

# 通道注意力要把整个图变成一个数字
class Flatten(nn.Module):
	def forward(self,x):
		return x.view(x.size(0),-1)

# 还是要先把参数减少的
class ChannelGate(nn.Module):
	def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max']):
		super(ChannelGate,self).__init__()
		self.gate_channels = gate_channels
		self.mlp = nn.Sequential(
			Flatten(),
			nn.Linear(gate_channels,gate_channels//reduction_ratio),
			nn.ReLU(),
			nn.Linear(gate_channels//reduction_ratio,gate_channels),
		)
		self.pool_types = pool_types
		self.avgpool = nn.AdaptiveAvgPool2d(1)
		self.maxpool = nn.AdaptiveMaxPool2d(1)
		self.sigmoid = nn.Sigmoid()
	def forward(self,x):
		channel_att_sum = None
		for pool_type in self.pool_types:
			if pool_type == 'avg':
				avg_pool = self.avgpool(x)
				channel_att_raw = self.mlp(avg_pool)
			elif pool_type == 'max':
				max_pool = self.maxpool(x)
				channel_att_raw = self.mlp(max_pool)
			if channel_att_sum is None:
				channel_att_sum = channel_att_raw
			else:
				channel_att_sum = channel_att_sum + channel_att_raw
		scale = self.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
		return x * scale
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) .unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
	def __init__(self):
		super(SpatialGate,self).__init__()
		kernel_size = 7
		self.compress = ChannelPool()
		self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size -1 ) // 2 ,relu = False)
		self.sigmoid = nn.Sigmoid()
	def forward(self,x):
		x_compress = self.compress(x)
		x_out = self.spatial(x_compress)
		scale = self.sigmoid(x_out)
		return x * scale
class CBAM(nn.Module):
	def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
		super(CBAM,self).__init__()
		self.ChannelGate = ChannelGate(gate_channels,reduction_ratio,pool_types)
		self.no_spatial = no_spatial
		if not no_spatial:
			self.SpatialGate = SpatialGate()
	def forward(self,x):
		x_out = self.ChannelGate(x)
		if not self.no_spatial:
			x_out = self.SpatialGate(x_out)
		return x_out

给大家展示一下每个部分吧:

CAM :通道注意力机制UNet系列网络合集_第3张图片
SAM :空间意力机制
UNet系列网络合集_第4张图片

7.attention_unet

都用到了双卷积块,所以双卷积块可以写一下

class conv_block(nn.Module):
	def __init__(self,in_channels,out_channels):
		super(conv_block,self).__init__()
		self.cov = nn.Sequential(
			nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
			nn.BatchNorm2d(out_channels),
			nn.ReLU(inplace=True),
			nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=True),
			nn.BatchNorm2d(out_channels),
			nn.ReLU(inplace=True)
		)
	def forward(self,x):
		x = self.conv(x)

class Attention_block(nn.Module):
	def __init__(self,F_g,F_l,F_int):
		super(Attention_block,self).__init__()
		self.W_g = nn.Sequential(
			nn.Conv2d(F_g,F_int,kernel_size=1,stride=1,padding=0,bias=True),
			nn.BatchNorm2d(F_int)
		)
		self.W_x = nn.Sequential(
			nn.Conv2d(F_l,F_int,kernel_size=1,stride=1,padding=0,bias=True),
			nn.BatchNorm2d(F_int)
		)
		self.psi = nn.Sequential(
			# 这里的输出通道数应该是1,可能是空间注意力机制
			nn.Conv2d(F_int,1,kernel_size=1,stride=1,padding=0,bias=True),
			nn.BatchNorm2d(1),
			nn.sigmoid()
		)
		self.relu = nn.Relu(inplace=True)
	def forward(self,g,x):
		g1 = self.W_g(g)
		x1 = self.W_x(x)
		psi = self.relu(g1+x1)
		psi = self.psi(psi)

8. SALayer

先分层,然后把分开的组,再把分开的组一半用通道注意力,一半用空间注意力。其他配置与SElayer一样
UNet系列网络合集_第5张图片

class SALayer(nn.Module):
	"""Constructs a Channels Spatial Group module
	
	Args:
		k_size:Adaptive selection of kernel size
		groups : 组的个数
		avg_pool : 全局平均池化
		cweight : 通道的权重,初始为全0的参数
		cbias : 通道的偏移量
		sweight : 形状的权重
		sbias : 形状的偏移量
		GroupNorm : 全局正则化
		sigmod : sigmod函数
	"""
	def __init__(self,channel,groups=2):
		super(SALayer,self).__init__()
		self.groups = groups
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		self.cweight = Parameter(torch.zeros(1,channel//(2*groups),1,1))
		self.cbias = Parameter(torch.ones(1,channel // (2*groups),1,1))
		self.sweight = Parameter(torch.zeros(1,channel // (2*groups),1,1))
		self.sbias = Parameter(torch.ones(1,channel // (2*groups),1,1))	

		self.sigmoid = nn.Sigmoid()
		self.gn = nn.GroupNorm(channel // (2*groups),channel // (2*groups))
	#将所有的切片融合回去
	@staticmethod
	def channel_shuffle(x,groups):
		b,c,h,w = x.shape
		x = x.reshape(b,groups,-1,h,w)
		# 把组放在了第三个位置
		x = x.premute(0,2,1,3,4)

		# flatten
		# 把组又融合回去了因为一开始group放在前面了
		x = x.reshape(b,-1,h,w)
	def forward(self,x):
		b,c,h,w = x.shape
		# 这里的-1应该必然是2了,分组正好是二分组
		x = x.reshape(b*self.groups,-1,h,w)
		x_0,x_1 = x.chunk(2,dim=1)
		
		# channel attention
		xn = self.avg_pool(x_0)
		xn = self.cweight * xn + self.cbias
		xn = x_0 * self.sigmoid(xn)
		
		#spatial attention
		xs = self.gn(x_1)
		xs = self.sweight * xs + self.sbias
		xs = x_1 * self.sigmoid(xs)
		
		#concatenate along channel axis
		out = torch.cat([xn,xs],dim=1)
		out = out.reshape(b,-1,h,w)
		
		out = self.channel_shuffle(out,2)
		return out

9.GC unet

融合了SE Net 以及 non Local block 的GC-NetUNet系列网络合集_第6张图片
代码部分

class GC_block(nn.Moudle):
	def __init__(self,in_channels,out_channels):
		super(GC_block,self).__init__()
		self.conv = nn.Conv2d(in_channels,out_channels,1)
		self.trans = nn.Sequential(
			nn.Linear(out_channels,out_channels,bias=False),
			nn.LayerNorm(out_channels),
			nn.ReLU(inplace=True),
			nn.Linear(out_channels,out_channels,bias=False)
		)
	def forward(self,x):
		x1 = nn.Softmax(self.conv(x))
		y1 = torch.mm(x,x1)
		y2 = self.trans(y1)
		z = x + y2
		return z

10.R2Unet 一直循环卷积的神经网络

UNet系列网络合集_第7张图片

class Recurrent_block(nn.Module):
	def __init__(self,ch_out,t=2):
		super(Recurrent_block,self).__init__()
		self.t = t
		self.ch_out = ch_out
		self.conv = nn.Sequential(
			nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding = 1, bias=True),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
		)
	def forward(self,x):
		for i in range(self.t):
			if i == 0:
				x1 =self.conv(x)
			x1 = self.conv(x + x1)
		return x1

class RRCNN_block(nn.Module):
	def __init__(self,ch_in,ch_out,t=2):
		super(RRCNN_block,self).__init__()
		self.RCNN = nn.Sequential(
			Recurrent_block(ch_out,t=t),
			Recurrent_block(ch_out,t=t)
		)
		self.Conv_1x1 = nn.Conv2d(ch_in, ch_out,kernel_size=1,stride=1,padding=0)
	def forward(self,x):
		x = self.Conv_1x1(x)
		x1 = self.RCNN(x)
		return x + x1

11.CBAM_DCA 模块

DCA模块我还没仔细了解
UNet系列网络合集_第8张图片

import torch.nn as nn
import torch.nn.Parameter import Parameter
import torch
import torch.nn.functional as F

class BasicConv(nn.Module):
	def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,relu=True,bn=True,bias=False):
	super(BasicConv,self).__init__()
	self.out_channels = out_channels
	self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
	self.bn = nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.01,affine=True) if bn else None
	self.relu = nn.ReLU() if relu else None
def forward(self,x):
	x = self.conv(x)
	if self.bn is not None:
		x = self.bn(x)
	if self.relu is not None:
		x = self.relu(x)
	return x

class Flatten(nn.Module):
	def forward(self,x):
		return x.view(x.size(0),-1)

class ChannelGate(nn.Module):
	def __init__(self,in_channels,gate_channels,reduction_ratio=8,
		pool_types = ['avg','max']):
		super(channelGate,self).__init__()
		self.gate_channels = gate_channels
		self.mlp = nn.Sequential(
			Flatten(),
			nn.Linear(gate_channels,gate_channels//reduction_ratio),
			nn.ReLU()
			nn.Linear(gate_channels//reduction_ratio,gate_channels)
		)
		self.pool_types = pool_types
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		self.max_pool = nn.AdaptiveMaxPool2d(1)
		self.sigmoid = nn.Sigmoid()
		# 这里不一样,如果输入通道数不等于注意力块通道数,先将通道数转化为注意力块通道数
		if in_channels != gate_channels:
			self.att_fc = nn.Sequential(
				nn.Conv2d(in_channels,gate_channels,kernel_size=1),
				nn.BatchNorm2d(gate_channels),
				nn.ReLU(inplace=True)
			)
		# self.alpha 是当前注意力块extration提取的特征图的参数
		self.alpha = nn.Sequential(
			nn.Conv2d(2,1,bias=False,kernel_size=1),
			nn.LayerNorm(gate_channels),
			nn.ReLU(inplace=True)
		)
	def forward(self,inputs):
		x = inputs[0]
		b,c,_,_ = x.size()
		pre_att = inputs[1]
		channel_att_sum = None
		if pre_att is not None:
			if hasattr(self,'att_fc'):
				pre_att = self.att_fc(pre_att)
		for pool_type in self.pool_types:
			if pool_type == 'avg':
				avg_pool = self.avgpool(x)
				if pre_att is not None:
					avg_pool = torch.cat((avg_pool.view(b,1,1,c),self.avgpool(pre_att).view(b,1,1,c)),dim=1)
					# 这里用到了self.alpha
					avg_pool = self.alpha(avg_pool).view(b,c)
				channel_att_raw = self.mlp(avg_pool)
			elif pool_type == 'max':
				max_pool = self.maxpool(x)
				if pre_att is not None:
					max_pool = torch.cat((max_pool.view(b,1,1,c),self.maxpool(pre_att).view(b,1,1,c)),dim=1)
					max_pool = self.alpha(max_pool).view(b,c)
				channel_att_raw = self.mlp(max_pool)
			if channel_att_sum is None:
				channel_att_sum = channels_att_raw
			else:
				channel_att_sum = channel_att_sum + channel_att_raw
			scale = self.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
			out = x * scale
			return {0:out,1:out}
class ChannelPool(nn.Module):
	def forward(self,x):
		return torch.cat((torch.max(x,1)[0].unsqueeze(1),torch.mean(x,1).unsqueeze(1)),dim=1)	
class SpatialGate(nn.Module):
	def __init__(self,inchannel,gate_channels):
		super(SpatialGate,self).__init__()
		kernel_size = 7
		self.compress = ChannelPool()
		self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size-1)//2,relu=False)
		self.sigmoid = nn.Sigmoid()
		self.p1 = Parameter(torch.ones(1))
		self.p2 = Parameter(torch.zeros(1))
		self.bnrelu = nn.Sequential(
			nn.BatchNorm2d(2),
			nn.ReLU(inplace=True)	
		)
	def forward(self,x):
		if x[1] is None:
			x_compress = self.compress(x[0])
		else:
			if x[1].size()[2] != x[0].size()[2]:
				exten = (x[1].size()[2])//(x[0].size()[2])
				pre_spatial_att = F.avg_pool2d(x[1],kernel_size=extent,stride=extent)
			else:
				pre_spatial_att = x[1]
			x_compress = self.bnrelu(self.p1*self.compress(x[0])+self.p2*compress(pre_spatial_att))
			x_out = self.spatial(x_compress)
			scale = self.sigmoid(x_out)
			return {0:x[0]*scale,1:x[0]*scale}

class CBAM(nn.Module):
	def __init__(self,in_channel,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
		super(CBAM,self).__init__()
		self.ChannelGate = ChannelGate(in_channel,gate_channels,reduction_ratio,pool_types)
		self.no_spatial = no_spatial
		if not no_spatial:
			self.SpatialGate = SpatialGate(in_channel,gate_channels)
	def forward(self,x):
		#x[0]其实是out, x[1] 实际是x[1]
		x_out = self.ChannelGate({0:x[0],1:x[1]})
		channel_att = x_out[1]
		if not self.no_spatial:
			x_out = self.SpatialGate({0:x_out[0],1:x[2]})
		return {0:x_out[0],1:channel_att,2:x_out[1]}

TA_UNET

UNet系列网络合集_第9张图片
直接来TA_unet

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
	def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,groups=1,relu=True,bn=True,bias=False):
		super(BasicConvs,self).__init__()
		self.out_channels = out_channels
		self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
		self.bn = nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.01,affine=True)
		self.relu = nn.ReLU() if relu else None
	def forward(self,x):
		x = self.conv(x)
		if self.bn is not None:
			x = self.bn(x)
		if self.relu is not None:
			x = self.relu(x)
		return x
class ChannelPool(nn.Module):
	def forward(self,x):
		return torch.cat((torch.max(x,1)[0].unsqueeze(1),torch.mean(x,1).unsqueeze(1)),dim=1)
class SpatialGate(nn.Module):
	def __init__(self):
		super(SpatialGate,self).__init__()
		kernel_size = 7
		self.compress = ChannelPool()
		self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size-1)//2,relu=False)
	def forward(self,x):
		x_compress = self.compress(x)
		x_out = self.spatial(x_compress)
		scale = torch.sigmoid_(x_out)
		return x * scale
class TripletAttention(nn.Module):
	def __init__(self,gate_channels,reduction_ratio=16,pool_types=['avg','max'],no_spatial=False):
			super(TripletAttention,self).__init__()
			self.ChannelGateH = SpatialGate()
			self.ChannelGateW = SpatialGate()
			self.no_spatial = no_spatial
			if not no_spatial:
				self.SpatialGate = SpatialGate()
		def forward(self,x):
			x_perm1 = x.permute(0,2,1,3).contiguous()
			x_out1 = self.ChannelGateH(x_perm1)
			x_out11 = x_out1.permute(0,2,1,3).contiguous()
			x_perm2 = x.permute(0,3,2,1).contiguous()
			x_out2 = self.ChannelGateW(x_perm2)
			x_out21 = x_out2.permute(0,3,2,1).contiguous()
			if not self.no_spatial:
				x_out = self.SpatialGate(x)
				x_out = (1/3)*(x_out + x_out11+x_out21)
			else:
				x_out = (1/2)*(x_out11+x_out21)
			return x_out
		
	

你可能感兴趣的:(医学图像处理,pytorch)