PyTorch学习笔记(八)——将数据分批训练

代码和解释如下,最后附上了输出结果:

# coding=gbk
import torch
import torch.utils.data as Data #将数据分批次需要用到它

torch.manual_seed(1)    # 种子,可复用
BATCH_SIZE = 8 #设置批次大小

x = torch.linspace(1, 15, 15)       # 1到15共15个点
y = torch.linspace(15, 1, 15)       # 15到1共15个点

torch_dataset = Data.TensorDataset(x, y) #将x,y读取,转换成Tensor格式
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # 最新批数据
    shuffle=True,               # 是否随机打乱数据
    num_workers=2,              # 用于加载数据的子进程
)

def show_batch():
    for epoch in range(3):   # 对整个数据集进行3次培训
        for step, (batch_x, batch_y) in enumerate(loader):  # 每个训练步骤
            # 此处省略一些训练数据步骤...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())

if __name__ == '__main__':
    show_batch()
'''
(1)每次训练5个数据,打乱数据;每进行一次完整的训练需要进行3个训练步骤:
Epoch:  0 | Step:  0 | batch x:  [10. 12.  9.  5.  1.] | batch y:  [ 6.  4.  7. 11. 15.]
Epoch:  0 | Step:  1 | batch x:  [ 7. 15.  8. 13.  3.] | batch y:  [ 9.  1.  8.  3. 13.]
Epoch:  0 | Step:  2 | batch x:  [ 2.  6. 14.  4. 11.] | batch y:  [14. 10.  2. 12.  5.]
Epoch:  1 | Step:  0 | batch x:  [ 3. 10.  8. 13.  2.] | batch y:  [13.  6.  8.  3. 14.]
Epoch:  1 | Step:  1 | batch x:  [ 5.  4. 12. 14.  1.] | batch y:  [11. 12.  4.  2. 15.]
Epoch:  1 | Step:  2 | batch x:  [15.  9. 11.  6.  7.] | batch y:  [ 1.  7.  5. 10.  9.]
Epoch:  2 | Step:  0 | batch x:  [ 8.  7.  3. 10. 12.] | batch y:  [ 8.  9. 13.  6.  4.]
Epoch:  2 | Step:  1 | batch x:  [ 6. 13.  9.  4. 15.] | batch y:  [10.  3.  7. 12.  1.]
Epoch:  2 | Step:  2 | batch x:  [14.  2.  5.  1. 11.] | batch y:  [ 2. 14. 11. 15.  5.]
(2)每次训练8个数据,打乱数据;每进行一次完整的训练需要进行2个训练步骤,一次8个数据,一次7个数据:
Epoch:  0 | Step:  0 | batch x:  [10. 12.  9.  5.  1.  7. 15.  8.] | batch y:  [ 6.  4.  7. 11. 15.  9.  1.  8.]
Epoch:  0 | Step:  1 | batch x:  [13.  3.  2.  6. 14.  4. 11.] | batch y:  [ 3. 13. 14. 10.  2. 12.  5.]
Epoch:  1 | Step:  0 | batch x:  [ 3. 10.  8. 13.  2.  5.  4. 12.] | batch y:  [13.  6.  8.  3. 14. 11. 12.  4.]
Epoch:  1 | Step:  1 | batch x:  [14.  1. 15.  9. 11.  6.  7.] | batch y:  [ 2. 15.  1.  7.  5. 10.  9.]
Epoch:  2 | Step:  0 | batch x:  [ 8.  7.  3. 10. 12.  6. 13.  9.] | batch y:  [ 8.  9. 13.  6.  4. 10.  3.  7.]
Epoch:  2 | Step:  1 | batch x:  [ 4. 15. 14.  2.  5.  1. 11.] | batch y:  [12.  1.  2. 14. 11. 15.  5.]
'''

 

你可能感兴趣的:(人工智能,Pytorch学习笔记)