PyTorch单机多卡训练(DDP-DistributedDataParallel的使用)备忘记录

不做具体的原理分析和介绍(因为我也不咋懂),针对我实际修改可用的一个用法介绍,主要是模型训练入口主函数(main_multi_gpu.py)的四处修改。
PyTorch单机多卡训练(DDP-DistributedDataParallel的使用)备忘记录_第1张图片
以上的介绍来源https://zhuanlan.zhihu.com/p/206467852

0. 概述

使用DDP进行单机多卡训练时,通过多进程在多个GPU上复制模型,每个GPU都由一个进程控制,同时需要将参数local_rank传递给进程,用于表示当前进程使用的是哪一个GPU。

要将单机单卡训练修改为基于DDP的单机多卡训练,需要进行的修改如下(总共四处需要修改):

  1. 初始化设置,需要设置local_rank参数,并需要将local_rank参数传递到进程中,如下:
# 在参数设置中添加local_rank参数,见parse_args()函数
# 运行时无需指定local_rank参数,但必须在此处进行定义
parser.add_argument("--local_rank", type=int, default=0)

# 将local_rank参数传入到进程中,见init_distributed_mode()函数
# local_rank:表示当前GPU编号
# local_world_size:使用的GPU数量
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(
	backend='nccl',
	init_method='env://',
	world_size=local_world_size,
	rank=local_rank
)
  1. DataLoader修改,需要使用torch.utils.data.DistributedSampler对数据集进行划分(把所有数据划分给不同的GPU),用于多个GPU单独的数据读取:
    此处,Dataset --> DistributedSampler --> BatchSampler --> DataLoader
    此外,Dataset、Sampler、DataLoader的关系可参考博文一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系,我觉得写的还挺清楚的(大概可以总结为,Dataset包含了所有的数据,Sampler指明了需要用到的数据的index,DataLoader根据Sampler从Dataset中提取数据用于训练
# 见setup_loader()函数中修改
"""
# 从数据集中采样划分出每块GPU所用的数据,组织形式为数据的index,如下:
	32104
	99491
	11488
	25070
	67216
	22453
	57418
	45591
	64625
	46036
	98404
	81477
"""
self.sampler_train = torch.utils.data.DistributedSampler(
	xx_set
)
"""
# 将每个GPU所用的数据,按照BATCH_SIZE进行组织,组织形式为数据index的list,list长度即为BATCH_SIZE,如下:
	[32104, 99491, 11488, 25070, 67216, 22453, 57418, 45591, 64625, 46036]
	[98404, 81477, 73638, 22696, 82657, 44563, 106537, 15772, 85536, 38823]
	......
"""
batch_sampler_train = torch.utils.data.BatchSampler(
	self.sampler_train, 
	batch_size=cfg.TRAIN.BATCH_SIZE, 
	drop_last=cfg.DATA_LOADER.DROP_LAST
)
# 根据batch_sampler_train(数据index)从xx_set数据集中提取出真实数据用于训练
self.training_loader = torch.utils.data.DataLoader(
	self.xx_set, # 数据集
	batch_sampler=batch_sampler_train, # 该GPU分配到的数据(以batch为单位组织)
	collate_fn=datasets.data_loader.sample_collate,  # 一个batch数据的组织方式
	pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
	num_workers=cfg.DATA_LOADER.NUM_WORKERS
)
  1. 模型初始化设置,使用torch.nn.parallel.DistributedDataParallel对模型进行包装,需要传入local_rank参数
# 使用torch.nn.parallel.DistributedDataParallel对模型进行包装,同时传入local_rank参数,表明模型是在哪一块GPU上
self.model = torch.nn.parallel.DistributedDataParallel(
	model.to(self.device),
	device_ids=[self.local_rank]  # 传入`local_rank`参数,即GPU编号
)
  1. 模型训练代码修改,训练过程中每开始新的epoch,基于torch.utils.data.BatchSampler创建的self.sampler_train都需要调用set_epoch(epoch_num)对训练集数据进行shuffle;并且需要调用torch.distributed.barrier()。见train()函数中修改内容
    如果不使用BatchSampler而直接使用Dataset --> DistributedSampler --> DataLoader(需要在DataLoader中指定batch_size参数),应该不用调用set_epoch(epoch_num)?(不确定,没试过)。

1. 单机多卡训练代码修改

模型训练入口主函数文件组织大致如下:

# main_multi_gpu.py
# 导入相关包
import torch
import ...

# 训练器定义
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        # 固定随机数种子
        ...

		# 判断是否为多卡训练
		self.num_gpus = torch.cuda.device_count()
        self.distributed = self.num_gpus > 1
        
        # 针对单机多卡训练,进行初始化
        # 单机单卡训练无需此操作
        if self.distributed:
        	# 获取该进程的local_rank,数值和args.local_rank一致
            self.local_rank = init_distributed_mode()
        self.device = torch.device("cuda") if self.num_gpus > 0 else torch.device("cpu")
            
        # 训练集构建
        self.setup_dataset()
        self.setup_loader(0)  # 0表示初始epoch
        # 模型结构构建
        self.setup_network()
        # 一些无关的初始化操作
        ...
        
	def setup_dataset(self):
		self.xx_set = ...  # 创建训练集

	def setup_loader(self, epoch):
		if self.distributed:
            self.sampler_train = torch.utils.data.DistributedSampler(
                self.xx_set
            )
        else:
            self.sampler_train = torch.utils.data.RandomSampler(
                self.xx_set
            )
        batch_sampler_train = torch.utils.data.BatchSampler(
            self.sampler_train, 
            batch_size=cfg.TRAIN.BATCH_SIZE, 
            drop_last=cfg.DATA_LOADER.DROP_LAST
        )
        self.training_loader = torch.utils.data.DataLoader(
            self.xx_set,
            batch_sampler=batch_sampler_train,
            collate_fn=datasets.data_loader.sample_collate,  # 
            pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
            num_workers=cfg.DATA_LOADER.NUM_WORKERS
        )
		
	def setup_netword(self):
		model = ...  # 构建模型结构
		# DDP模型
		if self.distributed:
            self.model = torch.nn.parallel.DistributedDataParallel(
                model.to(self.device),
                device_ids=[self.local_rank]
            )
        else:
            self.model = torch.nn.parallel.DataParallel(model).to(self.device)

		# 一些不太相关的操作,比如损失函数、模型训练优化器的设置
		self.xe_criterion = ... # 模型训练交叉熵损失
		self.optim = ...  # 模型训练优化器
        ...

	# 模型训练核心
	def train():
		self.model.train()
		# 训练过程Epoch循环
		for epoch in range(MAX_EPOCH):
			# 如果是单机多卡训练,需要调用set_epoch,将数据进行shuffle
            if self.distributed:
                self.sampler_train.set_epoch(epoch)
            # 每个Epoch训练过程中的iteration循环
            for _data_ in self.training_loader:
            	# 1 模型前向运算,并计算损失
            	loss = ...
            	# 2 梯度清零
            	self.optim.zero_grad()
            	# 3 计算新梯度,并进行梯度裁减(梯度裁减为可选操作)
            	loss.backward()
            	# 4 梯度反传,模型参数更新
            	self.optim.step()
            	
				# 又是一些不太相关的操作,比如优化器lr衰减
				...
				
				# 进程间数据同步?(不确定是不是必须操作)
				if self.distributed:
                	torch.distributed.barrier()
        	# 又是一些不太相关的操作,比如模型的保存,模型的验证
        	...
        	# 进程间数据同步?
			if self.distributed:
                torch.distributed.barrier()

# 参数
def parse_args():
    '''
    Parse input arguments
    '''
    parser = argparse.ArgumentParser(description='Image Captioning')
    # 模型训练所需一些参数
    ...
    # DDP训练所需参数,--local_rank,不加它也有办法能跑起来,但是还是加了更规范一点
    parser.add_argument("--local_rank", type=int, default=0)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args

# 以下两个函数参考了DETR
# 禁用非主进程的输出
def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def init_distributed_mode():
    # 获取GPU编号
    # 初始化时使用get_rank()报错,难道只能初始化之后才能正常调用获取local_rank值?
    # local_rank = torch.distributed.get_rank()
    # local_world_size = torch.distributed.get_world_size()
    
    # 也可以传入args,通过args.local_rank获取local_rank值 (即当前GPU编号)
    # 通过torch.cuda.device_count()获取local_world_size值 (即GPU数量)
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ["RANK"])
        local_world_size = int(os.environ['WORLD_SIZE'])
        local_gpu = int(os.environ['LOCAL_RANK'])
    else:
        print('Error when get init distributed settings!')
        return
    
    torch.cuda.set_device(local_rank)
    print('| distributed init (rank {}): env://'.format(local_rank), flush=True)
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=local_world_size, # 所有的进程数,及GPU数量
        rank=local_rank
    )
    torch.distributed.barrier()
    # 禁用非主进程的输出,local_rank为0的进程作为主进程
    setup_for_distributed(local_rank==0)
    # 返回GPU编号
    return local_rank

if __name__ == '__main__':
    args = parse_args()

    if args.folder is not None:
        cfg_from_file(os.path.join(args.folder, 'config.yml'))
    cfg.ROOT_DIR = args.folder

    trainer = Trainer(args)
    trainer.train()

单机单卡训练(--*** ***表示模型训练所需传入的其他参数)

CUDA_VISIBLE_DEVICES=0 python main_multi_gpu.py --*** ***

单机多卡训练,无需指定--local_rank参数,使用torch.distributed.launch可以自动指定相关参数,在代码中可以从os.environ中获取相关参数(见init_distributed_mode()函数),但是必须得设置parser.add_argument("--local_rank", type=int, default=0)

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port=3141 --nproc_per_node 2 main_multi_gpu.py --*** ***

参考:

[1] 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
[2] DataParallel & DistributedDataParallel分布式训练
[3] [原创][深度][PyTorch] DDP系列第一篇:入门教程
[4] DETR源码 https://github.com/facebookresearch/detr

你可能感兴趣的:(深度瞎搞,pytorch,深度学习,DDP,单机多卡)