此次数据使用的是LCQMC,它是做短文本匹配的一个数据,长这样
判断两个文本是否相似,如果相似标签为1,不相似为0
对于数据的处理在这里就不做研究了,无非就是分词,构建词表之类的,这里只说Dataset和DataLoader的用法。
我们首先构建一个类,并继承Dataset类
class DatasetIterater(Dataset):
def __init__(self,texta,textb,label):
self.texta = texta
self.textb = textb
self.label = label
def __getitem__(self, item):
return self.texta[item],self.textb[item],self.label[item]
def __len__(self):
return len(self.texta)
既然继承Dataset类,就要实现Dataset类的方法。
第一个方法就不用说了,初始化方法。
第二个是迭代方法,每次得到是一个数据,不是一个batch,我最初以为是一个batch的数据。
第三个就是返回数据的个数
但是,处理NLP的数据,通常情况下是需要对数据进行补齐,也就是在不够长度的数据后补0,所以需要自己实现一个collate_fn函数来进行对文本的补齐操作。
def collate_fn(batch_data,pad=0):
texta,textb,label = list(zip(*batch_data))#batch_data的结构是[([texta_1],[textb_1],[label_1]),([texta_2],[textb_2],[label_2]),...],所以需要使用zip函数对它解压
max_len_a = max([len(seq_a) for seq_a in texta])
max_len_b = max([len(seq_b) for seq_b in textb])
max_len = max(max_len_a,max_len_b) #这里我使用的是一个batch中text_a或者是text_b的最大长度作为max_len,也可以自定义长度
texta = [seq+[pad]*(max_len-len(seq)) for seq in texta]
textb = [seq+[pad]*(max_len-len(seq)) for seq in textb]
texta = torch.LongTensor(texta)
textb = torch.LongTensor(textb)
label = torch.FloatTensor(label)
return (texta,textb,label)
接下来就可以向自己实现的 DatasetIterater 类里传值了
train_data = DatasetIterater(train_texta,train_textb,train_label)
然后使用DataLoader进行处理,DataLoader每次返回的是一个batch的数据,这里的shuffle代表每次是否要打乱数据,一般对于训练数据都是要打乱的,验证集可打乱也可不打乱,测试集是千万不能打乱的!这里的collate_fn使用的是我们自己定义的函数,
DataLoader里还有一个参数是num_workers,这个是使用几个线程来处理。
train_iter = DataLoader(dataset=train_data,batch_size=args.batch_size,shuffle=True,collate_fn=collate_fn)
最后我们就可以迭代数据向模型里传入了,
for batch_data in tqdm(train_data):
texta, textb, tag = map(lambda x: x.to(device), batch_data)
output = model(texta, textb)