pytorch中collate_fn函数的使用&如何向collate_fn函数传参

1.为什么要使用collate_fn

这里先从dataset的运行机制讲起.

在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表; 然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据; 最后, 对每个index对应的数据进行堆叠, 就形成了一个batch的数据.

看这个视频可以理解得更透彻一些

⚠️ 在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据.

在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的.

⭐️ 此外, 某些优化方法是要对一个batch的数据进行操作.

collate-fn函数就是手动将抽取出的样本堆叠起来的函数

2.collate_fn的用法

loader = Dataloader(dataset, batch_size, shuffle, collate_fn, ...)

collate_fn函数是实例化dataloader的时候, 以函数形式传递给loader.

既然是collate_fn是以函数作为参数进行传递, 那么其一定有默认参数. 这个默认参数就是getitem函数返回的数据项的batch形成的列表.

先假设, datase类是如下形式:

class testData(Dataset):
	def __init__(self):
		super().__init__()
		
	def __getitem__(self, index):
		return x, y

可以看到, 假设的dataset返回两个数据项: x和y. 那么, 传入collate_fn的参数定义为data, 则其shape为(batch_size, 2,…).

知道了输入参数的形式, 就可以去定义collate_fn函数了:

def collate_fn(data):
	for unit in data:
		unit_x.append(unit[0])
		unit_y.append(unit[1])
		...
	return {x: torch.tensor(unit_x),  y: torch.tensor(unit_y)}

可以看到我对collate_fn函数的定义,最后返回的是一个字典. 这也是collate_fn函数最大的一个好处: 可以自定义取出一个batch数据的格式. 该函数的输出就是对dataloader进行遍历, 取出一个batch的数据.

3.如何给collate_fn函数传参

在collate_fn的使用过程中, 我发现只输入data有时候是非常不方便的, 需要额外的参数来传递其他变量.

这里有两个方法可以解决以上问题:

(1)使用lambda函数

info = args.info	# info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))

这里巧用lambda函数, 相当于使用collate_fn函数再定义了一个匿名函数.

(2)创建可被调用的类

class collater():
	def __init__(self, *params):
		self. params = params
	
	def __call__(self, data):
		'''在这里重写collate_fn函数'''
collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)

参考网址

4.总结

collate_fn的用处:

  • 自定义数据堆叠过程
  • 自定义batch数据的输出形式

collate_fn的使用

  • 定义一个以data为输入的函数
  • 注意, 输入输出分别域getitem函数和loader调用时对应

你可能感兴趣的:(Pytorch,pytorch)