图像分类 Convnext 结合capsnet胶囊网络

首先介绍Convnext:

ConvNeXt
论文名称:A ConvNet for the 2020s
论文下载链接:https://arxiv.org/abs/2201.03545
论文对应源码链接:https://github.com/facebookresearch/ConvNeXt
视频讲解:https://www.bilibili.com/video/BV1SS4y157fu

相关介绍可以看这位大佬太阳花的小绿豆的博文
原文链接:https://blog.csdn.net/qq_37541097/article/details/122556545

关于capsnet网络:

图像分类 Convnext 结合capsnet胶囊网络_第1张图片

什么是胶囊网络

将神经元替换为胶囊就是胶囊网络。

高层胶囊和低层胶囊之间权重通过dynamic routing获得。

胶囊是一组神经元,它会学习检测给定区域图像的特定目标,它输出一个向量,向量的长度代表目标存在的概率估计,用向量的方向表示实体的属性。如果对象有轻微的变化(例如移位、旋转、改变大小等),那么胶囊将输出相同长度但方向略有不同的向量,因此,胶囊是等变化的。

更多讲解参考这篇文章:
版权声明:本文为CSDN博主「GodWriter」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/GodWriter/article/details/79216404

经典的胶囊网络作为一种图像分类模型被提出,其创新性地提出vector in- vector out,并采用动态路由算法替代了传统的卷积层,但是由于动态路由算法庞大的计算量,导致其难以应用到更普遍的场景.

《CAPSULES WITH INVERTED DOT-PRODUCT  ATTENTION ROUTING》引入了一种新的胶囊网络路由算法,该算法仅基于父节点的状态和子节点的投票之间的协议将子节点路由到父节点。创新点有:

1.通过反向点积注意设计路由;

2采用层归一化作为归一化;

3 用并发迭代路由代替顺序迭代路由。与之前提出的路由算法相比,我们的方法在基准数据集如CIFAR-10和CIFAR-100上提高了性能,并与参数少4倍的强大CNN (ResNet-18)媲美。在从重叠数字图像中识别数字的不同任务中,所提出的胶囊模型在给定相同的层数和每层神经元的情况下具有较好的性能。

图像分类 Convnext 结合capsnet胶囊网络_第2张图片

图像分类 Convnext 结合capsnet胶囊网络_第3张图片

 本文将convnext 作为backbone,由CAPSULES WITH INVERTED DOT-PRODUCT  ATTENTION ROUTING作为最后的分类网络,改进了原先的网络,实现了更好的性能。

预测效果,代码链接:https://download.csdn.net/download/qq_54557352/85233010

图像分类 Convnext 结合capsnet胶囊网络_第4张图片打开文件,终端执行 python .\main_capsule.py --bs 4 --backbone convnext图像分类 Convnext 结合capsnet胶囊网络_第5张图片

from datetime import datetime

# +
parser = argparse.ArgumentParser(description='Training Capsules using Inverted Dot-Product Attention Routing')

parser.add_argument('--resume_dir', '-r', default='', type=str, help='dir where we resume from checkpoint')
parser.add_argument('--num_routing', default=1, type=int, help='number of routing. Recommended: 0,1,2,3.')
parser.add_argument('--data_path', default="flower_photos/flower_photos")
parser.add_argument('--backbone', default='simple', type=str, help='type of backbone. simple or convnext or resnet')
parser.add_argument('--num_workers', default=0, type=int, help='number of workers. 0 or 2')
parser.add_argument('--config_path', default='./configs/resnet_backbone_CIFAR10.json', type=str,
                    help='path of the config')
parser.add_argument('--debug', action='store_true',
                    help='use debug mode (without saving to a directory)')
parser.add_argument('--sequential_routing', action='store_true', help='not using concurrent_routing')

parser.add_argument('--lr', default=0.1, type=float, help='learning rate. 0.1 for SGD')
parser.add_argument('--dp', default=0.0, type=float, help='dropout rate')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--epoches', default=100, type=int, help='')
parser.add_argument('--bs', default=4, type=int, help='batch-size')

你可能感兴趣的:(机器视觉,python,算法)