参考资料:
https://pytorch.org/docs/stable/data.html#dataloader-collate-fn
https://blog.csdn.net/anshiquanshu/article/details/112868740
在使用Pytorch深度学习框架的时候,一定绕不开的就是dataset和dataloader,后者依赖于前者,并给出了高效加载数据的解决方案(多线程,batch训练等)。
以RGB图片为例,dataset出来的数据形状是(3, H, W),而dataloader出来的数据形状是(batch_size, 3, H, W)。很明显,多了一维即batch维度。这显然是dataloader将数据给“叠”了起来。事实上,dataloader是有一个参数为collate_fn的,它的默认值是None,即当你在使用dataloader并不指定collate_fn的时候,实际上pytorch调用了默认的collate_fn函数,将数据“叠”起来之后再交给你。
然而,当你的数据是不定长的数据的时候,它就没有办法成功把数据叠起来,比如我就遇到了如下报错:
RuntimeError: stack expects each tensor to be equal size, but got [2, 4] at entry 0 and [5, 4] at entry 1
一个数据长度为2,一个数据长度为5,显然无法直接stack?此时在面对不定长数据的时候就需要自定义collate_fn进行填充了。譬如,pytorch文档上有这么一段话:
A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch.
那么,如何自定义一个collate_fn?这个collate_fn的输入和输出又是什么?我们来看一下这个例子:
def padding_collate_fn(data_batch):
batch_bbox_list = [item['bbox'] for item in data_batch]
batch_label_list = [item['label'] for item in data_batch]
batch_filename_list = [item['filename'] for item in data_batch]
padding_bbox = pad_sequence(batch_bbox_list, batch_first=True, padding_value=0)
padding_label = pad_sequence(batch_bbox_list, batch_first=True, padding_value=5)
result = dict()
result["bbox"] = padding_bbox
result["label"] = padding_label
result["filename"] = batch_filename_list
return result
首先我原始的dataset输出是一个字典,上述代码就是把字典中的值取出来再叠起来,最后放到大字典中返回。其中pad_sequence这个函数在torch.nn.utils.rnn这个包里,很好用。
实际上,batch就是你的dataset[index] ~ dataset[index + batch_size] 构成的列表,知道这一点后问题就迎刃而解了。