李宏毅2022机器学习HW7解析

李宏毅2022机器学习HW7解析_第1张图片

准备工作

作业七是BERT问答,需要助教代码和数据集,运行代码过程中保持联网可以自动下载数据集,已经有数据集的情况可关闭助教代码中的下载数据部分。关注本公众号,可获得代码和数据集(文末有方法)。

提交地址

Kaggle:https://www.kaggle.com/competitions/ml2022spring-hw7,另外关于作业有想讨论的同学可进QQ群:156013866。

数据预处理流程梳理

数据解压后包含3个json文件:hw7_train.json, hw7_dev.json, hw7_test.json

第一步是通过read_data函数读取这三个文件,每个文件返回相应的question数据和paragraph数据,都是文本数据,不能作为模型的输入。

第二步是利用tokenizer将question和paragraph文本数据转换为数字数据。

第三步是使用QA_Dataset将选取paragraph中固定长度的片段(固定长度为150),片段需包含answer部分,然后使用CLS + question + SEP + paragraph + CLS + padding(不足的补0)作为训练数据。

以上三个步骤进行后,模型就具备标准输入数据了。

Simple Baseline (Acc>0.45139)

方法:直接运行助教代码。注意在本地或kaggle上运行时候,需要调整文件名称或者路径。提交kaggle的score是:0.51028 。

李宏毅2022机器学习HW7解析_第2张图片

Medium Baseline (Acc>0.65792)

方法:修改doc_stride+learning rate scheduler。将self.doc_stride设置为32,这样移动的更密集,做inference的时候准确率更高。另外使用get_linear_schedule_with_warmup训练,提升优化能力。提交kaggle的socre是:0.66196

# 一、降低doc_stride值为32self.doc_stride = 32
# 二、使用scheduler,一个epoch大概1000步,所以num_training_steps=1000learning_rate = 2e-4optimizer = AdamW(model.parameters(), lr=learning_rate)from transformers import get_linear_schedule_with_warmupscheduler = get_linear_schedule_with_warmup(optimizer,                 num_warmup_steps=100, num_training_steps=1000                ....                optimizer.step()        optimizer.zero_grad()        # 每次运算完启动scheduler        scheduler.step()        ....

李宏毅2022机器学习HW7解析_第3张图片

Strong Baseline (Acc>0.78136)

方法:修改doc_stride+learning rate scheduler+preprocessing + postprocessing + new model。与medium baseline相比,有三个地方进行了改进。第一个是做preprocessing,能提升大概9%的正确率,助教代码中训练集(QA_Dataset类生成)都是以answer为中心截取段落,这可能让模型学习到“答案在段落中央”这样的结论,为避免此问题,我将训练集变为随机抽取片段,并且片段包含答案。第二个是采用新的pretrain模型,能提升大概7%的正确率,新的模型大小是1.2G比原模型的400M大了很多,这样造成内存不够,需要降低batch size,并采用accumulate gradient的方法,注意这里要跟lr scheduler混合使用,需要改动lr部分,下面列举的是部分代码。第三个是postprocessing,能提升大概0.5%的正确率,在evaluate函数中,可能出现预测的start index比end index大的情况,要添加代码修复。提交kaggle的socre是:0.80112

李宏毅2022机器学习HW7解析_第4张图片

# 一、preprocessing#single window is obtained by slicing the portion of paragraph containing the answer#mid = (answer_start_token + answer_end_token) // 2#paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))   
start_min = max(0, answer_end_token - self.max_paragraph_len + 1)
start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)
start_max = max(start_min, start_max)
paragraph_start = random.randint(start_min, start_max + 1)
paragraph_end = paragraph_start + self.max_paragraph_len
# 二、采用新的pretrian模型# model_name = "bert-base-chinese"model_name = "luhua/chinese_pretrain_mrc_roberta_wwm_ext_large"model = BertForQuestionAnswering.from_pretrained(model_name).to(device)tokenizer = BertTokenizerFast.from_pretrained(model_name)...train_batch_size = 16...acc_steps = 4...        if step % acc_steps == 0:            optimizer.step()            optimizer.zero_grad()            scheduler.step()​​​​​​
# 三、postprocessingstart_prob, start_index = torch.max(output.start_logits[k], dim=0)end_prob, end_index = torch.max(output.end_logits[k], dim=0)if start_index > end_index:    continue

Boss Baseline (Acc>0.83091)

方法:doc_stride + max_length+ learning rate scheduler + preprocessing + postprocessing + new model + no validation。与strong baseline相比,最大的改变有两个,一是换pretrain model,在hugging face中搜索chinese + QA的模型,根据model card描述选择最好的模型,使用后大概提升2.5%的精度,二是更近一步的postprocessing,查看提交文件可看到很多answer包含CLS, SEP, UNK等字符,CLS和SEP的出现表示预测位置有误,UNK的出现说明有某些字符无法正常编码解码(例如一些生僻字),错误字符的问题均可在evaluate函数中改进,这个步骤提升了大概1%的精度。其他的修改主要是针对overfitting问题,包括减少了learning rate,提升dataset里面的paragraph max length, 将validation集合和train集合进行合并等。另外可使用的办法有ensemble,大概能提升0.5%的精度,改变random seed,也有提升的可能性。提交kaggle的socre是:0.83461

李宏毅2022机器学习HW7解析_第5张图片

作业七答案获得方式:

  1. 关注微信公众号 “机器学习手艺人” 

  2. 后台回复关键词:202207

你可能感兴趣的:(机器学习,人工智能,深度学习)