torch 单机多卡训练

最近在尝试用torch单机多卡进行训练。 网上有很多方法,有的讲的也很详细,但是torch版本更新的还是很快的。所以自己也踩了很多坑。在这里记录下来,希望对大家有帮助。
本文适用torch版本:1.10

torch单机多gpu训练有两种方式

torch.nn.DataParallel

torch.nn.DataParallel(
	module, 
	device_ids=None, 
	output_device=None, 
	dim=0)

这一种方式是比较老的,官方现在也不太推荐,相对来说,这种方式的速度并不算太快。
但是使用起来还算简单,需要修改的代码并不多。
在训练RNN的时候,坑非常多。可以看看博客:RNN与torch DataParallel的爱恨情仇
下面介绍一下相对于普通的训练,这种方式需要进行哪些修改。
使用方式:

from torch.nn import DataParallel
...
model = nn.DataParallel(model, device_ids=devices).to(device)
...

没错,只需要修改一行即可。

torch.nn.parallel.DistributedDataParallel

torch.nn.parallel.DistributedDataParallel(
	module,
 	device_ids=None, 
 	output_device=None, 
 	dim=0, 
 	broadcast_buffers=True, 
 	process_group=None, 
 	bucket_cap_mb=25, 
 	find_unused_parameters=False, 
 	check_reduction=False, 
 	gradient_as_bucket_view=False, 
 	static_graph=False)

这种办法是现在最普遍最常用的,因为其性能最好,效果也最佳。使用方法相较上一个来说比较复杂。但是网上的教程都比较过时。

import os
local_rank = int(os.environ["LOCAL_RANK"])

首先,在所有代码之前,先添加这两句。现在torch distributed的启动器已经由之前的torch.distributed.launch改为torchrun了。

然后在数据集构建的地方添加上面这一段。请使用torch.nn.utils.Datasettorch.nn.utils.Dataloader进行数据集的构建,这样会极大提高数据的加载速度,从而提高模型的训练速度。

请注意,在Dataloader指定sampler时,不能指定shuffle

train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(dataset=dataset,
                    batch_size=batch_size,
                    # shuffle=True,
                    num_workers=8,
                    pin_memory=True,
                    sampler=train_sampler
                    )

然后在模型进行训练前添加:

torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
torch.distributed.init_process_group(backend="nccl", init_method='env://')
model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

然后,运行你的main.py:

torchrun --nproc_per_node=4 main.py train

然后直接起飞,飞一般的速度,真的比DataParallel快很多(十倍以上)。

源码

Github

References

  • https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  • https://pytorch.org/docs/stable/elastic/run.html#launcher-api
  • https://blog.csdn.net/laizi_laizi/article/details/115299263

你可能感兴趣的:(pytorch,深度学习,人工智能,分布式)