Swin-Transformer是一个新的视觉Transformer,Swin=shift+window(移动窗口),可以作为计算机视觉的基础架构backbone,swin-transformer提供了三种特性,第一是层级的结构,第二是将自注意限制在大小一定的窗口里,这样就可以将自注意力的复杂度跟图片大小呈线性关系,,第三是通过移动窗口的形式进行每个window之间的信息交流
import torch
from torch import nn
from torch.nn import functional
import math
def imag2emb_naive(image,patch_size,wweight):
"""直观方法去实现patch_embedding"""
# image.shape=[bs,channel,h,w]
patch = F.unfold(image,kernel_size=(patch_size,patch_size),
stride=(patch_size,patch_size)).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
*
patch_size*
patch_sizedef image2emb_conv(image,kernel,stride):
# bs*oc*oh*ow
conv_output = F.conv2d(image,kernel,stride=stride)
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
return patch_embedding
import torch
from torch.nn import functional as F
# method_1 : using unfold to achieve the patch_embedding
# step_1: unfold the image
# step_2: unfold_output@weight
def image2embed_naive(image, patch_size, weight):
"""
:param image: [bs,in_channel,height,width]
:param patch_size:
:param weight : weight.shape=[patch_depth=in_channel*patch_size*patch_size,model_dim_C]
:return: patch_embedding,it shape is [batch_size,num_patches,model_dim_C]
"""
# patch_depth = in_channel*patch_size*patch_size
# image_output.shape = [batch_size,num_patch,patch_depth=in_channel*patch_size*patch_size]
image_output = F.unfold(image, kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size)).transpose(-1, -2)
# change the final_channel dimension from patch_depth to model_dim_C
patch_embedding = image_output @ weight
return patch_embedding
# using F.conv2d to achieve the patch_embedding
def image2embed_conv(image, weight, patch_size):
# image =[batch_size,in_channel,height,width]
# weight = [out_channels,in_channels,kernel_h,kernel_w]
conv_output = F.conv2d(image, weight=weight, stride=patch_size)
bs, oc, oh, ow = conv_output.shape
patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1,-2)
return patch_embedding
batch_size = 1
in_channel = 2
out_channel = 5
height = 3
width = 4
input = torch.randn(batch_size, in_channel, height, width)
patch_size = 2
weight1_depth = in_channel * patch_size * patch_size
weight1_model_c = out_channel
weight1 = torch.randn(weight1_depth,weight1_model_c)
weight2_out_channel = weight1_model_c
weight2 = weight1.transpose(0,1).reshape(weight1_model_c,in_channel,patch_size,patch_size)
output1 = image2embed_naive(input, patch_size, weight1)
output2 = image2embed_conv(input, weight2, patch_size)
# flag the check output1 is the same for output2
# if flag is true ,they are the same
flag = torch.isclose(output1,output2)
print(f"flag={flag}")
print(f"output1={output1}")
print(f"output2={output2}")
print(f"output1.shape={output1.shape}")
print(f"output2.shape={output2.shape}")
# 输出结果
# flag=tensor([[[True, True, True, True, True],
# [True, True, True, True, True]]])
# output1=tensor([[[ -0.3182, -2.0556, -0.4092, 0.8453, 3.8825],
# [ 4.1530, -2.4645, -0.8912, 3.9692, -11.5213]]])
#output2=tensor([[[ -0.3182, -2.0556, -0.4092, 0.8453, 3.8825],
# [ 4.1530, -2.4645, -0.8912, 3.9692, -11.5213]]])
#output1.shape=torch.Size([1, 2, 5])
#output2.shape=torch.Size([1, 2, 5])
from torch import nn
class MultiHeadSelfAttention(nn.Module):
def __init__(self,model_dim,num_head):
super(MultiHeadSelfAttention, self).__init__()
self.num_head=num_head
self.proj_linear_layer = nn.Linear(model_dim,3*model_dim)
self.final_linear_layer = nn.Linear(model_dim,model_dim)
def forward(self,input,additive_mask=None):
bs,seqlen,model_dim = input.shape
num_head = self.num_head
head_dim = model_dim//num_head
# proj_output=[bs,seqlen,3*model_dim]
proj_output = self.proj_linear_layer(input)
# 3*[bs,seqlen,model_dim]
q,k,v = proj_output.chunk(3,dim=-1)
q = q.reshape(bs,seqlen,num_head,head_dim).transpose(1,2)
# q=[bs,num_head,seqlen,head_dim]
q = q.reshape(bs*num_head,seqlen,head_dim)
# k=[bs,num_head,seqlen,head_dim]
k = k.reshape(bs,seqlen,num_head,head_dim).transpose(1,2)
k = k.reshape(bs*num_head,seqlen,head_dim)
v = v.reshape(bs,seqlen,num_head,head_dim).transpose(1,2)
v = v.reshape(bs*num_head,seqlen,head_dim)
if additive_mask is None:
attn_prob = F.softmax(torch.bmm(q,k.transpose(-1,-2))/math.sqrt(head_dim),dim=-1)
else:
additive_mask = additive_mask.tile(num_head,1,1)
attn_prob = F.softmax(torch.bmm(q,k.transpose(-2,-1))/math.sqrt(head_dim)+additive_mask,dim=-1)
output = torch.bmm(attn_prob,v)
output = output.reshape(bs,num_head,seqlen,head_dim).transpose(1,2)
output = output.reshape(bs,seqlen,model_dim)
return attn_prob,output
# 基于windows的多头自注意力
def window_multi_head_self_attention(patch_embedding,mhsa,window_size=4,num_head=2):
# 定义有多少个patch
num_patch_in_window=window_size*window_size
# 得到相关大小参数
bs,num_patch,patch_depth = patch_embedding.shape
# 将三维拆分成四维数据,一般图片的高宽可以冲num_patch里面拆分得到
image_height =image_width = int(math.sqrt(num_patch))
#[bs,num_patch,patch_depth] -> [bs,patch_depth,num_patch]
patch_embedding = patch_embedding.transpose(-1,-2)
# [bs,patch_depth,num_patch] -> [bs,patch_depth,image_height,image_width]
patch = patch_embedding.reshape(bs,patch_depth,image_height,image_width)
# 经过卷积中的卷得到window后,再将最后一维和倒数第二维度进行切换
# window.shape=[bs,windows_depth,num_window] ->[bs,num_window,windows_depth]
window = F.unfold(patch,kernel_size=(window_size,window_size),
stride=(window_size,window_size)).transpose(-1,-2)
bs,num_window,patch_depth_times_num_patch_in_window = window.shape
window = window.reshape(bs*num_window,patch_depth,num_patch_in_window).transpose(-1,-2)
attn_prob,output = mhsa(window)
output = output.reshape(bs,num_window,num_patch_in_window,patch_depth)
return output
def window2image(msa_output):
bs,num_window,num_patch_in_window,patch_depth=msa_output.shape
window_size=int(math.sqrt(num_patch_in_window))
image_height = int(math.sqrt(num_window))*window_size
image_width = image_height
msa_output = msa_output.reshape(bs,int(math.sqrt(num_window)),
int(math.sqrt(num_window)),
window_size,
window_size,
patch_depth)
msa_output = msa_output.transpose(2,3)
image = msa_output.reshape(bs,image_height*image_width,patch_depth)
image = image.transpose(-1,-2).reshape(bs,patch_depth,image_height,image_width)
return image
def shift_window(w_msa_output,window_size,shift_size,generate_mask=False):
bs,num_window,num_patch_in_window,patch_depth=w_msa_output.shape
w_msa_output = window2image(w_msa_output)
bs,patch_depth,image_height,image_width = w_msa_output.shape
rolled_w_msa_output = torch.roll(w_msa_output,shifts=(shift_size,shift_size),dims=(2,3))
shifted_w_msa_input = rolled_w_msa_output.reshape(bs,patch_depth,
int(math.sqrt(num_window)),
window_size,
int(math.sqrt(num_window)),
window_size)
shifted_w_msa_input = shifted_w_msa_input.transpose(3,4)
shifted_w_msa_input = shifted_w_msa_input.reshape(bs,patch_depth,num_window*num_patch_in_window)
shifted_w_msa_input = shifted_w_msa_input.transpose(-1,-2)
shifted_window = shifted_w_msa_input.reshape(bs,num_window,num_patch_in_window,patch_depth)
if generate_mask:
additive_mask = build_mask_for_shifted_wmsa(bs,image_height,image_width,window_size)
else:
additive_mask = None
return shifted_window,additive_mask
def build_mask_for_shifted_wmsa(batch_size,image_height,image_width,window_size):
index_matrix = torch.zeros(image_height,image_width)
for i in range(image_height):
for j in range(image_width):
row_times = (i+window_size//2)//window_size
col_times = (j+window_size//2)//window_size
index_matrix[i,j] = row_times*(image_height//window_size)+col_times+1
rolled_index_matrix = torch.roll(index_matrix,shifts=(-window_size//2,-window_size//2),dim=(0,1))
rolled_index_matrix = rolled_index_matrix.unsqueeze(0).unsqueeze(0)
c = F.unfold(rolled_index_matrix,kernel_size=(window_size,window_size),
stride=(window_size,window_size)).transpose(-1,-2)
c = c.tile(batch_size,1,1)
bs,num_window,num_patch_in_window = c.shape
c1 = c.unsqueeze(-1)
c2=(c1-c1.transpose(-1,-2)) == 0
valid_matrix = c2.to(torch.float32)
additive_mask = (1-valid_matrix)*(-1e-9)
additive_mask = additive_mask.reshape(bs*num_window,num_patch_in_window,num_patch_in_window)
return additive_mask
def shift_window_multi_head_self_attention(w_msa_output,mhsa,window_size=4,num_head=2):
bs,num_window,num_patch_in_window,patch_depth = w_msa_output.shape
shifted_w_msa_input,additive_mask = shift_window(w_msa_output,window_size,
shift_size=-window_size//2,
generate_mask=True)
shifted_w_msa_input = shifted_w_msa_input.reshape(bs*num_window,num_patch_in_window,patch_depth)
attn_prob,output = mhsa(shifted_w_msa_input,additive_mask=additive_mask)
output = output.reshape(bs,num_window,num_patch_in_window,patch_depth)
output,_ = shift_window(output,window_size,shift_size=window_size//2,generate_mask=False)
return output
class PatchMerging(nn.Module):
def __init__(self,model_dim,merge_size,output_depth_scale=0.5):
super(PatchMerging,self).__init__()
self.merge_size = merge_size
self.proj_layer= nn.Linear(
model_dim*merge_size*merge_size,
int(model_dim*merge_size*merge_size*output_depth_scale))
def forward(self,input):
bs,num_window,num_patch_in_window,patch_depth = input.shape
window_size = int(math.sqrt(num_patch_in_window))
input = window2image(input)
merged_window = F.unfold(input,kernel_size=(self.merge_size,self.merge_size),
stride=(self.merge_size,self.merge_size)).transpose(-1,-2)
merge_window = self.proj_layer(merged_window)
return merged_window
class SwinTransformerBlock(nn.Module):
def __init__(self,model_dim,window_size,num_head):
super(SwinTransformerBlock, self).__init__()
self.layer_norm1 = nn.LayerNorm(model_dim)
self.layer_norm2 = nn.LayerNorm(model_dim)
self.layer_norm3 = nn.LayerNorm(model_dim)
self.layer_norm4 = nn.LayerNorm(model_dim)
self.wsma_mlp1 = nn.Linear(model_dim,4*model_dim)
self.wsma_mlp2 = nn.Linear(4*model_dim,model_dim)
self.swsma_mlp1 = nn.Linear(model_dim,4*model_dim)
self.swsma_mlp2 = nn.Linear(4*model_dim,model_dim)
self.mhsa1 = MultiHeadSelfAttention(model_dim,num_head)
self.mhsa2 = MultiHeadSelfAttention(model_dim,num_head)
def forward(self,input):
bs,num_patch,patch_depth = input.shape
input1 = self.layer_norm1(input)
w_msa_output = window_multi_head_self_attention(input,self.mhsa1,window_size=4,num_head=2)
bs,num_window,num_patch_in_window,patch_depth = w_msa_output.shape
w_msa_output=input+w_msa_output.reshape(bs,num_patch,patch_depth)
output1 = self.wsma_mlp2(self.wsma_mlp1(self.layer_norm2(w_msa_output)))
output1 = w_msa_output
input2 = self.layer_norm3(output1)
input2 = input2.reshape(bs,num_window,num_patch_in_window,patch_depth)
sw_msa_output = shift_window_multi_head_self_attention(input2,self.mhsa2,window_size=4,num_head=2)
sw_msa_output=output1+sw_msa_output.reshape(bs,num_patch,patch_depth)
output2 = self.swsma_mlp2(self.swsma_mlp1(self.layer_norm4(sw_msa_output)))
output2 +=sw_msa_output
output2 = output2.reshape(bs,num_window,num_patch_in_window,patch_depth)
return output2
class SwinTransformerModel(nn.Module):
def __init__(self,input_image_channel=1,patch_size=4,model_dim_C=8,num_classes=10,
window_size=4,num_head=2,merge_size=2):
super(SwinTransformerModel, self).__init__()
patch_depth = patch_size*patch_size*input_image_channel
self.patch_size = patch_size
self.model_dim_C=model_dim_C
self.num_classes = num_classes
self.patch_embedding_weight = nn.Parameter(torch.randn(patch_depth,model_dim_C))
self.block1 = SwinTransformerBlock(model_dim_C,window_size,num_head)
self.block2 = SwinTransformerBlock(model_dim_C*2,window_size,num_head)
self.block3 = SwinTransformerBlock(model_dim_C*4,window_size,num_head)
self.block4 = SwinTransformerBlock(model_dim_C*8,window_size,num_head)
self.patch_merging1 = PatchMerging(model_dim_C,merge_size)
self.patch_merging2 = PatchMerging(model_dim_C*2,merge_size)
self.patch_merging3 = PatchMerging(model_dim_C*4,merge_size)
self.final_layer = nn.Linear(model_dim_C*8,num_classes)
def forward(self,image):
patch_embedding_naive = image2embed_naive(image,self.patch_size,self.patch_embedding_weight)
# block1
patch_embedding = patch_embedding_naive
print(patch_embedding.shape)
sw_msa_output = self.block1(patch_embedding)
print("block1_output",sw_msa_output.shape)
merged_patch1=self.patch_merging1(sw_msa_output)
sw_msa_output1 = self.block2(merged_patch1)
print("block2_output",sw_msa_output1.shape)
merged_patch2 = self.patch_merging2(sw_msa_output1)
sw_msa_output2 = self.block3(merged_patch2)
print("block3_output",sw_msa_output2.shape)
merged_patch3 = self.patch_merging3(sw_msa_output2)
sw_msa_output3=self.block4(merged_patch3)
print("block4_output",sw_msa_output3.shape)
bs,num_window,num_patch_in_window,patch_depth = sw_msa_output3.shape
sw_msa_output3=sw_msa_output3.reshape(bs,-1,patch_depth)
pool_output = torch.mean(sw_msa_output3,dim=1)
logits = self.final_layer(pool_output)
print("logits",logits.shape)
return logits