collate_fn 使用详解

collate_fn 参数

当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。

default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。所以在我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种你也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。

以下是文字识别时,文本行图像长度不一,需要自定义整理。

class AlignCollate(object):
    """将数据整理成batch"""
    def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio_with_pad = keep_ratio_with_pad

    def __call__(self, batch):# 有可能__getitem__返回的图像是None, 所以需要过滤掉
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        if self.keep_ratio_with_pad:  # same concept with 'Rosetta' paper
            resized_max_w = self.imgW
            input_channel = 3 if images[0].mode == 'RGB' else 1
            transform = NormalizePAD((input_channel, self.imgH, resized_max_w))

            resized_images = []
            for image in images:
                w, h = image.size
                ratio = w / float(h)
                # 图片的宽度大于设定的输入ingW
                if math.ceil(self.imgH * ratio) > self.imgW:
                    resized_w = self.imgW
                else:
                    resized_w = math.ceil(self.imgH * ratio)

                resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
                resized_images.append(transform(resized_image))
                # resized_image.save('./image_test/%d_test.jpg' % w)

            image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        else:
            transform = ResizeNormalize((self.imgW, self.imgH))
            image_tensors = [transform(image) for image in images]
            image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)

        return image_tensors, labels

再看个在做目标检测时自定义collate_fn函数,给每个图像添加索引

def collate_fn(self, batch):
        paths, imgs, targets = list(zip(*batch))
        # Remove empty placeholder targets  
        # 有可能__getitem__返回的图像是None, 所以需要过滤掉
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        # boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        targets = torch.cat(targets, 0)
        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # Resize images to input shape
        # 每个图像大小不同呢,所以resize到统一大小
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        self.batch_count += 1
        return paths, imgs, targets

其实也可以自定义collate_fn同时,结合使用默认的default_collate

from torch.utils.data.dataloader import default_collate  # 导入这个函数


def collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch

    """

    # 这一部分是对 batch 进行重新 “校对、整理”的代码
    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

Reference:

一文读懂Dataset, DataLoader及collate_fn, Sampler等参数

你可能感兴趣的:(Pytorch,batch,计算机视觉,深度学习)