Pytorch中的collate_fn函数用法

Pytorch中的collate_fn函数用法

官方的解释:
  Puts each data field into a tensor with outer dimension batch size
  即用于对单个样本生成batch的函数,如果没有特殊需求其实不用自己写collate_fn方法,有默认的default_collate方法。

collate_fn方法用在如下位置:

DataLoader(dataset=train_data, batch_size=4, shuffle=True, num_workers=2, collate_fn=train_data.collate_fn)

自定义collate_fn的一个demo,训练数据data和target。


def collate_fn(batch):
    data = [item[0] for item in batch]
    # 这里对我的target进行了reshape操作
    target = [torch.reshape(item[1], (-1,)) for item in batch]
    data = torch.stack(data)
    target = torch.stack(target)
    return [data, target]

你可能感兴趣的:(Pytorch)