mmclassification注意事项-修改、增加模块,测试时参数说明

目录

1.测试image_demo.py效果时,在windows下传入参数为绝对路径,不然可能找不着。

 2.测试时参数说明test.py

3.修改、增加模块

3.1配置文件D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py中

(1)neck

(2)loss--示例增加L1Loss,并进行debug

(3)数据增强--颜色抖动、mixup

(4)修改配置文件其他参数--加预训练模型、修改学习率

1.测试image_demo.py效果时,在windows下传入参数为绝对路径,不然可能找不着。

D:\Code\mmclassification\demo\image_demo.py

参数格式示例如下:

D:\Code\mmclassification\demo\00001.png D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth

mmclassification注意事项-修改、增加模块,测试时参数说明_第1张图片

#D:\Code\mmclassification\demo\image_demo.py

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

import mmcv

from mmcls.apis import inference_model, init_model, show_result_pyplot

##00001.png ../configs/resnet/GWF_resnet18_8xb32_in1k.py ../tools/work_dirs/resnet18_8xb32_in1k/epoch_6.pth #不是绝对路径,找不到
###D:\Code\mmclassification\demo\00001.png D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth

def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--show',
        action='store_true',
        help='Whether to show the predict results by matplotlib.')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    # test a single image
    result = inference_model(model, args.img)
    # show the results
    print(mmcv.dump(result, file_format='json', indent=4))
    if args.show:
        show_result_pyplot(model, args.img, result)


if __name__ == '__main__':
    main()

运行结果:

mmclassification注意事项-修改、增加模块,测试时参数说明_第2张图片

 2.测试时参数说明test.py

D:\Code\mmclassification\tools\test.py

#D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth
# --show ##一个个结果展示
#--show-dir D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\val_result  ##把每张图片的测试结果打印
#--metrics accuracy recall #支持多种测量结果输出

找不到测试文件夹时,在配置文件中D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py将数据修改为绝对路径。

3.修改、增加模块

3.1配置文件D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py中

(1)neck

neck=dict(type='GlobalAveragePooling'),

在D:\Code\mmclassification\mmcls\models\necks中找目前有的结构,如gem.py中有

class GeneralizedMeanPooling(nn.Module):#平均池化和最大池化的综合版本

则配置文件中可以改为neck=dict(type='GeneralizedMeanPooling'),

(2)loss--示例增加L1Loss,并进行debug

增加D:\Code\mmclassification\mmcls\models\losses\l1_loss.py只是作为例子

#D:\Code\mmclassification\mmcls\models\losses\l1_loss.py

import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def l1_loss(pred, target):
    target = nn.functional.one_hot(target,num_classes=102)  #修改标注的格式
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class L1Loss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(L1Loss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * l1_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss

并在D:\Code\mmclassification\mmcls\models\losses\__init__.py中增加

from .l1_loss import L1Loss, l1_loss

__all__ = [
    'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
    'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
    'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
    'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss', 'L1Loss', 'l1_loss'
]

则配置文件中可以改为

loss=dict(type='L1Loss', loss_weight=1.0),

注意:

当D:\Code\mmclassification\mmcls\models\losses\l1_loss.py中

@weighted_loss
def l1_loss(pred, target):
    #target = nn.functional.one_hot(target, num_classes=102) #数据格式必须维度统一
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

 target的数据格式和输入维度不统一时,即缺少

    #target = nn.functional.one_hot(target, num_classes=102) #数据格式必须维度统一

会出现如下错误:

File "D:\Code\mmclassification\mmcls\models\losses\l1_loss.py", line 9, in l1_loss
    assert pred.size() == target.size() and target.numel() > 0
AssertionError

则在D:\Code\mmclassification\mmcls\models\losses\l1_loss.py中打断点,debug调试

debug调试

 会得到:
mmclassification注意事项-修改、增加模块,测试时参数说明_第3张图片

 则表明是pred和target的维度不一致,在 \l1_loss.py中增加以下行即可:

    target = nn.functional.one_hot(target, num_classes=10) #使得目标标签格式由[3,0,...],修改为[[0,0,0,1,0,0,0,0,0,0],[1,0,0,0,0,0,0,0,0,0],...].

    assert pred.size() == target.size() and target.numel() > 0

(3)数据增强--颜色抖动、mixup

颜色抖动:在D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py中增加,如下例子:

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='ColorJitter', brightness=0.5, contrast=0.5, saturation=0.5), #gwf add

增加mixup

model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GeneralizedMeanPooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=10,
        in_channels=512,
        loss=dict(type='FocalLoss', loss_weight=1.0),#L1Loss CrossEntropyLoss
        topk=(1, 5)),
    train_cfg=dict(
            augments=dict(type='BatchMixup', alpha=0.2, num_classes=10,   #gwf add
 
                          prob=1.))
)

mmclassification注意事项-修改、增加模块,测试时参数说明_第4张图片

mixup表示泊松融合,按通透明度两张图各占多少,cutmix则是按区域两张图各占多少。

 

(4)修改配置文件其他参数--加预训练模型、修改学习率

预训练模型:

load_from = r'D:\Code\mmclassification\mmcls\data\resnet18_8xb32_in1k_20210831-fbbb1da6.pth'

若出现loss nan,减小学习率、或者换loss函数。

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)#减小学习率
        loss=dict(type='FocalLoss', loss_weight=1.0),

你可能感兴趣的:(分类,python,深度学习,pytorch)