李宏毅2022机器学习HW2解析

李宏毅2022机器学习HW2解析_第1张图片

准备工作:去课程github下载原始代码,kaggle下载数据集。或者关注本公众号,下载代码和数据集(文末有方法)。解压数据集,出现libriphone文件夹将文件和代码放到同一目录下。

kaggle提交: https://www.kaggle.com/c/ml2022spring-hw2,提交结果可能需要科学上网,想讨论的可进QQ群:156013866。

  • Simple Baseline (acc>0.45797): 直接运行代码,可能需要下载一些工具包,运行过后出现prediction.csv文件,将其提交到kaggle上得到分数:0.46083。

李宏毅2022机器学习HW2解析_第2张图片

  • Medium Baseline (acc>0.69747)concat_nframes参数设置+网络架构改变+学习率设置。对train_labels.txt文件进行统计,发现每一个音位占用的frame均值是9个,因此可以将concat_nframes参数设置为>9(必须为奇数),经尝试可以将concat_nframes设置的大些,这里我设置为17。网络架构调整的更宽和稍深。学习率也稍微调整的大些。运行代码,提交得到kaggle分数:0.70594

李宏毅2022机器学习HW2解析_第3张图片

  • Strong Baseline (acc>0.75028)concat_nframes参数设置+batch_size+网络架构改变+余弦退火学习率。concat_nframes参数设置为19。batch_size设置为2048。设置三个宽度为1024的隐藏层。利用余弦退火学习率,有的学生可能问了,为什么老是余弦退火啊,用李宏毅老师的话,这都是古圣先贤的意思,用就对了,不过我的理解是使用余弦退火的时候可以很直观的看到哪些学习率是比较合适的,这对我们选择正确的学习率参数很有帮助。运行代码,提交后得到分数:0.75321,好于strong baseline。

李宏毅2022机器学习HW2解析_第4张图片

  • Boss Baseline (acc>0.82324)concat_nframes参数设置+batch_size+BiLSTM-CRF网络架构+余弦退火学习率

    BiLSTM-CRF网络结构是序列标注中的经典模型,该结构可以综合考虑lstm的输出结果和标签顺序分布,可参考pytorch官方样例:https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html,或者使用pytorchcrf库。在使用BiLSTM-CRF架构的时候,需要修改数据的产生方式,之前每个sample的feature和label size分别是(batch_size, 39*concat_nframes)(batch_size,),现在是(batch_size,concat_nframes, 39)(batch_size,concat_nframes),最后做推理的时候也需要相应的改变。同时因为BiLSTM和CRF的收敛速度一般是不一样的,CRF的学习率要设置的大些,运行代码提交后,分数是:0.79449,还没到boss baseline,想得到更好的结果需要进行精细调参,另外可以尝试Transfromer-CRF或Bert-CRF结构。

李宏毅2022机器学习HW2解析_第5张图片

作业二答案获得方式:

  1. 关注微信公众号 “机器学习手艺人” 

  2. 后台回复关键词:202202

你可能感兴趣的:(机器学习,人工智能,深度学习)