pytorch多GPU并行DistributedDataParallel应用和踩坑记录(本节-单机多卡实现)持续补充

一、 前言

说在前面:网上参考链接很多,参考之后可以实现分布式,但是对其原理还是云里雾里,有时间的建议去看一看原理。并且我实现分布式之后还是显示显存不足,不知道为什么。

参考链接:pytorch多GPU并行训练DistributedDataParallel应用和踩坑记录_train_sampler = distributedsampler(train_dataset, -CSDN博客
https://blog.csdn.net/weixin_44966641/article/details/121872773

二、.实现步骤

1 首先设置程序可见的GPU设备列表:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 

注意:一定要将其设置在所有访问cuda的语句之前,否则会失效。建议import os后就设置,因为import的文件中可能也有访问cuda的操作。或者在启动程序的命令行中设置,见下文。

2 初始化使用nccl后端:(使用 init_process_group 设置GPU 之间通信使用的后端和端口:)

torch.distributed.init_process_group(backend="nccl")

3而后为每个进程设置其device:

torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank)

4 DistributedDataParallel并不会自动分配数据。如果使用Dataset类的话,构建DataLoader的时候需要使用sampler对其进行采样并分配,具体如下:

train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, opt.batch_size, shuffle=False,  sampler=train_sampler)

5 然后,加载模型,并使用 DistributedDataParallel 包装模型,它能帮助我们为不同 GPU 上求得的梯度进行 all reduce(即汇总不同 GPU 计算所得的梯度,并同步计算结果)。all reduce 后不同 GPU 中模型的梯度均为 all reduce 之前各 GPU 梯度的均值:

model = Model()
model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],output_device=mypara.local_rank, find_unused_parameters=True)

6. 程序运行

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
或者用修改torch.distributed.launch用torchrun启动

问题一:显示分布式成功,但是还是报错

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 808.00 MiB (GPU 0; 23.69 GiB total capacity; 21.41 GiB already allocated; 371.69 MiB free; 22.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 20344) of binary: /HOME/scz0p8m/.conda/envs/torch/bin/python
Traceback (most recent call last):

你可能感兴趣的:(pytorch,人工智能,python)