新版TensorDataset 和 downloader中num_workers 的解决方案

Pytorch学习

小问题一:
#旧版
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
#会报错,init() got an unexpected keyword argument ‘data_tensor’
#新版
torch_dataset = Data.TensorDataset(x, y)

小问题二:
对于num_workers大于0是报错
Error1:
DataLoader worker (pid(s) 15464, 8000) exited unexpectedly
or
Error2:
BrokenPipeError: [Errno 32] Broken pipe
解决方法:
在运行的部分放进__main__()函数里

以下是报错的代码

import torch
import torch.utils.data as Data
 
# 虚构要训练的数据
x = torch.linspace(11, 20, 10)  # 在[11, 20]里取出10个间隔相等的数 (torch tensor)
y = torch.linspace(20, 11, 10)
 
 
BATCH_SIZE = 5  # 每批需要训练的数据个数
 
 
# 把tensor转换成torch能识别的数据集
#torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
#新版
torch_dataset = Data.TensorDataset(x, y)
 

```c
import torch
import torch.utils.data as Data

x = torch.linspace(11,20,10)
y = torch.linspace(20,11,10)


BATCH_SIZE = 5

torch_dataset = Data.TensorDataset(data_tensor = x, target_tensor = y)

loader = Data.DataLoader(
        
        dataset = torch_dataset,
        batch_size = BATCH_SIZE,
        shuffle = True,
        num_workers = 2
        )

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        print('Epoch:', epoch, '|Step:', step,  # Epoch表示哪一遍, Step表示哪一次
              'batch x:', batch_x.numpy(),
              'batch y:', batch_y.numpy(),
        )


正确代码如下


# -*- coding: utf-8 -*-
"""
Created on Thu Nov  7 20:23:36 2019

@author:"""

import torch
import torch.utils.data as Data
 
 
# 虚构要训练的数据
x = torch.linspace(11, 20, 10)  # 在[11, 20]里取出10个间隔相等的数 (torch tensor)
y = torch.linspace(20, 11, 10)
 
 
BATCH_SIZE = 5  # 每批需要训练的数据个数
 
 
# 把tensor转换成torch能识别的数据集
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
#新版
#torch_dataset = Data.TensorDataset(x, y)
 
 
# 把数据集放进数据装载机里
loader = Data.DataLoader(
    dataset=torch_dataset,  # 数据集
    batch_size=BATCH_SIZE,  # 每批需要训练的数据个数
    shuffle=True,  # 是否打乱取数据的顺序(打乱的训练效果更好)
    num_workers=2,  # 多线程读取数据
)
 
# 批量取出数据来训练
def main():
    for epoch in range(3):  # 把整套数据重复训练3for step, (batch_x, batch_y) in enumerate(loader):  # 每次从数据装载机里取出批量数据来训练
            # 以下为训练的地方
            # …………
            # 把每遍里每次取出的数据打印出来
            print('Epoch:', epoch, '|Step:', step,  # Epoch表示哪一遍, Step表示哪一次
                  'batch x:', batch_x.numpy(),
                  'batch y:', batch_y.numpy(),
            )
if __name__=="__main__":
    main()

你可能感兴趣的:(pytorch深度学习)