pytorch中的minibatch

日期:2020-2-13

pytorch中的torch.utils.data这个库可以非常好的对数据实现批处理
主要用到2个函数

import torch
import torch.utils.data as Data
Data.TensorDataset()#设置数据集,数据与标签相对应
Data.DataLoader()#传入数据集,设置批处理大小,是否打乱数据,顺序,线程数

举个例子

#batch training
import torch
import torch.utils.data as Data
#设置批量大小和原始数据
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
#设置数据集
torch_dataset = Data.TensorDataset(x, y)
#传入数据集
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=BATCH_SIZE,
 shuffle=True,#shuffle的英文含义是洗牌,=true意味着每回顺序是打乱的
 num_workers=2,#2个“工人”
 )
 #开始训练
if __name__ == '__main__':
#训练3次,每次训练2批,顺序是打乱的
 for epoch in range(3):
  for step, (batch_x, batch_y) in enumerate(loader):
   #training
   print('Epoch:',epoch, '| Step:', step, '| batch x:',
        batch_x.numpy(), '| batch_y', batch_y.numpy())

输出结果:
pytorch中的minibatch_第1张图片

你可能感兴趣的:(pytorch初学笔记,深度学习,神经网络,python)