莫烦pytorch 批训练

批训练,即把完整数据分成数批分别进行训练。

DataLoader

import torch
import torch.utils.data as Data

torch.manual_seed(1)
BATCH_SIZE = 5

x = torch.linespace(1, 10, 10)
y = torch.linespace(10, 1, 10)

#转换成torch能识别的Dataset
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

loader = Data.DataLoader(
	dataset = torch_dataset,
	batch_size = BATCH_SIZE,
	shuffle = True,
	num_workers = 2
)

for epoch in range(3):
	for step, (batch_x, batch_y) in enumerate(loader):
		print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

结果:

"""
Epoch:  0 | Step:  0 | batch x:  [ 6.  7.  2.  3.  1.] | batch y:  [  5.   4.   9.   8.  10.]
Epoch:  0 | Step:  1 | batch x:  [  9.  10.   4.   8.   5.] | batch y:  [ 2.  1.  7.  3.  6.]
Epoch:  1 | Step:  0 | batch x:  [  3.   4.   2.   9.  10.] | batch y:  [ 8.  7.  9.  2.  1.]
Epoch:  1 | Step:  1 | batch x:  [ 1.  7.  8.  5.  6.] | batch y:  [ 10.   4.   3.   6.   5.]
Epoch:  2 | Step:  0 | batch x:  [ 3.  9.  2.  6.  7.] | batch y:  [ 8.  2.  9.  5.  4.]
Epoch:  2 | Step:  1 | batch x:  [ 10.   4.   8.   1.   5.] | batch y:  [  1.   7.   3.  10.   6.]
"""

函数剖析:

torch.manual_seed(args.seed) #为CPU设置种子用于生成随机
数,以使得结果是确定的 
if args.cuda: 
torch.cuda.manual_seed(args.seed)#为当前GPU设置随机种子
;如果使用多个GPU,应该使用torch.cuda.manual_seed_all()
为所有的GPU设置种子。
class torch.utils.data.TensorDataset(data_tensor, target_tensor)

包装数据和目标张量的数据集。
通过沿着第一个维度索引两个张量来恢复每个样本
参数:
#data_tensor(Tensor) — 包含样本数据
#target_tensor(Tensor) — 包含样本目标(标签)

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
参数:
#dataset(Dataset) — 从中加载数据的数据集。
#bactch_size(int, optional) — 批训练的数据个数(默认:1)
#shuffle(bool, optional) — 设置为True在每个epoch重新排列数据(默认值:False,一般打乱比较好)。
#sampler(Sampler, optional) — 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
#batch_sampler(sampler, 可选) — 和sampler一样,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last相互排斥。
#num_workers(int, 可选) — 用于数据加载的子进程数。0表示数据将在主进程中加载(默认值:0)
#collate_fn(callable, optional) — 合并样本列表以形成小批量。
#pin_memory(bool, optional) — 如果True,数据加载器在返回将张量复制到CUDA固定内存中。
#drop_last(bool, optional) — 如果数据集大小不能被batch_size整除,设置True可删除最后一个不完整的批处理。如果设为False并且数据集的大小不能被batch_size整除,则最后一个batch更小。(默认:false)

enumerate(sequence, [start=0])

参数:
#sequence — 一个序列、迭代器或其他支持迭代对象。
#start — 下标起始位置

实例:

>>>seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1))       # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]

你可能感兴趣的:(pytorch)