目录
1.测试image_demo.py效果时,在windows下传入参数为绝对路径,不然可能找不着。
2.测试时参数说明test.py
3.修改、增加模块
3.1配置文件D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py中
(1)neck
(2)loss--示例增加L1Loss,并进行debug
(3)数据增强--颜色抖动、mixup
(4)修改配置文件其他参数--加预训练模型、修改学习率
D:\Code\mmclassification\demo\image_demo.py
参数格式示例如下:
D:\Code\mmclassification\demo\00001.png D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth
#D:\Code\mmclassification\demo\image_demo.py
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import mmcv
from mmcls.apis import inference_model, init_model, show_result_pyplot
##00001.png ../configs/resnet/GWF_resnet18_8xb32_in1k.py ../tools/work_dirs/resnet18_8xb32_in1k/epoch_6.pth #不是绝对路径,找不到
###D:\Code\mmclassification\demo\00001.png D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--show',
action='store_true',
help='Whether to show the predict results by matplotlib.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_model(model, args.img)
# show the results
print(mmcv.dump(result, file_format='json', indent=4))
if args.show:
show_result_pyplot(model, args.img, result)
if __name__ == '__main__':
main()
运行结果:
D:\Code\mmclassification\tools\test.py
#D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\epoch_6.pth # --show ##一个个结果展示 #--show-dir D:\Code\mmclassification\tools\work_dirs\resnet18_8xb32_in1k\val_result ##把每张图片的测试结果打印 #--metrics accuracy recall #支持多种测量结果输出
找不到测试文件夹时,在配置文件中D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py将数据修改为绝对路径。
neck=dict(type='GlobalAveragePooling'),
在D:\Code\mmclassification\mmcls\models\necks中找目前有的结构,如gem.py中有
class GeneralizedMeanPooling(nn.Module):#平均池化和最大池化的综合版本
则配置文件中可以改为neck=dict(type='GeneralizedMeanPooling'),
增加D:\Code\mmclassification\mmcls\models\losses\l1_loss.py只是作为例子
#D:\Code\mmclassification\mmcls\models\losses\l1_loss.py
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def l1_loss(pred, target):
target = nn.functional.one_hot(target,num_classes=102) #修改标注的格式
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module()
class L1Loss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
并在D:\Code\mmclassification\mmcls\models\losses\__init__.py中增加
from .l1_loss import L1Loss, l1_loss
__all__ = [
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss', 'L1Loss', 'l1_loss'
]
则配置文件中可以改为
loss=dict(type='L1Loss', loss_weight=1.0),
注意:
当D:\Code\mmclassification\mmcls\models\losses\l1_loss.py中
@weighted_loss
def l1_loss(pred, target):
#target = nn.functional.one_hot(target, num_classes=102) #数据格式必须维度统一
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
target的数据格式和输入维度不统一时,即缺少
#target = nn.functional.one_hot(target, num_classes=102) #数据格式必须维度统一
会出现如下错误:
File "D:\Code\mmclassification\mmcls\models\losses\l1_loss.py", line 9, in l1_loss
assert pred.size() == target.size() and target.numel() > 0
AssertionError
则在D:\Code\mmclassification\mmcls\models\losses\l1_loss.py中打断点,debug调试
debug调试
则表明是pred和target的维度不一致,在 \l1_loss.py中增加以下行即可:
target = nn.functional.one_hot(target, num_classes=10) #使得目标标签格式由[3,0,...],修改为[[0,0,0,1,0,0,0,0,0,0],[1,0,0,0,0,0,0,0,0,0],...].
assert pred.size() == target.size() and target.numel() > 0
颜色抖动:在D:\Code\mmclassification\configs\resnet\GWF_resnet18_8xb32_in1k.py中增加,如下例子:
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='ColorJitter', brightness=0.5, contrast=0.5, saturation=0.5), #gwf add
增加mixup
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GeneralizedMeanPooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='FocalLoss', loss_weight=1.0),#L1Loss CrossEntropyLoss
topk=(1, 5)),
train_cfg=dict(
augments=dict(type='BatchMixup', alpha=0.2, num_classes=10, #gwf add
prob=1.))
)
mixup表示泊松融合,按通透明度两张图各占多少,cutmix则是按区域两张图各占多少。
预训练模型:
load_from = r'D:\Code\mmclassification\mmcls\data\resnet18_8xb32_in1k_20210831-fbbb1da6.pth'
若出现loss nan,减小学习率、或者换loss函数。
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)#减小学习率
loss=dict(type='FocalLoss', loss_weight=1.0),