pytorch dataloader collate_fn 在复杂label情况下的使用技巧

pytorch dataset collect_fn 在复杂label情况下的使用技巧

  • 在如目标检测等任务的训练过程中,label可能出现shape不一致的情况,以目标检测为例,一个batch的不同图片,可能有不同数量的bounding box,如果将bbox以(n,5)的shape的张量形式返回,n的数量就不统一,在使用默认的collect_fn时,pytorch的dataloader就会报错,这时就需要专门写一个collect_fn,如下:
import torch.utils.data as data
from PIL import Image
import os

class my_dataloader_with_det_label(data.Dataset):
	def __init__(self, images_path):
		self.train_list = [os.path.join(images_path, i) for i in os.listdir(images_path)]
		self.size = 256
		self.data_list = self.train_list
		print("Total training examples:", len(self.train_list))

	def __getitem__(self, index):
		# get image
		data_path = self.data_list[index]
		data = Image.open(data_lowlight_path).convert('RGB')
		data = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
		data = (np.asarray(data_lowlight)/255.0) 
		data = torch.from_numpy(data_lowlight).float()
		# get label
		label_path = data_path.replace('images', 'labels')
		label_path = label_path[:-3] + 'txt'
		with open(label_path) as f:
			l = [x.split() for x in f.read().strip().splitlines() if len(x)]
		nl = len(l)
		label = np.zeros((nl, 6))
		label[:, 1:] = np.array(l, dtype=np.float32)
		return data.permute(2,0,1), torch.from_numpy(label).float()
		
	def __len__(self):
		return len(self.data_list)
		
	@staticmethod
	def collate_fn(batch):
		img, label = zip(*batch)  # transposed
		for i, l in enumerate(label):
			l[:, 0] = i  # add target image index for build_targets()
		return {'img': torch.stack(img, 0), 'label': torch.cat(label, 0)}
		
train_dataset = my_dataloader_with_det_label(images_path)	
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, collate_fn=train_dataset.collate_fn)
  • 解释一下,collate_fn之所以这么写,是因为dataloader在返回一个batch的数据时,首先会调用 batch_size 次 dataset的__getitem__()方法,然后将 batch_size 个返回值作为一个列表送入collate_fn 函数,以打包成一个batch的数据返回出来。如果没有重写collate_fn函数,会用默认的collate_fn方法打包,这时label自动concatenate时会出现维度不匹配的错误。
  • 因此这里就用到了label 6列中的第一列。可能你会好奇为什么bbox明明只需要5个数据,label却是一个 n × 6 n\times6 n×6的矩阵。第2列到第6列装的就是label本身的5个数据,而第一列装的是一个batch中各个图片的索引。这里可以注意到img用的是stack方法,而label用的是cat方法,这样一来img就变成了(bs,c,h,w)而label却变成了(k,6),少了batch size的一维。这里k其实是一个batch的所有图片的所有bbox的总数,把不同图片的所有bbox都按顺序混在了一起,这样一来就不会出现维度不匹配的问题,但就会需要第一列来区分出哪些行是对batch中第一张图片的标注,哪些行是对第二张的标注,等等。然后在算loss的时候再根据label的第一维去对应出来就可以了。
  • 这个collate_fn其实是pytorch版的yolov3的库中借鉴来的,源码的地址在:

https://github.com/ultralytics/yolov3

你可能感兴趣的:(实用代码,pytorch,深度学习,python)