Pytorch之Dataloader参数collate_fn研究

前言

    之前看了不到pytorch代码,对Dataloader的大部分参数都比较了解,今天看代码时,发现了一个参数collate_fn ,之前论文代码没怎么见过,也就自动忽略了,今天既然遇到了,就突然来了好奇心,想搞清楚用途及用法,以下为正文。

问题及实验

  1. 问题

今天看代码时出现如下问题:

Pytorch之Dataloader参数collate_fn研究_第1张图片

 对Dataloader参数中的collate_fn甚感好奇,故想一探究竟。

        2. 实验

1):myDataset()类中__getitem__方法返回的数据

代码如下:

Pytorch之Dataloader参数collate_fn研究_第2张图片

测试结果:

Pytorch之Dataloader参数collate_fn研究_第3张图片

 可见myDataset()类中__getitem__方法返回值为两个,网络的输入数据为128x40的tensor,输出是个分类标签数据。

2):Dataloader 运行过程

过程:首先Dataloader 会根据batch参数生成一个长度为batch值的列表,列表的值是myDataset()类中__getitem__()的参数,如果shuffle为True ,列表的值就是从0到len(data)中随机抽样索引。然后列表的索引值会依次送入__getitem__()方法,最终返回一个列表的数据,该列表数据会作为collate_fn 函数的参数传入,最终得到一个batch的数据。其中collate_fn 函数可以使用系统默认的也可以使用自己设计的,非常灵活。

debug验证:

程序:

Pytorch之Dataloader参数collate_fn研究_第4张图片

debug1:

Pytorch之Dataloader参数collate_fn研究_第5张图片

进入self._next_data():

Pytorch之Dataloader参数collate_fn研究_第6张图片

index 即为根据batch参数得到的列表:

Pytorch之Dataloader参数collate_fn研究_第7张图片

 进入fetch函数:

Pytorch之Dataloader参数collate_fn研究_第8张图片

 debug如下:

Pytorch之Dataloader参数collate_fn研究_第9张图片

即fetch函数通过传入index列表得到一个新的列表数据,然后该列表数据通过collate_fn()函数得到最终数据。

--------------------------------------------------------------分割线----------------------------------------------------------

以下无关Dataloarder的使用方法,仅仅是研究以下这里的自定义collate_fn函数的功能。

进入collate_fn函数:

Pytorch之Dataloader参数collate_fn研究_第10张图片

 关于zip的拆包功能研究:

Pytorch之Dataloader参数collate_fn研究_第11张图片

测试结果:

Pytorch之Dataloader参数collate_fn研究_第12张图片

通过pad_sequence得到最终结果:

Pytorch之Dataloader参数collate_fn研究_第13张图片

参考文章:

https://blog.csdn.net/dong_liuqi/article/details/114521240

 

 

 

 

 

 

 

 

 

 

 

 

        

你可能感兴趣的:(pytorch)