PicklingError: Can‘t pickle <function <lambda>...attribute lookup <lambda> on __main__ failed

PicklingError: Can’t pickle : attribute lookup on main failed

环境:win10, python3.7, jupyter notebook
报错来源:20天吃掉那只pytorch-DAY2
报错原因:cell3 + cell4 ⬅ torch.utils.data

一、cell3 + cell4

# cell3
def target_transform(t):
    return torch.tensor([t]).float()

ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())

print(ds_train.class_to_idx)

# cell4
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)

cell3中datasets.ImageFolder使用了lambda函数,cell4中num_workers被设置为3,这两个因素共同作用导致报错。

二、torch.utils.data

torch.utils.data.DataLoader默认采用单进程(主进程)来加载数据,但可以通过num_workers设置同时使用几个子进程,num_workers=0表示只使用主进程。这里的workers由pytorch提供,其实现依赖于python的multiprocessing,其实现在windows下和unix下是不同的。

  • unix下默认采用fork(),子进程通过从父进程那里继承来的地址空间直接访问dataset和代码中其他带参数的函数
  • windows下默认采用spawn(),这时候会另起一个python解释器来执行主代码(main script),之后通过pickle序列化接收dataset, collate_fn以及其他参数来执行主代码内部需要由workers来执行的代码

采用spawn()的时候,worker_init_fn参数不能为unpicklable对象,例如lambda函数。

三、其他

  1. 由于spawn()的存在,linux下不会报错的代码在windows下可能会报错,所以在windows下使用多进程加载数据的时候要注意:
    (1)python脚本的主代码应该放在if __name__ == '__main__':内,这样它们就不会在worker子进程启动的时候再次运行,而DataLoader的构造是不需要被重复执行的,所以这部分代码也应该放在这里(比如cell4)
    (2)任何自定义的内容,包括参数collate_fn, worker_init_fn或dataset的具体代码(通常是函数的形式)则要放在__main__外面(比如cell3)。在pickle序列化的过程中对于函数传递的是引用而非二进制代码,所以要使worker子进程正常工作,这一步是必须的。
  2. torchvision.transforms.ToTensor,把读入的图片转为tensor
  3. torchvision.transforms.Compose,把对读入图片的各种处理组合起来

四、解决办法

方法一:依旧是在jupyter notebook环境中,把num_workers=3改为num_workers=0
方法二:由于cell3中lambda函数的存在,无论是notebook还是脚本中的__main__都不能实现子进程num_workers>0,所以把cell3改为:

ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
            transform = transform_train)
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
            transform = transform_train)

print(ds_train.class_to_idx)
# 也就是把target_transform = lambda ...删掉
# 可以打印一下ds_train[0]和type(ds_train[0][1])看一下

再把cell5改为:

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

# 查看部分样本
from matplotlib import pyplot as plt 

plt.figure(figsize=(8,8)) 
for i in range(9):
    img,label = ds_train[i]
    img = img.permute(1,2,0)
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()
# 也就是把label.item()改为label

这里要说一下target_transform,如果说错了敬请指正。这个参数网上资料比较少,官网的说明也不清楚,作用应该是对类别,也就是ds_train[0][1]作转换,如果没有这个,可以看到ds_train[0][1]的类型是int,那么cell5中会报“int 没有 item()”的错误,加上这个以后就变成了tensor,再取item()实际上就是得到int(或者按原文是float,因为原文在tensor([t])后面加了.float()),所以感觉这一步比较多余,不知道有没有人可以在评论里告诉我为什么要这么做?

你可能感兴趣的:(Python,Deep,Learning,pytorch)