使用Open-mmLab库进行ResNet-50 classification

使用Open-mmLab库进行ResNet-50 classification应用

配置环境可以参考pytorch环境配置及所遇问题

源代码来自官方实例库:

链接:https://github.com/open-mmlab/mmclassification

官方说明文档:https://mmclassification.readthedocs.io/en/latest/index.html

样例应用及修改方式:

1. 准备数据集

首先保证mmclassification文件夹下有以下文件夹

mmclassification
├── mmcls
├── tools
├── configs
├── docs
├── data

并在data数据集中加入自己所需的新数据集,按以下格式和名称存储在data目录下

训练数据集:

netdata
├── ...
├── train
│   ├── n01440764
│   │   ├── n01440764_10026.JPEG
│   │   ├── n01440764_10027.JPEG
│   │   ├── ...
│   ├── ...
│   ├── n15075141
│   │   ├── n15075141_999.JPEG
│   │   ├── n15075141_9993.JPEG
│   │   ├── ...

测试数据集

├── val
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516

准备TXT文件链接:https://blog.csdn.net/alexzhang9208/article/details/115488509

2. 参数设定

训练模型的配置文件都在**/configs路径下,其中包括用来配置数据集、网络模型、训练策略等的_base_和各种模型,以ResNet50为例,在**/configs/resnet路径下,新建resnet50_b32x8_mynet.py,其中

_base_包括:
_base_ = [
    '../_base_/models/resnet50.py',                  # 模型配置
    '../_base_/datasets/mynet_bs32.py',              # 数据集配置
    '../_base_/schedules/imagenet_bs256.py',         # 训练策略配置
     '../_base_/default_runtime.py'                  # 日志、存储相关配置
]

通过自定义以上4项,便可以进行自己模型的训练任务。

3. 自定义数据集的加载

相关目录:

../_base_/datasets/			# 数据预处理、路径配置
../data/Mydataset/			# 数据放在这个位置
../mmcls/dataset/			# 数据读取

…/base/datasets/目录下代码

# dataset settings
dataset_type = 'MyNetData'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(128, 128), backend='pillow'),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1), backend='pillow'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='data/mynet/train',
        ann_file='data/mynet/meta/train.txt',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='data/mynet/val',
        ann_file='data/mynet/meta/val.txt',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='data/mynet/val',
        ann_file='data/mynet/meta/val.txt',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')

…/mmcls/dataset/,首先自定义自己的数据加载方法,代码如下

import os
import mmcv

import numpy as np

from .base_dataset import BaseDataset
from .builder import DATASETS


def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def find_folders(root):
    """Find classes by folders under a root.

    Args:
        root (string): root directory of folders

    Returns:
        folder_to_idx (dict): the map from folder name to class idx
    """
    folders = [
        d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
    ]
    folders.sort()
    folder_to_idx = {folders[i]: i for i in range(len(folders))}
    return folder_to_idx


def get_samples(root, folder_to_idx, extensions):
    """Make dataset by walking all images under a root.

    Args:
        root (string): root directory of folders
        folder_to_idx (dict): the map from class name to class idx
        extensions (tuple): allowed extensions

    Returns:
        samples (list): a list of tuple where each element is (image, label)
    """
    samples = []
    root = os.path.expanduser(root)
    for folder_name in sorted(os.listdir(root)):
        _dir = os.path.join(root, folder_name)
        if not os.path.isdir(_dir):
            continue

        for _, _, fns in sorted(os.walk(_dir)):
            for fn in sorted(fns):
                if has_file_allowed_extension(fn, extensions):
                    path = os.path.join(folder_name, fn)
                    item = (path, folder_to_idx[folder_name])
                    samples.append(item)
    return samples


@DATASETS.register_module()
class MyNetData(BaseDataset):

    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')

    def load_annotations(self):
        if self.ann_file is None:
            folder_to_idx = find_folders(self.data_prefix)
            samples = get_samples(
                self.data_prefix,
                folder_to_idx,
                extensions=self.IMG_EXTENSIONS)
            if len(samples) == 0:
                raise (RuntimeError('Found 0 files in subfolders of: '
                                    f'{self.data_prefix}. '
                                    'Supported extensions are: '
                                    f'{",".join(self.IMG_EXTENSIONS)}'))

            self.folder_to_idx = folder_to_idx
        elif isinstance(self.ann_file, str):
            with open(self.ann_file) as f:
                samples = [x.strip().split(' ') for x in f.readlines()]
        else:
            raise TypeError('ann_file must be a str or None')
        self.samples = samples

        data_infos = []
        for filename, gt_label in self.samples:
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)
        return data_infos

第二步需要在 init.py下登记自定义的数据集,代码如下:

from .base_dataset import BaseDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               RepeatDataset)
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler
from .voc import VOC
from .mynet import MyNetData

__all__ = [
    'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
    'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset', 'Compose',
    'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'MyNetData'
]
4. 模型配置
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ))
5. 训练策略
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)
6. 训练模型

单一GPU训练命令如下:

python tools/train.py ${CONFIG_FILE} [optional arguments]

其中optional arguments包括工作目录的定义

--work_dir ${YOUR_WORK_DIR}

例如:使用Open-mmLab库进行ResNet-50 classification_第1张图片

python tools/train.py D:\DeepLearning\mmclassfication_res\configs\resnet\resnet50_b32x8_mynet.py  --work-dir D:\DeepLearning\mmclassfication_res\myres50
7. 注意事项

(1)训练外部数据时,或按照不同分类存储至以类别为名称的文件夹内,或给定txt文档说明数据与类别之间的关系,二者选其一即可。

(2)mmdetection默认训练jpg数据,要训练其他格式数据更改需pipelines/loading代码:

class LoadImageFromFile(object):
 
    def __init__(self, to_float32=False):
        self.to_float32 = to_float32
 
    def __call__(self, results):
        
        if results['img_prefix'] is not None:
            filename = osp.join(results['img_prefix'],
                                results['img_info']['filename'])
            filename = filename.replace('jpg', 'png') #替换为自己的格式
            #print(filename)   
        else:
            filename = results['img_info']['filename']
        print(filename)
        img = mmcv.imread(filename)
        if self.to_float32:
            img = img.astype(np.float32)
        results['filename'] = filename
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        return results

训练tif数据时,虽然TIFF格式支持32位 ,但OpenCV libtiff库不支持32位TIFF,接受的TIFF数据格式是double,single,uint8,uint16或logical,需转为相应格式,其报错为:

"Sorry, can not handle images with %d-bit samples"

(3)由于训练类别的不同,请根据自己的类别数目,选择\configs\_base_\models\下相应模型的topk值,相对应报错为:

RuntimeError:invalid argument 5:k not in range for dimension at /pytorch/ate ... 

你可能感兴趣的:(ResNet,深度学习心得)