[CapsNet]基于pytorch的胶囊网络工具包

API

论文地址:Dynamic Routing Between Capsules

开发了一个基于pytorch的胶囊网络工具包,以下是它的API:

### 胶囊化层
class Capsulation2D(nn.Module)
	# input.shape = (batch_size, channels, height, weight)
    # output.shape = (batch_size, out_channels, out_dim_capsule, height, weight)
    
### 反胶囊化层
class DeCapsulation2D(nn.Module)
	# input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, out_channels, height, weight)
    
### 平化层
class CapFlatten(nn.Module)
    # input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, channels * height * weight, dim_capsule) which is (batch_size, num_capsules, dim_capsule)


### 反平化层
class DeCapFlatten(nn.Module)
    # input.shape = (batch_size, channels * height * weight, dim_capsule),  
    #     which is (batch_size, num_capsules, dim_capsule)
    # output.shape = (batch_size, channels, dim_capsule, height, weight)


###  标量化层
class CapScalarization(nn.Module)
    # input.shape = (batch_size, num_capsules, dim_capsule)
    # output.shape = (batch_size, num_capsules)


### 胶囊2D卷积层(V1)
class CapConv2dV1(nn.Module)
    # input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, out_channels, out_dim_capsule, out_height, out_weight)


### 数字胶囊(路由输出层)
class CapsuleLayer(nn.Module)
    # Dynamic Routing Version 
    # input.shape = [batch, input_num_capsule, input_dim_capsule]  
    # output.shape = [batch, num_capsule, 1, dim_capsule]


### 掩码层
class CapReconMask(nn.Module)
    # input.shape = (batch, num_classes, dim_capsules) | (batch, num_capsules, dim_capsules)
    # masked.shape = (batch, dim_capsules)
    

## 工具包
class CapTool():
    def one_hot(self, y, num_dim=10):
        """
        One Hot Encoding, similar to `torch.eye(num_dim).index_select(dim=0, index=y)`
        :param y: N-dim tenser
        :param num_dim: do one-hot labeling from `0` to `num_dim-1`
        :return: shape = (batch_size, num_dim)
        """

    def margin_loss(self, input, target, num_classes=10, m_plus=None, m_minus=None, m_lambda=0.5):
        """
        The non-linear activation used in Capsule. 
        It drives the length of a large vector to near 1 and small vector to 0

        input.shape = (batch_size, num_classes)
        target.shape = (batch_size, ), type of `LongTensor`, True-Label of classifications

        :param input: Predict-ablility of classifications
        :param target: True-Label of classifications
        :param num_classes: 10
        :param m_plus: 0.9
        :param m_minus: 0.1
        :param m_lambda: 0.5
        :return: shape = (1, )
        """
    
    def squash(self, s, dim=-1, constant=1, epsilon=1e-8):
        """
        It drives the length of a large vector to near 1 and small vector to 0
        :params s: N-dim tenser
        :params dim: the dimension to squash
        :params constant: (0, 1]
        :return: The same shape like `s`
        """
    
    def acc_eval(self, model, test_loader, loss_fn, y_pred_dim=0)
    
    
    def model_summary(self, model, show_layer_detail=True)

工具包下载

可以到此处获取ipynb版的内容,输出/拷贝为py文件即可直接import: 胶囊网络工具包/capsnet_tool.ipynb@gist

同时,该ipynb文件的末尾包含了一个测试,基于本工具包实现了 Hinton 提出的 Dynamic-Routing 版本的 CapsNet ,输出为py文件时记得把这部分内容删去。

你可能感兴趣的:(机器学习,数据挖掘,数学模型)