pytorch 分布式多卡训练DistributedDataParallel

主要分为以下几个部分:

  1. 单机多卡,DataParallel(简单,常用)
  2. 多机多卡,DistributedDataParallel(最高级)
  3. 训练指令
  4. 注意事项

一、单机多卡(DataParallel)

from torch.nn import DataParallel
 
model = MyModel()
model = DataParallel(model).cuda()

二、多机多卡(DistributedDataParallel)

2.1 argparse和dist初始化设置

from torch.utils.data.distributed
import DistributedSampler
import torch.distributed as dist
import torch
import argparse

def get_args_parser():
    parser = argparse.ArgumentParser(description='Sparse-to-Dense')
	parser.add_argument('--dist', type=bool, default=False)
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--num_workers', default=4, type=int)
    return parser

def _init_dist_pytorch(opt):
    # 新增:DDP backend初始化
    torch.cuda.set_device(opt.local_rank)
    dist.init_process_group(backend="nccl", init_method="env://")

def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size
    
if __name__ == '__main__':
	###多机设置###
	if 'LOCAL_RANK' not in os.environ:
	    os.environ['LOCAL_RANK'] = str(opt.local_rank)
	
	# 新增:DDP backend初始化
	if opt.dist:
	    _init_dist_pytorch(opt)
	# set random seeds
	if opt.seed is None:
	    seed = random.randint(1, 10000)
	    np.random.seed(seed)
	    torch.manual_seed(seed)
	    if opt.device == 'cuda':
	        torch.cuda.manual_seed_all(seed)
	main()

2.2 main函数改写

def main():
    if args.dist:
        ### 多机模型并行  ###
        device = torch.device('cuda', opt.local_rank)

        # send your model to GPU
        model = model.to(device)
        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank,
                    broadcast_buffers=False, find_unused_parameters=False)
    else:
        device = torch.device(args.device)
        model = torch.nn.DataParallel(model)
        model = model.cuda()

	 ### 多机数据并行  ###
    if args.dist:
        train_sampler = data.distributed.DistributedSampler(train_dataset, rank=args.local_rank)  ##数据分布式
        train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,               num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

    else:
        train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True,num_workers=args.workers,pin_memory=True,sampler=None)
        
    ### 多机数据并行  ###

	for epoch in range(1, max_epoch + 1):
        model.train()
        ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ###
        tr_sampler.set_epoch(epoch)
        ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ###
		
		optimizer.zero_grad()
	    outputs = ddp_model(torch.randn(20, 10))
	    labels = torch.randn(20, 5).to(device_ids[0])
	    loss_fn = nn.MSELoss()
	    loss_fn(outputs, labels).backward()
	    optimizer.step()
		
		if args.rank == 0:
            # save ckpt every epoch
            torch.save(model.state_dict(), os.path.join(args.output_dir, f'epoch_{epoch}.pth'))
        # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
        dist.barrier()

2.3 训练入口

import os
import json
# from easydict import EasyDict
import socket
import logging
import argparse

logger = logging.getLogger()

def get_work_index():
    while True:
        try:
            addr = os.environ.get("MASTER_ADDR", "{}")  # .replace('master','worker')
            if addr == "localhost":  # 当为master节点时,需要通过hostname获取IP,不然获取IP为127.0.0.1
                addr = os.environ.get("HOSTNAME", "{}")
            master_addr = socket.gethostbyname(addr)  # 获取master IP地址
            master_port = os.environ.get("MASTER_PORT", "{}")  # 获取master port
            # print("MASTER_ADDR: %s", addr)
            world_size = os.environ.get("WORLD_SIZE", "{}")  # job 的总进程数
            rank = os.environ.get("RANK", "{}")  # 当前进程的进程号, 必须在 rank==0 的进程内保存参数
            # logging.info("RANK: %s", rank)
            break
        except:
            print("get 'TF_CONFIG' failed, sleep for 1 second~")
            os.system('sleep 1s')
            continue
    return int(world_size), int(rank), master_addr, master_port


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--expfile', type=str, help='experiment config yaml file ', default='exp.yaml')
    parser.add_argument('--num_gpus', type=int, help='experiment config yaml file ', default=8)
    parser.add_argument('--eval', action='store_true', help='only eval')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    os.system('sleep 10s')
    cfg = {'config_file': "experiments/" + args.expfile,
           'num_gpus': args.num_gpus,
           'cluster_name': 'tecent'}

    num_nodes, node_rank, master_addr, master_port = get_work_index()

    if int(node_rank) <= cfg['num_gpus'] - 1:

        CMD = './dist_val_mpi.sh %s %d %s %d %d ' % (
            cfg['config_file'], cfg['num_gpus'], master_addr, num_nodes, node_rank)

        # start train
        logger.info(CMD)
        os.system("%s" % (CMD))
        

2.4 训练脚本

#!/bin/bash

CONFIG=$1
nproc_per_node=$2
master_addr=$3
nnodes=$4
node_rank=$5
PORT=${PORT:-29506}

export OMP_NUM_THREADS=8
export PYTHONPATH="$(dirname $0)/../":$PYTHONPATH

# scenflow
TORCH_DISTRIBUTED_DEBUG=DETAIL python -m torch.distributed.launch --nproc_per_node=$nproc_per_node \
--nnodes=$nnodes --node_rank=$node_rank  --master_addr=$master_addr --master_port=$PORT \
 main.py --dist True 

3.开始训练指令

python torch_dist_job.py --num_gpus 8

注意事项

pytorch在分布式训练过程中,对于数据的读取是采用主进程预读取并缓存,然后其它进程从缓存中读取,不同进程之间的数据同步具体通过torch.distributed.barrier()实现。

dist.barrier()

你可能感兴趣的:(python,人工智能,深度学习,分布式,pytorch)