创建Dataloader基础篇【一】

概述

在transformers trainer训练、评估模型中,大致根据以下过程加载与处理训练、评估数据集:

  1. 使用dataset.Dataset加载数据
  2. 使用Dataset.map与自定义的convert_examples_to_features函数处理Dataset中的每一行数据
  3. 定义sampler,在迭代Dataloader过程中,本质是迭代sampler。默认auto-batch模式下,sampler在每次迭代过程中,会返回一个batch的索引数值(indices),然后根据indices从Dataloader.dataset中取数据(fetch)。e.g. [self.dataset[index] for index in batch_indices]
  4. 将第三步取到的数据,喂到collator_fn中,组装成tensor类型,并返回组装后的结果。

因为trainer默认是以aotu_batch方式加载与处理数据,因此本部分仅记录aotu_batch方式。另外本文仅记录trainer中创建dataloader的基础过程。对于一些个性化加载与处理、如长文档文本分类,如有必要,会另起一篇文章再进行记录。

实例

# set up
from typing import List, Dict, Union

from datasets import Dataset
from transformers import default_data_collator
from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler

from config import CKP  # huggingface 中预训练模型下载到本地的地址

# emotion classification demo
x = [{"texts": "我爱中国。", "labels": 1}, {"texts": "今天天气真糟糕!", "labels": 0}] * 3

# 可以使用datasets.load_dataset函数,将样本数据存储为json格式,每一条样本占据一行
examples: Dataset = Dataset.from_list(x)
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(CKP)

def convert_examples_to_features(exams: Dict[str, List[Union[str, int]]]):
    return tokenizer(exams["texts"], padding=True, max_length=20, truncation=True)

# map函数中的batch=True并不影响最终结果,只是影响convert_examples_to_features的签名|定义
dataset = examples.map(convert_examples_to_features, with_indices=False, with_rank=False, batched=True,
                       batch_size=1, remove_columns=["texts"])

# 验证sampler
sequence_sampler = SequentialSampler(dataset)
print(f"sequence sampler: {list(sequence_sampler)}")

random_sampler = RandomSampler(dataset)
print(f"random sampler: {list(random_sampler)}")

batch_sampler = BatchSampler(random_sampler, batch_size=2, drop_last=False)
print(f"batch sampler: {list(batch_sampler)}")

# 在convert_examples_to_features已经对input_ids进行了pad,所以使用default_data_collator
# 如果仅进行编码,即padding=False, 此处使用transformers.DataCollatorWithPadding
dataloader = DataLoader(dataset, batch_size=1, collate_fn=default_data_collator)

# add breakpoint in here, you will see
# step1. get next batch indices
# step2. fetch data according batch indices
# step3. collator data by collator_fn and return batch
for batch in dataloader:
    print(batch)

参考资料

datasets.Dataset.map方法学习笔记
transformers中的data_collator
【pytorch】Dataloader学习笔记

你可能感兴趣的:(#,huggingface,transformers,pytorch)