关于如何构建自己的dataset, dataloader,请看博客:pytorch -- 构建自己的Dateset,DataLoader如何使用_无脑敲代码,bug漫天飞的博客-CSDN博客
下面内容来自博客:pytorch中collate_fn函数的使用&如何向collate_fn函数传参_XJTU-Qidong的博客-CSDN博客_collate_fn pytorch
然后明确dataloader返回数据的形式是怎样的; 是以batch的形式取数据,是取出等同于batch size的index列表,然后将列表中的index输入到dataset的getitem()中,按照index来获取数据,之后对每个index对应的数据进行堆叠, 就形成了一个batch的数据,并返回;
但是在堆叠数据的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠,因此在构建dataset数据集时这样就应该处理好.
collate-fn函数就是手动将抽取出的样本堆叠起来的函数;
( 自定义数据堆叠过程;自定义batch数据的输出形式)
collate_fn函数是实例化dataloader的时候, 以函数形式传递给loader.
collate_fn是以函数作为参数进行传递, 那么其一定有默认参数:
默认参数就是getitem函数返回的数据项的batch形成的列表.
先假设dataset类为:
class testData(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, index):
return x, y
假设的dataset返回两个数据项: x和y. 那么, 传入collate_fn的参数定义为data, 则其shape为(batch_size, 2,…).知道了输入参数的形式, 就可以
去定义collate_fn函数了:
def collate_fn(data):
for unit in data:
unit_x.append(unit[0])
unit_y.append(unit[1])
...
return {x: torch.tensor(unit_x), y: torch.tensor(unit_y)}
最后返回的是一个字典. 这也是collate_fn函数最大的一个好处: 可以自定义取出一个batch数据的格式.
在collate_fn的使用过程中, 只输入data有时候是非常不方便的, 需要额外的参数来传递其他变量.
info = args.info # info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))
这里巧用lambda函数, 相当于使用collate_fn函数再定义了一个匿名函数.
class collater():
def __init__(self, *params):
self. params = params
def __call__(self, data):
'''在这里重写collate_fn函数'''
实例化该类
collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)