欢迎来到本期神经网络编程系列。在本集中,我们将看到如何利用PyTorch DataLoader
类的多进程功能来加快神经网络训练过程。
为了加快训练过程,我们将利用DataLoader
类的num_workers
可选属性。
num_workers
属性告诉DataLoader
实例要使用多少个子进程进行数据加载。默认情况下,num_workers
值被设置为0,0值代表告诉加载器在主进程内部加载数据。
这意味着训练进程将在主进程内部依次工作。在训练过程中使用一批批处理之后,我们从磁盘上读取另一批批处理数据。
现在,如果我们有一个工作进程,我们可以利用我们的机器有多个内核这一事实。这意味着,在主进程准备好另一个批处理的时候,下一个批处理已经可以加载并准备好了。这就是速度提升的原因。批批处理使用附加的辅助进程加载,并在内存中排队。
随之而来的问题是,我们应该添加多少个工作进程?这里有很多因素可以影响最佳数量,因此最好的方法就是测试。
为了设置这个测试,我们将创建一个num_workers
值的列表来尝试。我们将尝试以下值:
对于这些值,我们将通过尝试以下值来改变批次大小。
对于学习率,我们将在所有的运行中保持在0.01的恒定值。
最后要提到的是,这里的设置是-我们只为每个运行做一个单一的epoch
。
好了,让我们看看我们得到了什么。
num_workers
值。结果:好了,我们可以看到下面的结果。我们总共完成了十八次运行。我们有三组不同的批量大小,在每个组中,我们改变了工作进程的数量。
params = OrderedDict(
lr = [.01]
,batch_size = [100,1000,10000]
,num_workers = [0,1,2,4,8,16]
#,shuffle = [True, False]
)
m = RunManager()
for run in RunBuilder.get_runs(params):
network = Network()
loader = DataLoader(train_set,batch_size = run.batch_size,num_workers = run.num_workers)
optimizer = optim.Adam(network.parameters(),lr = run.lr)
m.begin_run(run,network,loader)
for epoch in range(1):
m.begin_epoch()
for batch in loader:
images,labels = batch #get batch
preds = network(images)#pass batch
loss = F.cross_entropy(preds,labels) #calculate loss
optimizer.zero_grad() #zero gradients
loss.backward() #calculate gradients
optimizer.step() # update weights
m.track_loss(loss)
m.track_num_correct(preds,labels)
m.end_epoch()
m.end_run()
m.save('results')
从这些结果中得到的主要结论是,在所有三个批次规模中,除了主流程外,拥有一个单一的工作流程可使速度提高约百分之二十。
此外,在第一个流程之后增加额外的工作流程并没有真正显示出任何进一步的改进。
在增加一个工作流程后,我们看到的20%的加速是有意义的,因为主流程要做的工作较少。
当主进程忙于执行前向和后向传递时,工作进程正在加载下一个批次。当主进程准备好另一个批次的时候,工作进程已经在内存中排好了队。
因此,主进程不必从磁盘读取数据。相反,数据已经在内存中,这使我们的速度提高了20%。
现在,为什么我们在添加更多的工作者后没有看到额外的速度提升呢?
我们会如果一个worker足以让主进程的队列充满数据,那么向队列中添加更多的数据批次是不会有任何作用的。这就是我认为我们在这里看到的情况。
仅仅因为我们向队列中添加了更多的批次,并不意味着这些批次的处理速度更快。因此,我们受制于前向和后向传播一个给定批次所需的时间。
我们甚至可以看到,当我们到达16个工作进程时,事情开始陷入僵局。
希望这能帮助你加快进度!
欢迎来到本期神经网络编程系列。在本集中,我们将看到如何利用PyTorch DataLoader
类的多进程功能来加快神经网络训练过程。
DataLoader
支持映射样式和迭代样式的数据集,具有单进程或多进程加载、自定义加载顺序以及可选的自动批处理(排序)和内存固定。
num_workers
(int, optional) - 有多少个子进程用于数据加载。0表示将在主进程中加载数据。(默认:0) 只是节省了我们从磁盘中读取批处理数据的时间。
英文原文链接是:https://deeplizard.com/learn/video/kWVgvsejXsE