图像分类竞赛——Test Time Augmentation(TTA)

TTA介绍

Test Time Augmentation(TTA),测试数据增强,是在测试阶段时,将输入的测试数据进行,翻转、旋转操作等数据增强,并最后对同一样本的不同数据增强的结果根据任务需求进行例如平均,求和等数据处理。

TTA实现

这里推荐github上一个库https://github.com/qubvel/ttach,可以直接调用tta,非常方便。

tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform(),  merge_mode='mean')
  • 第一个参数为model,即为输入的模型

  • 第二个参数为transform类型

    • 调用作者已经设定好的tta类型tta.aliases图像分类竞赛——Test Time Augmentation(TTA)_第1张图片
    • 自定义tta
      图像分类竞赛——Test Time Augmentation(TTA)_第2张图片
  • 第三个参数为数据融合方式,一般对于图像分类来说多采用mean,和geo-mean,一般geo-mean要由于mean,至于这几种融合方法的优劣,可以参考kaggle的这篇文章MLWave/Kaggle-Ensemble-Guide
    图像分类竞赛——Test Time Augmentation(TTA)_第3张图片

注意事项

在模型调用这个库的时候,因为这个库的是并行处理数据,也就是说,原本test-batch-size为12时,如果输入TTA的数量是2×2×2 = 8时,相应的数据的输入会变为12×8 = 96,所以在使用的时候应对应使用的TTA数量而缩小test-batch-size,以显存不足。

你可能感兴趣的:(图像分类竞赛)