pytorch多gpu训练

1. 安装apex

apex是nvidia开发的基于pytorch的多gpu训练加速包。下载:

git clone https://github.com/NVIDIA/apex
cd apex

#安装时碰倒问题pip._internal.operations.install.legacy.LegacyInstallFailure,因此采用另一个branch
git checkout f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

2. 初始化

初始化该流程并与其他流程合并。这过程是"阻塞",在所有进程都加入之前,没有进程将继续。

import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp

dist.init_process_group(
    	backend='nccl',
   		init_method='env://',
    	world_size=args.world_size,
    	rank=rank
    )

init_method 告诉进程组去哪儿寻找相关设置。本例中将会去环境变量中寻找MASTER_ADDRMASTER_PORT

相关的概念:

  • nodes: 节点的数量,即机器的数量
  • gpus: 每个节点上gpu的数量
  • nr: 当前节点在所有节点中的rank, 范围 [ 0 , n o d e s − 1 ] [0, nodes-1] [0,nodes1]
  • world_size: 总的gpu的数量,也就是所有进程的总数量, w o r l d _ s i z e = n o d e s × g p u s world\_size=nodes\times gpus world_size=nodes×gpus

3. 训练代码

import torch.multiprocessing as mp


def train():
	pass

def main():
	os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'
    mp.spawn(train, nprocs=4, args=(config,))


if __name__ == "__main__":
	main()

4. 执行训练

使用如下命令:

python -m torch.distributed.launch --nproc_per_node=4 train.py --a b --c d

其中python -m torch.distributed.lanuch表示调用torch.distributed.lanuch.py文件来进行分布式训练,-m表示将后面的torch.distributed.lanuch当做模块加载。 --nproc_per_node通常与GPU数量保持一致。train.py才是真正的训练文件,后面是相关参数。

执行过程中,其实相当于会自动在terminal中执行多个相同的命令:

python', '-u', 'train.py', '--local_rank=3' --a b --c d

只是每次执行时会赋值不同的local_rank,从0gpus-1.

5. 相关函数

  • torch.distributed.get_world_size(): 获取GPU数量。
  • torch.distributed.get_rank(): 返回当前进程组的排名。Rank是分配给分布式进程组中每个进程的唯一标识符。它们总是从0到world_size的连续整数。
  • torch.distributed.barrier(async_op=False): 同步所有的进程, 直到整组(也就是所有节点的所有GPU)到达这个函数的时候, 才会执行后面的代码。例如下面的代码:
def create_dataloader():
    if rank == 0:
        dataset = LoadImagesAndLabels()
	torch.distributed.barrier()

如果rank不为0,create_dataloader()函数的进程不是主进程,会执行相应的barrier(),此时会设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处(包括主进程数据处理完毕);如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放。

  • torch.distributed.all_gather(tensor_list, tensor, group=, async_op=False):
    从列表中收集整个组的张量。tensor_list中的每个张量应位于单独的GPU上。

如果要写一个损失函数,则可以使用:

import torch
from torch import distributed
from torch import autograd
from torch.nn.parallel import DistributedDataParallel as DDP


class awesome_allgather_function(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = distributed.get_world_size()
        # print(f"world size: {world_size}")
        # create a destination list for the allgather.  I'm assuming you're gathering from 3 workers.
        allgather_list = [torch.empty_like(input) for _ in range(world_size)]
        #if distributed.get_rank() == 0:
        #    import IPython;IPython.embed()
        # print(f"hahaha, {input.shape}")  # torch.Size([2, 4])
        # input is actually in multiple gpus, here we collect it to allgather_list
        distributed.all_gather(allgather_list, input)
        # print(f"allgather_list: {len(allgather_list)}")
        return torch.cat(allgather_list, dim=0)

    @staticmethod
    def backward(ctx, grad_output):  # distribute gradient to each gpu
        #print_if_rank0("backward grad_output len", len(grad_output))
        #print_if_rank0("backward grad_output shape", grad_output.shape)
        grads_per_rank = grad_output.shape[0] // distributed.get_world_size()
        rank = distributed.get_rank()
        # We'll receive gradients for the entire catted forward output, so to mimic DataParallel,
        # return only the slice that corresponds to this process's input:
        sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank)
        #print("worker", rank, "backward slice", sl)
        return grad_output[sl]

nominator = awesome_allgather_function.apply(nominator)
denominator = awesome_allgather_function.apply(denominator)
nominator = nominator.sum(0)
denominator = denominator.sum(0)

l = loss_fn(nominator, denominator)

l.backward()

打印和保存的时候使用:

def print():
	if local_rank == 0:
    	print("...")

def save_checkpoint():
	if local_rank == 0:
    	save(...)

6. kill掉后台进程

对于DDP程序,仅仅使用ctrl+c不能kill后台程序,使用如下命令即可全完kill掉:

kill $(ps aux | grep train.py | grep -v grep | awk '{print $2}')

你可能感兴趣的:(pytorch)