torch.utils.data.dataloader参数collate_fn简析

torch.utils.data.DataLoader是pytorch提供的数据加载类,初始化函数如下,

torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataset,batch_size等参数重要且容易理解,而collate_fn参数就不太直白,官方解释为:

collate_fn (callableoptional) – merges a list of samples to form a mini-batch

不明不白。

其实,collate_fn可理解为函数句柄、指针...或者其他可调用类(实现__call__函数)。 函数输入为list,list中的元素为欲取出的一系列样本。具体如下

indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])

其中self.sampler_iter即采样器,返回下一个batch中样本的序号,indices。

通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

 

你可能感兴趣的:(机器学习,深度学习,pytorch)