Huggingface transformer的Trainer中data_collator的使用

什么时候使用

由Transformers Trainer的文档中可知,Trainer函数有一个参数data_collator,其值也为一个函数,用于从一个list of elements来构造一个batch。这个函数其实就是torch.utils.data.DataLoader中的collate_fn。如果还不知道collate_fn是做何用处,请参考这篇文档。
要用到这个函数,要符合如下两个条件:

  1. Trainer的参数train_dataseteval_dataset是torch.utils.data.Dataset或torch.utils.data.IterableDataset的实体
  2. train_dataseteval_dataset(torch.utils.data.Dataset)加载入DataLoader后,得到的batch不可用,还不能直接加入到model的forward中计算

如何用

这里假设读者已经知道torch.utils.data.DataLoader的collate_fn用法,只介绍Trainer的data_collator和torch.utils.data.DataLoader的collate_fn的差异。
差异就是,输出格式!torch.utils.data.DataLoader的collate_fn的输出可以是各种格式,但Trainer的data_collator的输出只能是一个dict,这个dict的键必须包含“input_ids”,“attention_mask”等transformers models正常运算必要的参数的名称,如果需要,也可以加入任何transformers model.forward()可接受的参数名,而这些键对应的值也必须是transformers model中该键应该对应的输入值。
如果想让模型自动训练loss,还要在这个dict中加入如下键值对:{“labels”: labels in tensor type},这样模型的输出里就有loss了。

为什么呢?

看两段源码其实就差不多明白了:
Huggingface transformer的Trainer中data_collator的使用_第1张图片
Huggingface transformer的Trainer中data_collator的使用_第2张图片
第一张图中,这个DataLoader就是一个纯粹的torch.utils.data.DataLoader,self.data_collator就是输入的data_collator函数。所以,这个data_collator就彻彻底底是一个DataLoader的collate_fn啊
第二张图中,input就是如下迭代的结果(其中的dataloader就是第一张图中的dataloader)

for step, inputs in enumerate(DataLoader)

所以,inputs的键值对必须要与model.forwards()的参数相对应也是显然的

你可能感兴趣的:(huggingface,transformer,深度学习,pytorch)