Pytorch使用技巧之Dataloader中的collate_fn参数详析

以MNIST为例

from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])

结果

(, 5)

MINIST数据集的dataset是由一张图片和一个label组成的元组

dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
    print(each)
    break

结果

[(, 0), (, 2)]

collate_fn为lamda x:x时表示对传入进来的数据不做处理

下面自定义collate_fn看看什么效果

def collate(data):
    img = []
    label = []
    for each in data:
        img.append(each[0])
        label.append(each[1])
    return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
    print(each)
    break

结果

([, ], [9, 3])

说明:若不设置collate_fn参数则会使用默认处理函数

但必须保证传进来的数据都是tensor格式否则会报错

附:DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

总结

到此这篇关于Pytorch使用技巧之Dataloader中的collate_fn参数的文章就介绍到这了,更多相关Dataloader中的collate_fn参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

你可能感兴趣的:(Pytorch使用技巧之Dataloader中的collate_fn参数详析)