关于DataLoader中的collate_fn

关于如何构建自己的dataset, dataloader,请看博客:pytorch -- 构建自己的Dateset,DataLoader如何使用_无脑敲代码,bug漫天飞的博客-CSDN博客

        下面内容来自博客:pytorch中collate_fn函数的使用&如何向collate_fn函数传参_XJTU-Qidong的博客-CSDN博客_collate_fn pytorch

        然后明确dataloader返回数据的形式是怎样的; 是以batch的形式取数据,是取出等同于batch size的index列表,然后将列表中的index输入到dataset的getitem()中,按照index来获取数据,之后对每个index对应的数据进行堆叠, 就形成了一个batch的数据,并返回;

        但是在堆叠数据的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠,因此在构建dataset数据集时这样就应该处理好.

        collate-fn函数就是手动将抽取出的样本堆叠起来的函数;

        (  自定义数据堆叠过程;自定义batch数据的输出形式

        collate_fn函数是实例化dataloader的时候, 以函数形式传递给loader.

collate_fn是以函数作为参数进行传递, 那么其一定有默认参数:

        默认参数就是getitem函数返回的数据项的batch形成的列表.

        先假设dataset类为:

class testData(Dataset):
	def __init__(self):
		super().__init__()
		
	def __getitem__(self, index):
		return x, y

        假设的dataset返回两个数据项: x和y. 那么, 传入collate_fn的参数定义为data, 则其shape为(batch_size, 2,…).知道了输入参数的形式, 就可以

去定义collate_fn函数了:

def collate_fn(data):
	for unit in data:
		unit_x.append(unit[0])
		unit_y.append(unit[1])
		...
	return {x: torch.tensor(unit_x),  y: torch.tensor(unit_y)}

最后返回的是一个字典. 这也是collate_fn函数最大的一个好处: 可以自定义取出一个batch数据的格式.

如何给collate_fn函数传参

在collate_fn的使用过程中, 只输入data有时候是非常不方便的, 需要额外的参数来传递其他变量.

使用lambda函数

info = args.info	# info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))

这里巧用lambda函数, 相当于使用collate_fn函数再定义了一个匿名函数.

创建可被调用的类

class collater():
	def __init__(self, *params):
		self. params = params
	
	def __call__(self, data):
		'''在这里重写collate_fn函数'''

        实例化该类

collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)

你可能感兴趣的:(编程,pytorch,python)