李宏毅2022机器学习HW3解析

准备工作:作业三是食物分类,需要助教代码+数据集放置于同一目录下,数据集需解压。关注本公众号,可获得代码和数据集(文末有方法)。

kaggle提交: https://www.kaggle.com/competitions/ml2022spring-hw3b,提交结果可能需要科学上网,想讨论的可进QQ群:156013866。

  • Simple Baseline (acc>0.50099): 直接运行代码,可能需要下载一些工具包,运行过后出现submission.csv文件,将其提交到kaggle上得到分数:0.55278。

  • Medium Baseline (acc>0.73207)Data augmentation+dropout。对train_tfm进行修改,添加了常用的augmentation 方法,包括RandomResizedCrop(随机截取并resize)、RandomHorizontalFlip(随机横向翻转)、RandomVerticalFlip(随机竖向翻转)、RandomRoation(随机旋转)、RandomAffine(随机仿射)、RandomGrayscale(随机灰度化)。另外在模型的全连接层的最前面加上dropout层,注意dropout一定放到全连接层,千万不要放到卷积层。运行代码,提交得到kaggle分数:0.77689

  • Strong Baseline (acc>0.81872)Data augmentation + 残差网络架构+ FocalLoss + Cross Validation + Ensemble

    Augmentation的方法同medium。残差神经网络架构,如下图所示,基本的block包含两层卷积,卷积层的输出F(x)与block的输入x相加,注意这两个可能不是相同维度的,如果不相同,我使用1X1的卷积对x进行变换使其与F(x)有相同的维度,在Kaiming He的论文里面,还有一种使用zero-padding的方法来解决维度不同问题。FocalLoss相对于CrossEntropy,考虑样本不均衡的问题并增加了错误分类样本loss的权重,有alpha和gamma两个参数,我统计了各个样本的数量,根据不同类别的数目设定FocalLoss的alpha值,gamma值设为固定值2。Cross Validation + Ensemble,我使用了4-fold,得到4个模型,做推理的时候,每张图片有4个输出,将4个输出求和,然后使用argmax得到分类结果。运行代码,提交后得到分数:0.85159,Ensemble真是太强了,这四个模型里面最好的一个准确率是0.79,最差的0.77,合并一块准确率居然高了这么多。

    李宏毅2022机器学习HW3解析_第1张图片

李宏毅2022机器学习HW3解析_第2张图片

  • Boss Baseline (acc>0.88446)Data augmentation + 残差网络架构+ FocalLoss + Cross Validation + Ensemble + Test Time Augmentation

    相对于strong baseline,加大了网络的深度和训练次数,并使用Test Time Augmentation(TTA)。TTA使用了原本的test_tfm构建的testloader,另外使用train_tfm构建其他5个testloader,6个testloader的权重分别为0.6和0.1,0.1,0.1,0.1,01。运行后得到分数:0.88247,接近boss baseline,还需要微调。

李宏毅2022机器学习HW3解析_第3张图片

作业三答案获得方式:

  1. 关注微信公众号 “机器学习手艺人” 

  2. 后台回复关键词:202203

你可能感兴趣的:(机器学习,人工智能,深度学习)