概要
【零基础-1】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121488335【零基础-2】PaddlePaddle学习Bert_ 一只博客-CSDN博客
https://blog.csdn.net/qq_42276781/article/details/121523268
# 创建dataloader
def create_dataloader(dataset,
mode='train',
batch_size=1,
batchify_fn=None,
trans_fn=None):
if trans_fn:
dataset = dataset.map(trans_fn)
shuffle = True if mode == 'train' else False
if mode == 'train':
batch_sampler = paddle.io.DistributedBatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
else:
batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
return paddle.io.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=batchify_fn,
return_list=True)
def create_dataloader(dataset,
mode='train',
batch_size=1,
batchify_fn=None,
trans_fn=None):
create_dataloader,创建数据加载器,输入数据集dataset、模式mode(默认为训练集)、batchify_fn(未知,暂时理解成batchify_function,即batch化的函数)、trans_fn(转换样本的函数)。
if trans_fn:
dataset = dataset.map(trans_fn)
如果传入了trans_fn,就使用trans_fn将dataset进行一个转换,dataset.map的api文档如下
dataset — PaddleNLP 文档https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.datasets.dataset.html?highlight=dataset.map#paddlenlp.datasets.dataset.MapDataset.map
shuffle = True if mode == 'train' else False
如果是训练集,就打乱,否则不打乱,这里的语法相当于C、Java的三目运算符
shuffle = mode == 'train' ? true : false
if mode == 'train':
batch_sampler = paddle.io.DistributedBatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
else:
batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=batch_size, shuffle=shuffle)
如果传入的是训练集,则调用paddle.io.DistributedBatchSampler处理得到batch_sampler,如果传入的不是训练集,则调用paddle.io.BatchSampler处理得到batch_sampler。
这里为什么要得到batch_sampler呢?