Pytorch单机多卡并行应用经验分享

目录

  • 一、设置进程组
  • 二、封装模型
    • 【Tips】
  • 三、分割数据
    • 【Tips】
  • 四、训练模型
    • 【Tips】
  • 五、执行命令行
    • 【Tips】
  • 参考资料


记录分享一下最近使用单机多卡执行并行运算时总结的一些经验,以下内容均假设实验所使用的设备上有GPU,且已准备好了单卡运算时所需的数据及模型代码,本文仅介绍使用多卡(即多进程)并行运算时要在代码及命令行中需要额外设置的内容。

一、设置进程组

首先,我们要在指定存储设备device前设置进程组(建议直接加在代码最开始处),代码如下:

import torch
from torch.distributed import init_process_group, get_rank, get_world_size

init_process_group(backend = "nccl")
rank = get_rank()
world_size = get_world_size()
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)

init_process_group用于初始化进程组,rankworld_size分别表示当前进程的GPU编号(从0开始)和该进程组中的总GPU数(在命令行中设置,具体稍后介绍),例如若你在命令行中共设置了3张卡可见,即world_size = 3,那么在该进程组中你的每个进程编号分别为rank = 0rank = 1rank = 2

此外,若在代码中使用了argparse.ArgumentParser(),那么还需要指定默认的local_rank,即增加如下代码:

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type = int, default = -1)

二、封装模型

接下来要使用DistributedDataParallel对你的模型进行封装,以保证在执行并行运算时,各进程(卡)能同步更新优化模型。代码如下:

from torch.nn.parallel import DistributedDataParallel

model = <your model>.to(device)
model = DistributedDataParallel(model, device_ids = [rank], output_device = rank)

此处model即为你定义的模型,一般为nn.Module类。

【Tips】

  1. 对模型进行封装后,当我们想要调用模型的一些自定义属性(如模型的一个自定义损失函数model.loss_function())时,我们需要首先额外调用一个module属性(如model.module.loss_function()),但对于nn.Module类的内置属性(如model.parameters())则不需要。

三、分割数据

多进程并行运算的核心(个人理解)即是将数据划分至多个进程,在每个进程中独立同步地训练优化模型,从而减轻单一进程的运算及存储负担,最终达到加速模型训练的目的,因此需要将训练(及验证)数据集随机平均划分给每一个进程,具体代码很简单,如下所示:

from torch.utils.data.distributed import DistributedSampler

dataset = <your dataset>.to(device)
sampler = DistributedSampler(dataset, shuffle = True)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, sampler = sampler)

此处dataset为你的数据集,要求为torch中的Dataset类。sampler将数据集均分给各个进程,参数shuffle指定是否随机划分,如果希望每个epoch中各进程的数据集划分保持一致则设置为False,但dataloadershuffle参数必须设为False

【Tips】

  1. 在多卡情形下强烈建议设置num_workers0,否则可能会出现一些意想不到的bug,例如本人在一开始运行代码时发现经常在某个epoch突然卡主,但并没有任何报错,代码也并未终止,只是卡主不动,后来设置num_workers = 0后就没有这个问题了(之前使用单卡时设置num_workers不为0也没有这个问题),怀疑可能是因为进程太多引发内存不足所致。

四、训练模型

训练过程与单卡运行基本一致,代码如下:

optimizer = torch.optim.Adam(model.parameters())
model.train()
for epoch in num_epochs:
        dataloader.sampler.set_epoch(epoch)
        for batch, inputs in enumerate(dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = model.module.loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

可以看出唯一区别在于需要告知dataloader当前执行的epoch(即代码第三行),主要作用是打乱每个epoch的数据顺序(类似于单卡情形下dataloader设置参数shuffle = True)。因此若希望每个epoch喂入模型的数据顺序保持不变,则应删除这行代码。

【Tips】

  1. 在测试阶段,由于我们通常需要得到完整测试数据集的测试结果,因此我们无需对测试集进行划分,即不需要设置sampler,仅使用单卡执行测试即可,但需要指定一个要使用的master卡。示例代码如下:
if rank == 0: # set the master node
	test_dataset = <your testing dataset>.to(device)
	test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)
	model.eval()
    with torch.no_grad():
        for batch, inputs in enumerate(test_loader):
            outputs = model(inputs)
  1. 在某些情况下,我们需要将各个进程中保存的变量合并到单一进程上,以便用于后续操作(如模型测试),此时可以使用如下代码:
from torch.distributed import all_gather_object

outputs_all = [None for _ in range(world_size)]
# gather outputs among all processes to a list outputs_all
all_gather_object(outputs_all, outputs)
  1. 在训练模型时,为了减轻过拟合同时提高算法效率,有时我们会使用early_stop,即在满足一定条件下提前终止程序。此时在多进程情形下,经常会出现不同进程无法同时终止导致代码报错的问题。对此一种解决方案是不使用early_stop,另一种方案是要求所有进程均满足early_stop条件时才终止程序。

五、执行命令行

最后我们就可以使用命令行运行代码了,由于在许多情况下我们无法同时使用设备上的所有GPU(例如使用公用服务器等),因此我们需要指定允许此并行组可见的GPU编号和总数。例如设备上共有0,1,2,3四张卡,我们希望选择其中的0,1,3为可见卡,则命令行代码如下:

CUDA_VISIBLE_DEVICES=0,1,3 python -m torch.distributed.launch --nproc_per_node=3 <your code file>.py

其中nproc_per_node指定所要使用的GPU数(即world_size)。

【Tips】

  1. 注意此处的GPU编号和前述的rank并不完全相同,例如在本例中我们指定了卡0,1,3为可见卡,那么在初始化进程组时,程序会自动对这三张卡重新从0开始连续编号,即它们在进程组中的编号分别为rank = 0rank = 1rank = 2

  2. 本例使用的是torch版本为1.13,在2.0以上的版本中,torch.distributed.launch被替换为了torchrun,但其具体用法本人暂不清楚,欢迎大家补充!

  3. 若我们的程序中有任何print操作,那么每个进程都会重复执行此命令,导致在命令行出现大量重复信息。若想避免此情况,可以指定一个master进程来执行print,例如:

if rank == 0:
	print('hello world!')

参考资料

[1] https://blog.csdn.net/qq_39448884/article/details/120971703.

[2] https://medium.com/codex/a-comprehensive-tutorial-to-pytorch-distributeddataparallel-1f4b42bb1b51.

你可能感兴趣的:(pytorch,pytorch,python)