mmclassification使用步骤与心得/ACCV实验记录

比赛链接:

  • accv官网:https://sites.google.com/view/webfg2020
  • 比赛网站:https://www.cvmart.net/race/9917/base
数据预处理参考博客: https://blog.csdn.net/u013347145/article/details/109250455

主要包括:
1.数据清洗

由于图片直接由网上爬取得到,强制把后缀名改为.jpg,其原先可能包含多种格式,如.gif、.png、.tiff等,还有prefix存在问题的图片,以及四通道图片。
筛选出读取不会报错的图片,其余图片直接过滤掉(占少数,不用担心会影响数据集的总体分布)

然后进行数据标签分布的统计,大致呈现【长尾分布的规律】(横坐标一团黑是因为5000个种类太密集):

mmclassification使用步骤与心得/ACCV实验记录_第1张图片

PS:一开始抓住长尾分布这一特性,考虑了一种重采样进行数据平衡的方式,参考论文:BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition,作者是XiuShen Wei,也是本次大赛主委会成员之一。
但是后来经过实验发现BBN的方法提分有瓶颈,并且两个ResNet连起来使得网络结构较为复杂,训练速度较慢。
网络结构如下:

mmclassification使用步骤与心得/ACCV实验记录_第2张图片

========2020.11.20=========

更新重采样方式,借鉴BBN的思想,给数量多的类别分配较小的采样权重,最终达到各个类别数量较为平衡,公式和代码如下:
mmclassification使用步骤与心得/ACCV实验记录_第3张图片

# 统计各个类别的分布
x_label = []
y_num_cls = []
for d in sorted(os.listdir(root)):
    fd = os.path.join(root, d)
    label = int(d)
    num_cls = len(os.listdir(fd))
    # print('label:',label,'num:',num_cls)
    x_label.append(label)
    y_num_cls.append(num_cls)

def get_weight(num_list):
    max_num = max(num_list)
    class_weight = [max_num / i for i in num_list]
    sum_weight = sum(class_weight)
    p_class_weight = [i/sum_weight for i in class_weight]
    return class_weight, sum_weight,p_class_weight

class_weight, sum_weight,p_class_weight = get_weight(num_list=y_num_cls)

# 遍历每个文件夹进行抽样
total_sum = 0
for d in sorted(os.listdir(root)):
    fd = os.path.join(root, d)
    num_files = len(os.listdir(fd))
    total_sum += num_files

lines = []
for d in sorted(os.listdir(root)):
    fd = os.path.join(root, d)
    label = int(d)
    files = os.listdir(fd)
    file_num = len(files)
    print('label {} '.format(d),int(total_sum*p_class_weight[label]))
    print(len(files))
    total_num = int(total_sum*p_class_weight[label])
    sample_num = total_num
    if total_num > file_num:
        sample_num = file_num
    sample_files = random.sample(files,sample_num)

    for file in sample_files:
        line = os.path.join(prefix, d, file)
        line += ','
        line += str(d)
        line += '\n'
        print(line)
        lines.append(line)

 

2.预处理
这里用的mmclassification中自带的data augmentation的方式,模块化设计,只需要指定字段名即可:
其中ColorJitter是继承了torchvision.transform.ColorJitter的父类方法,AddGaussianNoise是在网上找的加入高斯噪声的方法。
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='RandomFlip', flip_prob=0.3, direction='vertical'),
    dict(type='ColorJitter', brightness=0.126, saturation=0.5,hue=0.4),
    dict(type='AddGaussianNoise', mean=0.0, variance=1.0, amplitude=15.0),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
下面进行mmclasification使用的步骤

这是港中文和商汤联合实验室开源的分类框架,同系列的还有mmdetection、mmsegmentation等,高度集成化和模块化的设计,使用者只需要修改对应配置文件及其接口即可。

Github: https://github.com/open-mmlab/mmclassification

1.主要文件是tools/train.py和tools/test.py。
train.py里面我们需要修改的地方是:
parser.add_argument('config', type=str,default='../*/resnet50_b32x8.py',help='train config file path')

其中resnet50_b32x8.py文件内容如下所示:

_base_ = [
    '../_base_/models/resnet50.py', '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]


指定配置文件的路径,其余参数均不需要修改,会自动读取这个路径下所有配置文件的路径并送入网络。

2.数据的读取

在mmcls/datasets文件路径下新建数据类,它继承了BaseDataset这个基类,并复现其load_annotations函数方法,在这里实现训练和验证集的读取。
我分别尝试了遍历文件夹和将文件路径保存在.txt文件然后读取这两种方式,注意image和label对应即可。
注意比赛官网写到了有【标签噪声】,所以用label smoothing进行标签平滑。

然后就可以运行train.py啦~

3.结果保存

按照代码里的设置,运行日志和模型会保存在../work_dirs目录下,.log和.json文件保存模型设置的参数以及loss和acc等信息,.pth是每个epoch保存下来的模型。

你可能感兴趣的:(计算机视觉,深度学习,神经网络,pytorch)