Dataset与Dataloader学习记录

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 19 and 1 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689

在使用DataLoader中,遇到了一个问题

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 19 and 1 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689

我的初始DataLoader代码是:

dataset=mydata(label_path)
batch_iterator=iter(DataLoader(dataset,batchsize,shuffle=True,num_workers=num_workers))

batch_size为32,数据集的总长度为16098
报错说明:batchsize在纬度1上不匹配
查看其他人的博客,说原因有两个:
你输入的图像数据的维度不完全是一样的,比如是训练的数据有100组,其中99组是256×256,但有一组是384×384,这样会导致Pytorch的检查程序报错
另外一个则是比较隐晦的batchsize的问题,Pytorch中检查你训练维度正确是按照每个batchsize的维度来检查的,比如你有1000组数据(假设每组数据为三通道256px256px的图像),batchsize为4,那么每次训练则提取(4,3,256,256)维度的张量来训练,刚好250个epoch解决(2504=1000)。但是如果你有999组数据,你继续使用batchsize为4的话,这样999和4并不能整除,你在训练前249组时的张量维度都为(4,3,256,256)但是最后一个批次的维度为(3,3,256,256),Pytorch检查到(4,3,256,256) != (3,3,256,256),维度不匹配,自然就会报错了,这可以称为一个小bug。

那么怎么解决,针对第一种,很简单,整理一下你的数据集保证每个图像的维度和通道数都一直即可。第二种来说,挑选一个可以被数据集个数整除的batchsize或者直接把batchsize设置为1即可。
具体解释链接

我使用的WIDER数据集,因此图片均为三通道,大小不一样,首先将图片resize成同样大小,依然报错
接着将batchsize改为1,确实不再报错了,不过batchsize改为1,在实际训练中不太合理,因此寻找其他解决办法。

在DataLoader中,除去上面的几个参数,还有其他的一些参数
其中的collate_fn参数,参考其他人的解释,说明如下:
一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label
collate_fn详细解释链接
因此修改代码

from data.data_iter1 import mydata,detection_collate
batch_iterator=iter(DataLoader(dataset,batchsize,shuffle=True,num_workers=num_workers,collate_fn=detection_collate))
def detection_collate(batch):
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, tup in enumerate(sample):
            if torch.is_tensor(tup):
                imgs.append(tup)
            elif isinstance(tup, type(np.empty(0))):
                annos = torch.from_numpy(tup).float()
                targets.append(annos)

    return (torch.stack(imgs, 0), targets)

问题解决

你可能感兴趣的:(pytorch学习记录)