环境:win10, python3.7, jupyter notebook
报错来源:20天吃掉那只pytorch-DAY2
报错原因:cell3 + cell4 ⬅ torch.utils.data
# cell3
def target_transform(t):
return torch.tensor([t]).float()
ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
print(ds_train.class_to_idx)
# cell4
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)
cell3中datasets.ImageFolder使用了lambda函数,cell4中num_workers被设置为3,这两个因素共同作用导致报错。
torch.utils.data.DataLoader默认采用单进程(主进程)来加载数据,但可以通过num_workers设置同时使用几个子进程,num_workers=0表示只使用主进程。这里的workers由pytorch提供,其实现依赖于python的multiprocessing,其实现在windows下和unix下是不同的。
采用spawn()的时候,worker_init_fn参数不能为unpicklable对象,例如lambda函数。
if __name__ == '__main__':
内,这样它们就不会在worker子进程启动的时候再次运行,而DataLoader的构造是不需要被重复执行的,所以这部分代码也应该放在这里(比如cell4)__main__
外面(比如cell3)。在pickle序列化的过程中对于函数传递的是引用而非二进制代码,所以要使worker子进程正常工作,这一步是必须的。方法一:依旧是在jupyter notebook环境中,把num_workers=3改为num_workers=0
方法二:由于cell3中lambda函数的存在,无论是notebook还是脚本中的__main__
都不能实现子进程num_workers>0,所以把cell3改为:
ds_train = datasets.ImageFolder("./dataset/cifar2/train/",
transform = transform_train)
ds_valid = datasets.ImageFolder("./dataset/cifar2/test/",
transform = transform_train)
print(ds_train.class_to_idx)
# 也就是把target_transform = lambda ...删掉
# 可以打印一下ds_train[0]和type(ds_train[0][1])看一下
再把cell5改为:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
# 查看部分样本
from matplotlib import pyplot as plt
plt.figure(figsize=(8,8))
for i in range(9):
img,label = ds_train[i]
img = img.permute(1,2,0)
ax=plt.subplot(3,3,i+1)
ax.imshow(img.numpy())
ax.set_title("label = %d"%label)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
# 也就是把label.item()改为label
这里要说一下target_transform,如果说错了敬请指正。这个参数网上资料比较少,官网的说明也不清楚,作用应该是对类别,也就是ds_train[0][1]作转换,如果没有这个,可以看到ds_train[0][1]的类型是int,那么cell5中会报“int 没有 item()”的错误,加上这个以后就变成了tensor,再取item()实际上就是得到int(或者按原文是float,因为原文在tensor([t])后面加了.float()),所以感觉这一步比较多余,不知道有没有人可以在评论里告诉我为什么要这么做?