pytorch笔记-batch

在学习莫烦大神的pytorch视频的batch部分,由于pytorch版本更新,产生了一些不兼容的情况。源代码如下:

import torch
import torch.utils.data as Data
torch.manual_seed(1) # 设定随机数种子


BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)

torch_dataset=Data.TensorDataset(data_tensor=x,target_tensor=y)
loader=Data.DataLoader(#变成小批数据
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,#每一组batch里面原数据个数
    shuffle=True,  #是否将原数据打乱分组
    num_workers=2
)

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        print('Epoch:',epoch)

直接运行会报错,是由于Data.TensorDataset()函数版本更新后接受参数为*tensor,不再设默认值,故只需将对应行改为:

torch_dataset=Data.TensorDataset(x,y)

但是会继续报错:
The “freeze_support()” line can be omitted if the program
is not going to be frozen to produce an executable.
只需把训练过程放在if name == ‘main’:以下即可。更正后代码:

import torch
import torch.utils.data as Data
torch.manual_seed(1) # 设定随机数种子


BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)

torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(#变成小批数据
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,#每一组batch里面原数据个数
    shuffle=True,  #是否将原数据打乱分组
    num_workers=2
)
if __name__ == '__main__':
    for epoch in range(3):   # 训练所有!整套!数据 3 次
        for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
            #  假设这里就是你训练的地方...
            #  打出来一些数据
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())


你可能感兴趣的:(pytorch笔记-batch)