import torch
from torch import nn
# device设定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def squash(x,dim = -1):
squared_norm = (x ** 2).sum(dim = dim,keepdim = True)
scale = squared_norm / (1 + squared_norm)
return scale * x / (squared_norm.sqrt() + 1e-8)
# x = torch.ones([5,2])
# squash(x)
输入:[batch_size,1,28,28]
中间过程:[batch_size,1,28,28]->(256个9x9的卷积核操作&ReLU)->[batch_size,256,20,20]
输出:[batch_size,256,20,20]
本部分代码较简单,在整体架构代码的类中嵌入进去
参数设定: num_conv_units = 32,in_channels = 256,out_channels = 8,kernel_size = 9,stride = 2
输入:[batch_size,256,20,20]
中间过程: [batch_size,256,20,20]->((8x32)个9x9的卷积核,stride = 2)->[batch_size,8x32,6,6]->(reshape&squash)->[batch_size,6x6x32,8]
含义: 生成6x6x32个dim为8的胶囊,作为下一层的输入。
输出: [batch_size,6x6x32,8]
class PrimaryCaps(nn.Module):
def __init__(self,num_conv_units,in_channels,out_channels,kernel_size,stride):
super(PrimaryCaps,self).__init__()
self.conv = nn.Conv2d(in_channels = in_channels,
out_channels = out_channels * num_conv_units,
kernel_size = kernel_size,
stride = stride)
self.out_channels = out_channels
def forward(self,x):
out = self.conv(x) # out:[batch_size,out_channels * num_conv_units,6,6]
batch_size = out.shape[0]
# output:[batch_size,out_capsules * height * weight,out_channels]
return squash(out.contiguous().view(batch_size,-1,self.out_channels),dim = -1)
动态路由过程:
参数设定: in_dim = 8,in_caps = 32x6x6,out_caps = 10,out_dim = 16,num_routing = 3,W.shape = [16(out_caps),10(in_caps),16(out_dim),8(in_dim)]
输入: [batch_size,6x6x32(in_caps),8(in_dim)]
中间过程: [batch_size,6x6x32,8]->(unsqueeze)->[batch_size,1,6x6x32,8,1]->(W线性变换)->[batch_size,10(out_caps),6x6x32(in_caps),16(out_dim),1]->(squeeze&detach)->[batch_size,10,6x6x32,16]->(num_routing=3的动态路由&squash)->[batch_size,10,16]
class DigitCaps(nn.Module):
def __init__(self,in_dim,in_caps,out_caps,out_dim,num_routing):
# in_dim:Dimensionality of each capsule vector
# in_caps:Number of imput capsules if digits layer
# out_caps:Number of capsules in the capsule vector
# out_dim:Dimensionality,of the output capsule vector
# num_routing:Number of iterations during routing algorithm
super(DigitCaps,self).__init__()
self.in_dim = in_dim
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dim = out_dim
self.num_routing = num_routing
self.device = device
# W:[1,out_caps,in_caps,out_dim,in_dim]
self.W = nn.Parameter(0.01 * torch.randn(1,out_caps,in_caps,out_dim,in_dim),requires_grad = True)
def forward(self,x):
batch_size = x.size(0)
x = x.unsqueeze(1).unsqueeze(4) # [batch_size,1,in_caps,in_dim,1]
u_hat = torch.matmul(self.W,x) # [batch_size,out_caps,in_caps,out_dims,1]
u_hat = u_hat.squeeze(-1) # [batch_size,out_caps,in_caps,out_dims]
temp_u_hat = u_hat.detach() # [batch_size,out_caps,in_caps,out_dims]
b = torch.zeros(batch_size,self.out_caps,self.in_caps,1).to(device) # [batch_size,out_caps,in_caps,1]
for route_iter in range(self.num_routing - 1):
c = b.softmax(dim = 1) # [batch_size,out_caps,in_caps,1]
# [batch_size,out_caps,in_caps,1] .* [batch_size,out_caps,in_caps,out_dims]
# ->[batch_size,out_caps,in_caps,out_dims] ->(sum.dim=2)->[batch_size,out_caps,out_dims]
s = (c * temp_u_hat).sum(dim = 2) # [batch_size,out_caps,out_dims]
v = squash(s) # [batch_size,out_caps,out_dims]
# [batch_size,out_caps,in_caps,out_dims] * [batch_size,out_caps,out_dims,1]
# -> [batch_size,out_caps,in_caps,1] 含义:v与每个incap的相似度
uv = torch.matmul(temp_u_hat,v.unsqueeze(-1))
b += uv # [batch_size,out_caps,in_caps,1]
c = b.softmax(dim = 1)
# [batch_size,out_caps,in_caps,1] .* [batch_size,out_caps,in_caps,out_dims]
# ->sum.dim = 2->[batch_size,out_caps,out_dims]
s = (c * u_hat).sum(dim = 2)
v = squash(s) # [batch_size,out_caps,out_dims]
return v
输入:[batch_size,10(out_caps),16(out_dim)]
输出:[batch_size,784]
含义:将特征向量还原为图片
本部分代码较简单,在整体架构代码的类中嵌入进去
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet,self).__init__()
# Conv2d layer
self.conv = nn.Conv2d(1,256,9)
self.relu = nn.ReLU(inplace = True)
# Primary capsule
self.primary_caps = PrimaryCaps(num_conv_units = 32,in_channels = 256,out_channels = 8,kernel_size = 9,stride = 2)
# Digit capsule
self.digit_caps = DigitCaps(in_dim = 8,in_caps = 32 * 6 * 6,out_caps = 10,out_dim = 16,num_routing = 3)
# Reconstruction layer
self.decoder = nn.Sequential(
nn.Linear(16 * 10,512),
nn.ReLU(inplace = True),
nn.Linear(512,1024),
nn.ReLU(inplace = True),
nn.Linear(1024,784),
nn.Sigmoid()
)
def forward(self,x):
out = self.relu(self.conv(x)) # [batch_size,in_channels,20,20] 20 = 28 - 9 + 1
out = self.primary_caps(out) # [batch_size,out_capsules*height*weight,out_channels]
out = self.digit_caps(out) # [batch_size,out_caps,out_dim]
logits = torch.norm(out,dim = -1) # [batch_size,out_caps]
# [batch_size,out_caps]
pred = torch.eye(10).to(device).index_select(dim = 0,index = torch.argmax(logits,dim = 1))
# Reconstruction
batch_size = out.shape[0]
# (out * pred.unsqueeze(2)):[batch_size,out_caps,out_dim]->view->[batch_size,out_caps*out_dim]
# reconstruction:[batch_size,784]
reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size,-1))
return logits,reconstruction
class CapsuleLoss(nn.Module):
def __init__(self,upper_bound = 0.9,lower_bound = 0.1,lmda = 0.5):
super(CapsuleLoss,self).__init__()
self.upper = upper_bound
self.lower = lower_bound
self.lmda = lmda
self.reconstruction_loss_scalar = 5e-4
self.mse = nn.MSELoss(reduction = 'sum')
def forward(self,images,labels,logits,reconstructions):
left = (self.upper - logits).relu() ** 2
right = (logits - self.lower).relu() ** 2
margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right)
# Reconstruction loss
reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape),images)
return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss
论文链接:https://arxiv.org/abs/2006.04768
代码链接:https://github.com/Riroaki/CapsNet