Pytorch自定义Dataset和DataLoader去除不存在和空的数据,batch中存在none

报错信息:‘NoneType’ object has no attribute ‘numel’
或 TypeError: batch must contain tensors, numbers, dicts or lists; found
不同版本报错信息不同,原因是一样的
版本python3.6
pytorch 1.4.0
torchvision 0.2.1

原因:batch中含有none数据。

测试时,将url获取写在了自定义dataloader中,url获取失败返回none。
提供两种解决办法,第二种可以较好的解决我的问题

一种解决办法:自定义collate_fn

def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据

# num_workeres不为0会报错,不清楚原因,只能改为单线程运行
test_loader = DataLoader(dataset=test_set, collate_fn = my_collate_fn, num_workers=0, batch_size=15, shuffle=False)

以下内容为转载,转载自https://blog.csdn.net/guyuealian/article/details/91129367###
【源码GitHub地址】:https://github.com/PanJinquan/pytorch-learning-tutorials/tree/master/image_classification/utils

另外一种较好的解决办法是,在collate_fn源码中做改动,找到dataset_collate.py文件(报错时会显示该文件位置)
修改collate_fn函数
添加

    # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
    if isinstance(batch, list):
        batch = [(image, image_id) for (image, image_id) in batch if image is not None]
    if batch==[]:
        return (None,None)

现在collate_fn函数如下所示:

def collate_fn(batch):
    '''
     collate_fn (callable, optional): merges a list of samples to form a mini-batch.
     该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时,
     默认的default_collate校队方法会出现错误
     一种的解决方法是:
     判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
    :param batch:
    :return:
    '''
    r"""Puts each data field into a tensor with outer dimension batch size"""
    # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
    if isinstance(batch, list):
        batch = [(image, image_id) for (image, image_id) in batch if image is not None]
    if batch==[]:
        return (None,None)
 
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))
 
            return collate_fn([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)#ok
        return [collate_fn(samples) for samples in transposed]
 
    raise TypeError((error_msg_fmt.format(type(batch[0]))))

你可能感兴趣的:(bug,深度学习)