基于pytorch的图像分类框架-更新日志

基于pytorch的图像分类框架-更新日志

源码地址 github

使用示例 使用pytorch实现花朵分类


pytorch-classifier v1.1 更新日志

  • 2022.11.8

    1. 修改processing.py的分配数据集逻辑,之前是先分出test_size的数据作为测试集,然后再从剩下的数据里面分val_size的数据作为验证集,这种分数据的方式,当我们的val_size=0.2和test_size=0.2,最后出来的数据集比例不是严格等于6:2:2,现在修改为等比例的划分,也就是现在的逻辑分割数据集后严格等于6:2:2.
    2. 参考yolov5,训练中的模型保存改为FP16保存.(在精度基本保持不变的情况下,模型相比FP32小一半)
    3. metrice.py和predict.py新增支持FP16推理.(在精度基本保持不变的情况下,速度更加快)
  • 2022.11.9

    1. 支持albumentations库的数据增强.
    2. 训练过程新增R-Drop,具体在main.py中添加–rdrop参数即可.
  • 2022.11.10

    1. 利用Pycm库进行修改metrice.py中的可视化内容.增加指标种类.
  • 2022.11.11

    1. 支持EMA(Exponential Moving Average),具体在main.py中添加–ema参数即可.
    2. 修改早停法中的–patience机制,当–patience参数为0时,停止使用早停法.
    3. 知识蒸馏中增加了一些实验数据.
    4. 修复一些bug.

FP16推理实验:

实验环境:

System CPU GPU RAM
Ubuntu i9-12900KF RTX-3090 32G

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练RepVGG-A0:

    python main.py --model_name RepVGG-A0 --config config/config.py --save_path runs/RepVGG-A0 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练densenet121:

    python main.py --model_name densenet121 --config config/config.py --save_path runs/densenet121 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/RepVGG-A0
    python metrice.py --task val --save_path runs/densenet121

    python metrice.py --task val --save_path runs/mobilenetv2 --half
    python metrice.py --task val --save_path runs/resnext50 --half
    python metrice.py --task val --save_path runs/RepVGG-A0 --half
    python metrice.py --task val --save_path runs/densenet121 --half

计算各个模型的fps:

    python metrice.py --task fps --save_path runs/mobilenetv2
    python metrice.py --task fps --save_path runs/resnext50
    python metrice.py --task fps --save_path runs/RepVGG-A0
    python metrice.py --task fps --save_path runs/densenet121

    python metrice.py --task fps --save_path runs/mobilenetv2 --half
    python metrice.py --task fps --save_path runs/resnext50 --half
    python metrice.py --task fps --save_path runs/RepVGG-A0 --half
    python metrice.py --task fps --save_path runs/densenet121 --half
model val accuracy(train stage) val accuracy(test stage) val accuracy half(test stage) FP32 FPS(batch_size=64) FP16 FPS(batch_size=64)
mobilenetv2 0.74284 0.74340 0.74396 52.43 92.80
resnext50 0.80966 0.80966 0.80966 19.48 30.28
RepVGG-A0 0.73666 0.73666 0.73666 54.74 98.87
densenet121 0.77035 0.77148 0.77035 18.87 32.75

R-Drop实验:

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --rdrop

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练ghostnet:

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练efficientnet_v2_s:

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_rdrop
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_rdrop
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_rdrop
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_rdrop

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_rdrop
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_rdrop
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_rdrop
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_rdrop
model val accuracy val accuracy(r-drop) test accuracy test accuracy(r-drop)
mobilenetv2 0.74340 0.75126 0.73784 0.73741
resnext50 0.80966 0.81134 0.82437 0.82092
ghostnet 0.77597 0.76698 0.76625 0.77012
efficientnet_v2_s 0.84166 0.85289 0.84460 0.85837

EMA实验:

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --ema

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练ghostnet:

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练efficientnet_v2_s:

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_ema
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_ema
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_ema
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_ema

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_ema
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_ema
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_ema
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_ema
model val accuracy val accuracy(ema) test accuracy test accuracy(ema)
mobilenetv2 0.74340 0.74958 0.73784 0.73870
resnext50 0.80966 0.81246 0.82437 0.82307
ghostnet 0.77597 0.77765 0.76625 0.77142
efficientnet_v2_s 0.84166 0.83998 0.84460 0.83986

pytorch-classifier v1.2 更新日志

  1. 新增export.py,支持导出(onnx, torchscript, tensorrt)模型.

  2. metrice.py支持onnx,torchscript,tensorrt的推理.

     此处在predict.py中暂不支持onnx,torchscript,tensorrt的推理的推理,原因是因为predict.py中的热力图可视化没办法在onnx、torchscript、tensorrt中实现,后续单独推理部分会额外写一部分代码.
     在metrice.py中,onnx和torchscript和tensorrt的推理也不支持tsne的可视化,那么我在metrice.py中添加onnx,torchscript,tensorrt的推理的目的是为了测试fps和精度.
     所以简单来说,使用metrice.py最好还是直接用torch模型,torchscript和onnx和tensorrt的推理的推理模型后续会写一个单独的推理代码.
    
  3. main.py,metrice.py,predict.py,export.py中增加–device参数,可以指定设备.

  4. 优化程序和修复一些bug.

训练命令:

python main.py --model_name efficientnet_v2_s --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/efficientnet_v2_s --device 0 \
--pretrained --amp --warmup --ema --imagenet_meanstd

GPU 推理速度测试 sh脚本:

batch_size=1 # 1 2 4 8 16 32 64
python metrice.py --task fps --save_path runs/efficientnet_v2_s --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --model_type torchscript --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --half --batch_size $batch_size 
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --half --batch_size $batch_size

CPU 推理速度测试 sh脚本:

python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --dynamic --device cpu
batch_size=1
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=2
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=4
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=8
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=16
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size

各导出模型在cpu和gpu上的fps实验:

实验环境:

System CPU GPU RAM Model
Ubuntu20.04 i7-12700KF RTX-3090 32G DDR5 6400 efficientnet_v2_s

GPU

model Torch FP32 FPS Torch FP16 FPS TorchScript FP32 FPS TorchScript FP16 FPS ONNX FP32 FPS ONNX FP16 FPS TensorRT FP32 FPS TensorRT FP16 FPS
batch-size 1 93.77 105.65 233.21 260.07 177.41 308.52 311.60 789.19
batch-size 2 94.32 108.35 208.53 253.83 166.23 258.98 275.93 713.71
batch-size 4 95.98 108.31 171.99 255.05 130.43 190.03 212.75 573.88
batch-size 8 94.03 85.76 118.79 210.58 87.65 122.31 147.36 416.71
batch-size 16 61.93 76.25 75.45 125.05 50.33 69.01 87.25 260.94
batch-size 32 34.56 58.11 41.93 72.29 26.91 34.46 48.54 151.35
batch-size 64 18.64 31.57 23.15 38.90 12.67 15.90 26.19 85.47

CPU

model Torch FP32 FPS Torch FP16 FPS TorchScript FP32 FPS TorchScript FP16 FPS ONNX FP32 FPS ONNX FP16 FPS TensorRT FP32 FPS TensorRT FP16 FPS
batch-size 1 27.91 Not Support 46.10 Not Support 79.27 Not Support Not Support Not Support
batch-size 2 25.26 Not Support 24.98 Not Support 45.62 Not Support Not Support Not Support
batch-size 4 14.02 Not Support 13.84 Not Support 23.90 Not Support Not Support Not Support
batch-size 8 7.53 Not Support 7.35 Not Support 12.01 Not Support Not Support Not Support
batch-size 16 3.07 Not Support 3.64 Not Support 5.72 Not Support Not Support Not Support

pytorch-classifier v1.3 更新日志

  1. 增加repghost模型.
  2. 推理阶段把模型中的conv和bn进行fuse.
  3. 发现mnasnet0_5有点问题,暂停使用.
  4. torch.no_grad()更换成torch.inference_mode().

pytorch-classifier v1.4 更新日志

  1. predict.py支持检测灰度图,其读取后会检测是否为RGB通道,不是的话会进行转换.
  2. 更新readme.md.
  3. 修复一些bug.

Knowledge Distillation Experiment

为了测试知识蒸馏的可用性,基于CUB-200-2011百度网盘链接数据集进行实验.

stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD1 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw 

计算通过resnet50蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74116 0.74200 0.73483 0.73452 0.77012 0.76979
resnet50 0.78720 0.78744 0.77744 0.77670 0.81231 0.81162
teacher->resnet50
student->mobilenetv2
SoftTarget
0.77092 0.77179 0.75248 0.75191 0.77787 0.77752
teacher->resnet50
student->mobilenetv2
MGD
0.78888 0.78994 0.78390 0.78296 0.79940 0.79890
teacher->resnet50
student->mobilenetv2
AT
0.74789 0.74878 0.73870 0.73795 0.76324 0.76244

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.2 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74116 0.74200 0.73483 0.73452 0.77012 0.76979
ghostnet 0.77709 0.77756 0.76367 0.76277 0.78046 0.77958
teacher->ghostnet
student->mobilenetv2
SoftTarget
0.77878 0.77968 0.76108 0.76022 0.77916 0.77807
teacher->ghostnet
student->mobilenetv2
MGD
0.75632 0.75723 0.74688 0.74638 0.77357 0.77302
teacher->ghostnet
student->mobilenetv2
AT
0.74846 0.74945 0.73827 0.73782 0.76625 0.76534

由于SP蒸馏开启AMP时,kd_loss大概率会出现nan,所在SP蒸馏实验中,我们把所有模型都不开启AMP.

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP --test_tta
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74509 0.74568 0.73827 0.73761 0.76969 0.76903
ghostnet 0.77821 0.77881 0.75807 0.75708 0.77873 0.77805
teacher->ghostnet
student->mobilenetv2
SP
0.74733 0.74836 0.73267 0.73198 0.75893 0.75850

stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/resnet50_admaw
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74509 0.74568 0.73827 0.73761 0.76969 0.76903
resnet50 0.78720 0.78707 0.77400 0.77321 0.81231 0.81138
teacher->resnet50
student->mobilenetv2
SP
0.74116 0.74200 0.74042 0.73969 0.76840 0.76753

以下实验是通过训练好的自身模型再作为教师模型进行训练.

知识蒸馏, resnet50作为teacher, resnet50作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/resnet50_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw 

计算通过resnet50蒸馏resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self --test_tta

知识蒸馏, mobilenetv2作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/mobilenetv2_admaw 

计算通过mobilenetv2蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self --test_tta

知识蒸馏, ghostnet作为teacher, ghostnet作为student, 使用AT进行蒸馏:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000 --teacher_path runs/ghostnet_admaw 

计算通过ghostnet蒸馏ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self --test_tta
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74116 0.74200 0.73483 0.73452 0.77012 0.76979
teacher->mobilenetv2
student->mobilenetv2
AT
0.74677 0.74758 0.74430 0.74342 0.77012 0.76926
resnet50 0.78720 0.78744 0.77744 0.77670 0.81231 0.81162
teacher->resnet50
student->resnet50
AT
0.79057 0.79091 0.79165 0.79026 0.81102 0.81030
ghostnet 0.77709 0.77756 0.76367 0.76277 0.78046 0.77958
teacher->ghostnet
student->ghostnet
AT
0.78046 0.78080 0.77142 0.77069 0.78820 0.78742

在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2 --test_tta

普通训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算efficientnet_v2_s指标:

python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

计算通过efficientnet_v2_s蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta
model val accuracy val mpa test accuracy test mpa test accuracy(TTA) test mpa(TTA)
mobilenetv2 0.74116 0.74200 0.73483 0.73452 0.77012 0.76979
efficientnet_v2_s 0.84166 0.84191 0.84460 0.84441 0.86483 0.86484
teacher->efficientnet_v2_s
student->mobilenetv2
ST
0.76137 0.76209 0.75161 0.75088 0.77830 0.77715
teacher->efficientnet_v2_s
student->mobilenetv2
MGD
0.77204 0.77288 0.77529 0.77464 0.79337 0.79261
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA)
0.77204 0.77267 0.77744 0.77671 0.80284 0.80201
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(RDrop)
0.77204 0.77288 0.77529 0.77464 0.79337 0.79261
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA,RDrop)
0.77204 0.77267 0.77744 0.77671 0.80284 0.80201

关于Knowledge Distillation的一些解释

实验解释:

  1. 对于AT和SP蒸馏方法,上述实验都是使用block3和block4的特征层进行蒸馏.
  2. MPA是平均类别精度,在类别不平衡的情况下非常有用,当类别基本平衡的情况下,跟accuracy差不多.
  3. 当蒸馏loss出现nan的时候请不要开启AMP,AMP可能会导致浮点溢出导致的nan.

目前支持的类型有:

Name Method paper
SoftTarget logits https://arxiv.org/pdf/1503.02531.pdf
MGD features https://arxiv.org/abs/2205.01529.pdf
SP features https://arxiv.org/pdf/1907.09682.pdf
AT features https://arxiv.org/pdf/1612.03928.pdf

蒸馏学习跟模型,参数,蒸馏的方法,蒸馏的层都有关系,效果不好需要自行调整,其中SP和AT都可以对模型中的四个block进行组合计算蒸馏损失具体代码在utils/utils_fit.py的fitting_distill函数中可以进行修改.

你可能感兴趣的:(pytorch,分类,深度学习)