PyTorch中collate_fn的应用

Pytorch中collate_fn的应用

在目标检测中,我们经常会用到Dataset这个类,它主要是由三个函数构成

def __init__(self, list_IDs, labels):
    'Initialization'
    self.labels = list_anno_path
    self.images = list_image_path
def __len__(self):
    'Denotes the total number of samples'
    return len(self.labels)
def __getitem__(self, index):
    'Generates one sample of data'
    # Select sample
    image =  read(self.images[index])

    # Load data and get label
    bbox = parse_anno(self.labels(index))

    return image, bbox

对于分类没有问题,每张图片只对应一个类别,但是对于检测来说,每张图片对应的目标数量不一致,就会导致在组batch时,尺寸不一致的问题。这时候我们就要用到自定义的collate_fn

loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=custom_collate)

参考yolov3对于collate_fn的设计,只需要在bbox中增加一维来标注bbox属于第几张图片即可,这样返回images是四维,bboxes是两维

def my_collate(batch):
    images  = []
    bboxes  = []
    for i, (img, box) in enumerate(batch):
        images.append(img)
        box[:, 0] = i
        bboxes.append(box)
            
    images  = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    bboxes  = torch.from_numpy(np.concatenate(bboxes, 0)).type(torch.FloatTensor)
    return images, bboxes

你可能感兴趣的:(深度学习相关(cs231n),pytorch,深度学习,python,batch,collate_fn)