Pytorch的DataLoarder中的collate_fn参数

使用方法

作为dataLoader的形参,不传入的时候使用默认的,可以自己定义。

DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

自己定义:

def collate_fn(examples):
    """
    wfj:该函数表示对于batch_size中的每一个元素做以下一下的操作,通常用来进行数据的标准化工作
    """
    print("==========================")
    print(examples)
    print(len(examples))
    lengths = torch.tensor([len(ex[0]) for ex in examples])
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = [torch.tensor(ex[1]) for ex in examples]
    # 对batch内的样本进行padding,使其具有相同长度
    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab[""])
    targets = pad_sequence(targets, batch_first=True, padding_value=vocab[""])
    #输出的几个参数的解释:解释变量;每个解释变量的长度;被解释变量;是否为填充位的标记。
    return inputs, lengths, targets, inputs != vocab[""]

打印信息

在这里插入图片描述
我们的batch_size设置的是32。

解析

所以collate_fn接受的一个参数,就是Dataloader迭代取出的每个batch_size,我们可以在collate_fn中对每个batch_size的数据进行相关的操作和个性化的处理。

你可能感兴趣的:(python,工具使用,pytorch,深度学习,python)