pytorch 批数据训练

import torch
import torch.utils.data as Data

BATCH_SIZE = 8

x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

torch_dataset = Data.TensorDataset(data_tensor=x,target_tensor=y)
loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 2,
)

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        # training
        print('Epoch: ',epoch,
              '| Step: ',step,
              '| batch x: ',batch_x.numpy(),
              '| batch y: ',batch_y.numpy())

result:

Epoch:  0 | Step:  0 | batch x:  [  9.   5.   4.   8.  10.   1.   3.   6.] | batch y:  [  2.   6.   7.   3.   1.  10.   8.   5.]
Epoch:  0 | Step:  1 | batch x:  [ 2.  7.] | batch y:  [ 9.  4.]
Epoch:  1 | Step:  0 | batch x:  [  3.   7.   8.   4.  10.   2.   9.   6.] | batch y:  [ 8.  4.  3.  7.  1.  9.  2.  5.]
Epoch:  1 | Step:  1 | batch x:  [ 5.  1.] | batch y:  [  6.  10.]
Epoch:  2 | Step:  0 | batch x:  [ 1.  8.  2.  7.  3.  5.  6.  4.] | batch y:  [ 10.   3.   9.   4.   8.   6.   5.   7.]
Epoch:  2 | Step:  1 | batch x:  [  9.  10.] | batch y:  [ 2.  1.]

你可能感兴趣的:(pytorch 批数据训练)