CCNet(Criss Cross Network)的核心是重复十字交叉注意力模块。该模块通过两次CC Attention,可以实现目标特征像素点与特征图中其他所有点之间的相互关系,并用这样的相互关系对目标像素点的特征进行加权,以此获得更加有效的目标特征。
还是基于***self attention***的思路,使用Q和K向量来确定权重,再与V值取权重和。
H: (batch,c1,h,w) #输入特征图
Q: (batch,c2,h,w) #Query查询向量
K: (batch,c1,h,w) #Key 键值向量
V: (batch,c1,h,w) #Value 值向量
暂先不考虑Batch
上面已经获取了Atten — (batch, h, w, h + w - 1)。接下来将权重施加在V上。 先不考虑batch
ICCV2019语义分割文章CCNet详解 该文章有很好的动图展示,有助理解。
一个CCAttention,只能获取当前位置上同一行和一列的信息,如果两个叠加两个CCAttention,就可以获取全局信息。
如目标像素点是(Ux,Uy) ,想要获取(θx, θy)的关系。
loop1: 像素点(Ux, θy)和 (θx, Uy)通过一次CCAtten 可以建立与(θx, θy)的关系;
loop2: 像素点是(Ux,Uy) 通过CCAtten可以获取与 像素点(Ux, θy)和 (θx, Uy)的联系,从而间接取得与(θx, θy)的联系。
即通过两次CCAtten , 可以建立目前像素点与任意像素点的信息融合。
虽然上面给出了对CCAttention 的理论逻辑的理解,但如果按照该逻辑使用循环进行代码设计,很占计算资源,计算速度也会很慢,而且反向传播也不易做到。我有查看网上的一些实现(包括官方源码),似乎需要自定义的cuda算子,这个扩展性并不友好。
https://github.com/speedinghzl/CCNet
参考: https://github.com/yearing1017/CCNet_PyTorch ,我在这里添加了详细的注释,有助理解更透彻。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def INF(B,H,W):
'''
生成(B*W,H,H)大小的对角线为inf的三维矩阵
Parameters
----------
B: batch
H: height
W: width
'''
return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CC_module(nn.Module):
def __init__(self,in_dim, device):
'''
Parameters
----------
in_dim : int
channels of input
'''
super(CC_module, self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1)).to(device)
self.device = device
def forward(self, x):
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x) #size = (b,c2,h,w), c1 = in_dim, c2 = c1 // 8
#size = (b*w, h, c2)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
#size = (b*h, w, c2)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x) #size = (b,c2,h,w)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) #size = (b*w,c2,h)
proj_value = self.value_conv(x) #size = (b,c1,h,w)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) #size = (b*w,c1,h)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) #size = (b*h,c1,w)
#size = (b*w, h,h) ,其中[:,i,j]表示Q所有W的第Hi行的所有通道值与K上所有W的第Hj列的所有通道值的向量乘积
energy_H = torch.bmm(proj_query_H, proj_key_H)
#size = (b,h,w,h) #这里为什么加 INF并没有理解
energy_H = (energy_H + self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
#size = (b*h,w,w),其中[:,i,j]表示Q所有H的第Wi行的所有通道值与K上所有H的第Wj列的所有通道值的向量乘积
energy_W = torch.bmm(proj_query_W, proj_key_W)
energy_W = energy_W.view(m_batchsize,height,width,width) #size = (b,h,w,w)
concate = self.softmax(torch.cat([energy_H, energy_W], 3)) #size = (b,h,w,h+w) #softmax归一化
#concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float()
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) #size = (b*w,h,h)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) #size = (b*h,w,w)
#size = (b*w,c1,h) #[:,i,j]表示V所有W的第Ci行通道上的所有H 与att_H的所有W的第Hj列的h权重的乘积
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1))
out_H = out_H.view(m_batchsize,width,-1,height).permute(0,2,3,1) #size = (b,c1,h,w)
#size = (b*h,c1,w) #[:,i,j]表示V所有H的第Ci行通道上的所有W 与att_W的所有H的第Wj列的W权重的乘积
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1))
out_W = out_W.view(m_batchsize,height,-1,width).permute(0,2,1,3) #size = (b,c1,h,w)
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x
if __name__ == '__main__':
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')
model = CC_module(8,device)
x = torch.randn(4, 8, 20, 20).to(device)
out = model(x).(device)
print(out.shape)
需要注意的是: 这里的self-attention, 主要是空间注意力,并没有涉及通道注意力。
Axial Attention,即轴向注意力。之前关注到Axial Attention,是在谷歌的天气预报模型Metnet中有使用到。
于是去翻看了Axial Attention注意力。
这里不谈他后续怎么设计的***Transformer***的,单纯的Axial Attention来看,其实有点类似之前的CC-Attention。
具体的思路也还是self-attention。理论实现方法与CC-Attention大同小异,这里就不赘述了。
这里先根据个人理解,给出Axial Attention中的 Row-Attention 和 Col-Attention。
#实现轴向注意力中的 row Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')
class RowAttention(nn.Module):
def __init__(self, in_dim, q_k_dim, device):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(RowAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.device = device
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
self.softmax = Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)
def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''
## c1 = in_dims; c2 = q_k_dim
b, _, h, w = x.size()
Q = self.query_conv(x) #size = (b,c2, h,w)
K = self.key_conv(x) #size = (b, c2, h, w)
V = self.value_conv(x) #size = (b, c1,h,w)
Q = Q.permute(0,2,1,3).contiguous().view(b*h, -1,w).permute(0,2,1) #size = (b*h,w,c2)
K = K.permute(0,2,1,3).contiguous().view(b*h, -1,w) #size = (b*h,c2,w)
V = V.permute(0,2,1,3).contiguous().view(b*h, -1,w) #size = (b*h, c1,w)
#size = (b*h,w,w) [:,i,j] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有h的第 Wj列位置上的所有通道值的乘积,
# 即(1,c2) * (c2,1) = (1,1)
row_attn = torch.bmm(Q,K)
########
#此时的 row_atten的[:,i,0:w] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有行的 所有列(0:w)的逐个位置上的所有通道值的乘积
#此操作即为 Q的某个(i,j)与 K的(i,0:w)逐个位置的值的乘积,得到行attn
########
#对row_attn进行softmax
row_attn = self.softmax(row_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,
#size = (b*h,c1,w) 这里先需要对row_atten进行 行列置换,使得某一列的所有行加起来等于1
#[:,i,j]即为V的所有行的某个通道上,所有列的值 与 row_attn的行的乘积,即求权重和
out = torch.bmm(V,row_attn.permute(0,2,1))
#size = (b,c1,h,2)
out = out.view(b,h,-1,w).permute(0,2,1,3)
out = self.gamma*out + x
return out
#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
row_attn = RowAttention(in_dim = 8, q_k_dim = 4,device = device).to(device)
print(row_attn(x).size())
#实现轴向注意力中的 column Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')
class ColAttention(nn.Module):
def __init__(self, in_dim, q_k_dim, device):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(ColAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.device = device
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
self.softmax = Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)
def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''
## c1 = in_dims; c2 = q_k_dim
b, _, h, w = x.size()
Q = self.query_conv(x) #size = (b,c2, h,w)
K = self.key_conv(x) #size = (b, c2, h, w)
V = self.value_conv(x) #size = (b, c1,h,w)
Q = Q.permute(0,3,1,2).contiguous().view(b*w, -1,h).permute(0,2,1) #size = (b*w,h,c2)
K = K.permute(0,3,1,2).contiguous().view(b*w, -1,h) #size = (b*w,c2,h)
V = V.permute(0,3,1,2).contiguous().view(b*w, -1,h) #size = (b*w,c1,h)
#size = (b*w,h,h) [:,i,j] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的第 Hj列位置上的所有通道值的乘积,
# 即(1,c2) * (c2,1) = (1,1)
col_attn = torch.bmm(Q,K)
########
#此时的 col_atten的[:,i,0:w] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的 所有列(0:h)的逐个位置上的所有通道值的乘积
#此操作即为 Q的某个(i,j)与 K的(i,0:h)逐个位置的值的乘积,得到列attn
########
#对row_attn进行softmax
col_attn = self.softmax(col_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,
#size = (b*w,c1,h) 这里先需要对col_atten进行 行列置换,使得某一列的所有行加起来等于1
#[:,i,j]即为V的所有行的某个通道上,所有列的值 与 col_attn的行的乘积,即求权重和
out = torch.bmm(V,col_attn.permute(0,2,1))
#size = (b,c1,h,w)
out = out.view(b,w,-1,h).permute(0,2,3,1)
out = self.gamma*out + x
return out
#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
col_attn = ColAttention(8, 4, device = device)
print(col_attn(x).size())
单独使用Row Atten(或者Col Attention),即使是堆叠好几次,也是无法融合全局信息的。一般来说,Row Attention 和 Col Attention要组合起来使用才能更好的融合全局信息。
建议方式:
这样的一次 out 类似于 一次CCAtten(x) 操作 。
所以一般至少需要迭代两次上述的任意方法,才能融合到全局信息。
github上有人开源了Axial Attention,并且灵活度很高。 https://github.com/lucidrains/axial-attention ,直接安装使用即可。
pip install axial_attention #安装
#Image
import torch
from axial_attention import AxialAttention
img = torch.randn(1, 3, 256, 256)
attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
)
img_attn = attn(img)
print(img_attn.size())
#%%
#Channel-last image latents
import torch
from axial_attention import AxialAttention
img = torch.randn(1, 20, 20, 512)
attn = AxialAttention(
dim = 512, # embedding dimension
dim_index = -1, # where is the embedding dimension
heads = 8, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
)
img_attn = attn(img)
print(img_attn.size())
#%%
#Video
import torch
from axial_attention import AxialAttention
video = torch.randn(1, 5, 10, 20, 20)
attn = AxialAttention(
dim = 10, # embedding dimension
dim_index = 2, # where is the embedding dimension
heads = 5, # number of heads for multi-head attention
num_dimensions = 3, # number of axial dimensions (images is 2, video is 3, or more)
)
video_atten = attn(video)
print(video_atten.size())
#%%
# Image Transformer, with reversible network
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn.Conv2d(3, 128, 1)
transformer = AxialImageTransformer(
dim = 128,
depth = 12,
reversible = True
)
img = torch.randn(1, 3, 20, 20)
img1 = transformer(conv1x1(img))
print(img1.size())
目前这个源码还没有啃清楚,有明白的同学欢迎交流。
ICCV2019语义分割文章CCNet详解
CCNet: Criss-Cross Attention for Semantic Segmentation论文解读
https://github.com/yearing1017/CCNet_PyTorch/tree/master/CCNet
https://github.com/speedinghzl/CCNet
论文:CCNet: Criss-Cross Attention for Semantic Segmentation
https://github.com/lucidrains/axial-attention
Axial Attention in Multidimensional Transformers
MetNet: A Neural Weather Model for Precipitation Forecasting