本系列笔记为莫烦PyTorch视频教程笔记 github源码
Torch 中提供了一种整理数据结构的好东西,叫做 DataLoader,可以用来包装自己的数据,进行批训练,而且批训练可以有多种途径
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
DataLoader 是 Torch 用来包装数据的工具,如将 numpy array 等数据形式转成 Tensor,然后放进该包装器中。使用 DataLoader 可以有效迭代数据。
Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.参见
下面演示
MINIBATCH_SIZE = 5 # mini batch size
x = torch.linspace(1, 10, 10) # torch tensor
y = torch.linspace(10, 1, 10)
# first transform the data to dataset can be processed by torch
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
# put the dataset into DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=MINIBATCH_SIZE,
shuffle=True,
num_workers=2 # set multi-work num read data
)
for epoch in range(3):
# 1 epoch go the whole data
for step, (batch_x, batch_y) in enumerate(loader):
# here to train your model
print('\n\n epoch: ', epoch, '| step: ', step, '| batch x: ', batch_x.numpy(), '| batch_y: ', batch_y.numpy())
epoch: 0 | step: 0 | batch x: [ 4. 2. 10. 7. 3.] | batch_y: [ 7. 9. 1. 4. 8.]
epoch: 0 | step: 1 | batch x: [ 1. 8. 6. 9. 5.] | batch_y: [ 10. 3. 5. 2. 6.]
epoch: 1 | step: 0 | batch x: [ 10. 2. 7. 3. 4.] | batch_y: [ 1. 9. 4. 8. 7.]
epoch: 1 | step: 1 | batch x: [ 5. 6. 9. 1. 8.] | batch_y: [ 6. 5. 2. 10. 3.]
epoch: 2 | step: 0 | batch x: [ 7. 1. 4. 3. 10.] | batch_y: [ 4. 10. 7. 8. 1.]
epoch: 2 | step: 1 | batch x: [ 5. 8. 2. 6. 9.] | batch_y: [ 6. 3. 9. 5. 2.]
可以看出,每个 step 取出 MINIBATCH_SIZE(这里为5)个数据进行处理,而且每次 epoch 都是不一样的,可知数据进行了 shuffle
数据经过 mini-batch 后,比每次迭代只训练一个样本要快,因为这样可以利用 CPU 或 GPU 进行并行计算来 speed up,同时也比每次迭代全体样本要好,因为如果一次迭代全体样本来训练的话,那么每次迭代只能进行一次 forward propagation 和 backword propagation,会比较耗时,特别是对于数据量比较大时,会更糟糕
如果数据不能被 MINIBATCH_SIZE 整除会怎样呢?最后的 step 返回剩余的就好了,下面演示说明
MINIBATCH_SIZE = 8
# put the dataset into DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=MINIBATCH_SIZE,
shuffle=True,
num_workers=2 # set multi-work num read data
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('\n\n epoch: ', epoch, '| step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
epoch: 0 | step: 0 | batch x: [ 9. 1. 3. 8. 6. 7. 4. 2.] | batch y: [ 2. 10. 8. 3. 5. 4. 7. 9.]
epoch: 0 | step: 1 | batch x: [ 10. 5.] | batch y: [ 1. 6.]
epoch: 1 | step: 0 | batch x: [ 7. 1. 8. 3. 9. 2. 5. 10.] | batch y: [ 4. 10. 3. 8. 2. 9. 6. 1.]
epoch: 1 | step: 1 | batch x: [ 4. 6.] | batch y: [ 7. 5.]
epoch: 2 | step: 0 | batch x: [ 3. 1. 9. 6. 5. 7. 10. 2.] | batch y: [ 8. 10. 2. 5. 6. 4. 1. 9.]
epoch: 2 | step: 1 | batch x: [ 4. 8.] | batch y: [ 7. 3.]