PyTorch笔记6-mini batch

本系列笔记为莫烦PyTorch视频教程笔记 github源码

概要

Torch 中提供了一种整理数据结构的好东西,叫做 DataLoader,可以用来包装自己的数据,进行批训练,而且批训练可以有多种途径

import torch
import torch.utils.data as Data

torch.manual_seed(1)     # reproducible

DataLoader

DataLoader 是 Torch 用来包装数据的工具,如将 numpy array 等数据形式转成 Tensor,然后放进该包装器中。使用 DataLoader 可以有效迭代数据。
Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.参见
下面演示

MINIBATCH_SIZE = 5    # mini batch size
x = torch.linspace(1, 10, 10)  # torch tensor
y = torch.linspace(10, 1, 10)

# first transform the data to dataset can be processed by torch
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
# put the dataset into DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=MINIBATCH_SIZE,
    shuffle=True,
    num_workers=2           # set multi-work num read data
)

for epoch in range(3):
    # 1 epoch go the whole data
    for step, (batch_x, batch_y) in enumerate(loader):
        # here to train your model
        print('\n\n epoch: ', epoch, '| step: ', step, '| batch x: ', batch_x.numpy(), '| batch_y: ', batch_y.numpy())
 epoch:  0 | step:  0 | batch x:  [  4.   2.  10.   7.   3.] | batch_y:  [ 7.  9.  1.  4.  8.]


 epoch:  0 | step:  1 | batch x:  [ 1.  8.  6.  9.  5.] | batch_y:  [ 10.   3.   5.   2.   6.]


 epoch:  1 | step:  0 | batch x:  [ 10.   2.   7.   3.   4.] | batch_y:  [ 1.  9.  4.  8.  7.]


 epoch:  1 | step:  1 | batch x:  [ 5.  6.  9.  1.  8.] | batch_y:  [  6.   5.   2.  10.   3.]


 epoch:  2 | step:  0 | batch x:  [  7.   1.   4.   3.  10.] | batch_y:  [  4.  10.   7.   8.   1.]


 epoch:  2 | step:  1 | batch x:  [ 5.  8.  2.  6.  9.] | batch_y:  [ 6.  3.  9.  5.  2.]

可以看出,每个 step 取出 MINIBATCH_SIZE(这里为5)个数据进行处理,而且每次 epoch 都是不一样的,可知数据进行了 shuffle

数据经过 mini-batch 后,比每次迭代只训练一个样本要快,因为这样可以利用 CPU 或 GPU 进行并行计算来 speed up,同时也比每次迭代全体样本要好,因为如果一次迭代全体样本来训练的话,那么每次迭代只能进行一次 forward propagation 和 backword propagation,会比较耗时,特别是对于数据量比较大时,会更糟糕

如果数据不能被 MINIBATCH_SIZE 整除会怎样呢?最后的 step 返回剩余的就好了,下面演示说明

MINIBATCH_SIZE = 8
# put the dataset into DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=MINIBATCH_SIZE,
    shuffle=True,
    num_workers=2           # set multi-work num read data
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('\n\n epoch: ', epoch, '| step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
 epoch:  0 | step:  0 | batch x:  [ 9.  1.  3.  8.  6.  7.  4.  2.] | batch y:  [  2.  10.   8.   3.   5.   4.   7.   9.]


 epoch:  0 | step:  1 | batch x:  [ 10.   5.] | batch y:  [ 1.  6.]


 epoch:  1 | step:  0 | batch x:  [  7.   1.   8.   3.   9.   2.   5.  10.] | batch y:  [  4.  10.   3.   8.   2.   9.   6.   1.]


 epoch:  1 | step:  1 | batch x:  [ 4.  6.] | batch y:  [ 7.  5.]


 epoch:  2 | step:  0 | batch x:  [  3.   1.   9.   6.   5.   7.  10.   2.] | batch y:  [  8.  10.   2.   5.   6.   4.   1.   9.]


 epoch:  2 | step:  1 | batch x:  [ 4.  8.] | batch y:  [ 7.  3.]

你可能感兴趣的:(深度学习,PyTorch)