小黑胶囊网络笔记:CapsNet公式与代码对应

1.整体框架图

小黑胶囊网络笔记:CapsNet公式与代码对应_第1张图片

2.代码细节

import torch
from torch import nn

# device设定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(1)squash操作

小黑胶囊网络笔记:CapsNet公式与代码对应_第2张图片

def squash(x,dim = -1):
    squared_norm = (x ** 2).sum(dim = dim,keepdim = True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / (squared_norm.sqrt() + 1e-8)
# x = torch.ones([5,2])
# squash(x)

(2)ReLU Conv1层

在这里插入图片描述
小黑胶囊网络笔记:CapsNet公式与代码对应_第3张图片

输入:[batch_size,1,28,28]
中间过程:[batch_size,1,28,28]->(256个9x9的卷积核操作&ReLU)->[batch_size,256,20,20]
输出:[batch_size,256,20,20]

本部分代码较简单,在整体架构代码的类中嵌入进去

(3) PrimaryCaps层(32个胶囊)

小黑胶囊网络笔记:CapsNet公式与代码对应_第4张图片
在这里插入图片描述
在这里插入图片描述
参数设定: num_conv_units = 32,in_channels = 256,out_channels = 8,kernel_size = 9,stride = 2
输入:[batch_size,256,20,20]
中间过程: [batch_size,256,20,20]->((8x32)个9x9的卷积核,stride = 2)->[batch_size,8x32,6,6]->(reshape&squash)->[batch_size,6x6x32,8]
含义: 生成6x6x32个dim为8的胶囊,作为下一层的输入。
输出: [batch_size,6x6x32,8]

class PrimaryCaps(nn.Module):
    def __init__(self,num_conv_units,in_channels,out_channels,kernel_size,stride):
        super(PrimaryCaps,self).__init__()
        
        self.conv = nn.Conv2d(in_channels = in_channels,
                              out_channels = out_channels * num_conv_units,
                              kernel_size = kernel_size,
                              stride = stride)
        self.out_channels = out_channels
    def forward(self,x):
        out = self.conv(x)    # out:[batch_size,out_channels * num_conv_units,6,6]
        batch_size = out.shape[0]
        # output:[batch_size,out_capsules * height * weight,out_channels]
        return squash(out.contiguous().view(batch_size,-1,self.out_channels),dim = -1)

(4) DigitCaps层(10个胶囊)

小黑胶囊网络笔记:CapsNet公式与代码对应_第5张图片

小黑胶囊网络笔记:CapsNet公式与代码对应_第6张图片
小黑胶囊网络笔记:CapsNet公式与代码对应_第7张图片
动态路由过程:
小黑胶囊网络笔记:CapsNet公式与代码对应_第8张图片
参数设定: in_dim = 8,in_caps = 32x6x6,out_caps = 10,out_dim = 16,num_routing = 3,W.shape = [16(out_caps),10(in_caps),16(out_dim),8(in_dim)]
输入: [batch_size,6x6x32(in_caps),8(in_dim)]
中间过程: [batch_size,6x6x32,8]->(unsqueeze)->[batch_size,1,6x6x32,8,1]->(W线性变换)->[batch_size,10(out_caps),6x6x32(in_caps),16(out_dim),1]->(squeeze&detach)->[batch_size,10,6x6x32,16]->(num_routing=3的动态路由&squash)->[batch_size,10,16]

class DigitCaps(nn.Module):
    def __init__(self,in_dim,in_caps,out_caps,out_dim,num_routing):
        # in_dim:Dimensionality of each capsule vector
        # in_caps:Number of imput capsules if digits layer
        # out_caps:Number of capsules in the capsule vector
        # out_dim:Dimensionality,of the output capsule vector
        # num_routing:Number of iterations during routing algorithm
        super(DigitCaps,self).__init__()
        self.in_dim = in_dim
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dim = out_dim
        self.num_routing = num_routing
        self.device = device
        # W:[1,out_caps,in_caps,out_dim,in_dim]
        self.W = nn.Parameter(0.01 * torch.randn(1,out_caps,in_caps,out_dim,in_dim),requires_grad = True)
    def forward(self,x):
        batch_size = x.size(0)
        x = x.unsqueeze(1).unsqueeze(4)    # [batch_size,1,in_caps,in_dim,1]
        u_hat = torch.matmul(self.W,x)    # [batch_size,out_caps,in_caps,out_dims,1]
        u_hat = u_hat.squeeze(-1)    # [batch_size,out_caps,in_caps,out_dims]
        temp_u_hat = u_hat.detach()    # [batch_size,out_caps,in_caps,out_dims] 
        b = torch.zeros(batch_size,self.out_caps,self.in_caps,1).to(device)    # [batch_size,out_caps,in_caps,1]
        for route_iter in range(self.num_routing - 1):
            c = b.softmax(dim = 1)    # [batch_size,out_caps,in_caps,1]
            # [batch_size,out_caps,in_caps,1] .* [batch_size,out_caps,in_caps,out_dims]
            # ->[batch_size,out_caps,in_caps,out_dims] ->(sum.dim=2)->[batch_size,out_caps,out_dims]
            s = (c * temp_u_hat).sum(dim = 2)    # [batch_size,out_caps,out_dims]
            v = squash(s)    # [batch_size,out_caps,out_dims]
            # [batch_size,out_caps,in_caps,out_dims] * [batch_size,out_caps,out_dims,1]
            # -> [batch_size,out_caps,in_caps,1]   含义:v与每个incap的相似度
            uv = torch.matmul(temp_u_hat,v.unsqueeze(-1))
            b += uv    # [batch_size,out_caps,in_caps,1]
        c = b.softmax(dim = 1)
        # [batch_size,out_caps,in_caps,1] .* [batch_size,out_caps,in_caps,out_dims]
        # ->sum.dim = 2->[batch_size,out_caps,out_dims]
        s = (c * u_hat).sum(dim = 2)
        v = squash(s)    # [batch_size,out_caps,out_dims]
        return v

(5) reconstruction&整体架构&CapsuleLoss

(i)reconstruction操作

输入:[batch_size,10(out_caps),16(out_dim)]
输出:[batch_size,784]
含义:将特征向量还原为图片

本部分代码较简单,在整体架构代码的类中嵌入进去

(ii)整体架构

小黑胶囊网络笔记:CapsNet公式与代码对应_第9张图片

class CapsNet(nn.Module):
    
    def __init__(self):
        super(CapsNet,self).__init__()
        # Conv2d layer
        self.conv = nn.Conv2d(1,256,9)
        self.relu = nn.ReLU(inplace = True)
        # Primary capsule
        self.primary_caps = PrimaryCaps(num_conv_units = 32,in_channels = 256,out_channels = 8,kernel_size = 9,stride = 2)
        # Digit capsule
        self.digit_caps = DigitCaps(in_dim = 8,in_caps = 32 * 6 * 6,out_caps = 10,out_dim = 16,num_routing = 3)
        # Reconstruction layer
        self.decoder = nn.Sequential(
            nn.Linear(16 * 10,512),
            nn.ReLU(inplace = True),
            nn.Linear(512,1024),
            nn.ReLU(inplace = True),
            nn.Linear(1024,784),
            nn.Sigmoid()
        )
    def forward(self,x):
        out = self.relu(self.conv(x))    # [batch_size,in_channels,20,20] 20 = 28 - 9 + 1
        out = self.primary_caps(out)    # [batch_size,out_capsules*height*weight,out_channels]
        out = self.digit_caps(out)    # [batch_size,out_caps,out_dim]
        logits = torch.norm(out,dim = -1)    # [batch_size,out_caps]
        # [batch_size,out_caps]
        pred = torch.eye(10).to(device).index_select(dim = 0,index = torch.argmax(logits,dim = 1))    
        # Reconstruction
        batch_size = out.shape[0]
        # (out * pred.unsqueeze(2)):[batch_size,out_caps,out_dim]->view->[batch_size,out_caps*out_dim]
        # reconstruction:[batch_size,784]
        reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size,-1))
        return logits,reconstruction

(iii)CapsuleLoss

小黑胶囊网络笔记:CapsNet公式与代码对应_第10张图片

class CapsuleLoss(nn.Module):
    def __init__(self,upper_bound = 0.9,lower_bound = 0.1,lmda = 0.5):
        super(CapsuleLoss,self).__init__()
        self.upper = upper_bound
        self.lower = lower_bound
        self.lmda = lmda
        self.reconstruction_loss_scalar = 5e-4
        self.mse = nn.MSELoss(reduction = 'sum')
    def forward(self,images,labels,logits,reconstructions):
        left = (self.upper - logits).relu() ** 2
        right = (logits - self.lower).relu() ** 2
        margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right)
        # Reconstruction loss
        reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape),images)
        return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss

论文链接:https://arxiv.org/abs/2006.04768
代码链接:https://github.com/Riroaki/CapsNet

你可能感兴趣的:(网络,pytorch,深度学习)