Pytorch多GPU训练:DataParallel和DistributedDataParallel

引言

Pytorch有两种方法实现多GPU训练,分别是DataParallel(DP)和DistributedDataParallel(DDP)。DP实现简单,但没有完全利用所有GPU资源,DDP实现相对复杂,但是更快,我建议使用DDP。


DP

DP使用torch.nn.DataParallel。原理是,假设用K个GPU训练,前向传播阶段,一个batch的数据会被平均分成K份,模型也会复制K份,分别送到每个GPU上;反向传播阶段,各复制模型产生的梯度会被累加到主模型上。batch size应该大于使用的GPU数量

Demo

我们写一份简单的程序实现一下torch.nn.DataParallel:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn: input size", input.size(),
              "output size", output.size())

        return output


# 超参数
input_size = 5    # 输入维数
output_size = 2   # 输出维数
batch_size = 30   # batch size
data_size = 30    # 样本数
gpus = [0, 1, 2]  # GPU索引

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)
model = Model(input_size, output_size)

# Multi-GPUS操作
model = nn.DataParallel(model, device_ids=gpus)
model = model.to(gpus[0])  # 主模型

for data in rand_loader:
    input = data.to(gpus[0])
    output = model(input)
    print("Out: input size", input.size(),
          "output_size", output.size())

GPUS = [0, 1, 2]
在这里插入图片描述
GPUS = [0, 1]
在这里插入图片描述
建议使用DDP代替DP。DP基于单进程,多线程,只能在一个机器上多卡训练,由于多线程之间GIL连接引入了额外开销,即使在一个机器上也比DDP慢;DDP基于多进程,可以在多个机器上训练,每个GPU由专有进程控制,训练更快!


DDP

下面的介绍都针对单机器多卡训练,因为这是我们最常见的情况。DDP的基本原理也是将模型复制到每个GPU上,收集每个GPU产生的梯度,平均这些梯度更新模型,然后同步所有GPU上的模型。首先了解一些必要的概念:

  • master node 主节点:负责同步,复制模型,加载模型,写日志的主GPU
  • process group 进程组:DDP给K个GPU各自分配一个进程,这K个进程构成进程组;进程组由一个后端控制协调,pytorch建议使用nccl后端
  • rank:各进程的索引,从0到K-1
  • world size:进程组中进程的总数K

DDP的实现流程可以概括为:设置进程组,分割数据,DDP化模型,训练模型,clean up。我们首先介绍各部分对应的代码,最后给出整体demo。

设置进程组

import torch.distributed as dist

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'  # 主节点地址
    os.environ['MASTER_PORT'] = '12355'      # 主节点端口,用于进程之间通讯
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

分割数据

可以使用torch.utils.data.distributed.DistributedSampler分割数据的索引,然后将每组索引送入dataloader组成batch。请注意,这里与DP有区别,DP将一个batch的数据分成K份,然后送入每个GPU,因此设置的batch size等于训练使用的batch size;DDP先将训练数据分成K份,然后送入每个GPU,再生成batch,因此实际训练使用的batch size等于设置的batch size再乘以K!

from torch.utils.data.distributed import DistributedSampler

def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    dataset = Your_Dataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    
    return dataloader

假设K = 3,共有10条数据。

  • 如果sampler.drop_last = False,会对索引个数不足的rank补零。例如数据索引为[0, 1, …, 9],那么这3个rank的索引分割结果可能是[0, 3, 6, 9],[0, 4, 7, 0]和[2, 5, 8, 0],sampler尽量保证每组的索引不重复,但由于补零操作,不可避免地使数据0有重复
  • 否则,会丢弃一些索引,此时的分割结果可能是[0, 3, 6],[1, 4, 7]和[2, 5, 8],其中索引9被丢弃了

请注意,设置num_workers = 0和pin_memory = False可以避免DDP下一些不必要的BUG。

DDP化模型

from torch.nn.parallel import DistributedDataParallel as DDP

model = Model().to(rank)
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)

这边也有一些需要注意的点!

  • 当我们要访问DDP化的模型属性时,需要使用model.module.attr
  • 直接保存DDP模型时,state_dict会对所有参数增加module前缀
  • 如果要把DDP模型参数加载到非DDP模型时,需要删除参数前的module前缀

训练模型

使用spawn方法管理多进程,对于多进程来说,所有子进程和父进程运行的是相同的程序。

import torch.multiprocessing as mp

if __name__ == '__main__':
    world_size = 3        
    mp.spawn(
        main,
        args=(world_size),
        nprocs=world_size
    )

main是运行在每一个进程上的训练过程,main的第一个形参必须是rank,spawn会自动传递这个值给main,所以spawn.args只写了world_size参数。rank = 0是默认的主节点。同时注意在epoch和iter的循环之间必须加上dataloader.sampler.set_epoch(epoch),数据才能正确分割。

Clean Up

main的最后一行是clean up操作。

def cleanup():
    dist.destroy_process_group()

保存网络

在主节点保存网络,保存函数后面要加上dist.barrier()函数,暂停此时其他进程的运行,等待网络保存完毕。

if rank == 0:
   model.save_nets()
dist.barrier()  # 保存结束前其他 process 不运行

Demo

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.label = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        return output


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'  # 主节点地址
    os.environ['MASTER_PORT'] = '12355'      # 主节点端口,用于进程之间通讯
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    dataset = RandomDataset(5, 60)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers,
                            drop_last=False, shuffle=False, sampler=sampler)
    return dataloader


def cleanup():
    dist.destroy_process_group()


def main(rank, world_size):
    # 建立进程组
    setup(rank, world_size)
    print("Rank", rank)
    # 分割数据
    dataloader = prepare(rank, world_size, batch_size=10)
    # DDP化模型
    model = Model(5, 5).to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    # 训练模型
    loss = torch.nn.MSELoss()
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    for epoch in range(10):
        dataloader.sampler.set_epoch(epoch)  # 这个必须要加!
        for x, y in dataloader:
            optim.zero_grad()
            x = x.to(rank)
            y = y.to(rank)
            pred = model(x)
            l = loss(pred, y)
            l.backward()
            optim.step()
    # 保存网络
    # if rank == 0:
    #     model.save_nets()  # 自己编写保存参数函数
    # dist.barrier()
    # clean up
    cleanup()


if __name__ == "__main__":
    world_size = 3

    mp.spawn(
        main,
        args=(world_size,),
        nprocs=world_size,
    )

参考

链接1
链接2

你可能感兴趣的:(pytorch,深度学习,python)