PyTorch单机多卡分布式训练教程及代码示例

导师不是很懂PyTorch的分布式训练流程,我就做了个PyTorch单机多卡的分布式训练介绍,但是他觉得我做的没这篇好PyTorch分布式训练简明教程 - 知乎。这篇讲的确实很好,不过我感觉我做的也还可以,希望大家看完之后能给我一些建议。

目录

1.预备知识

1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。

1.2 World,Rank,Local Rank

1.2.1 World

1.2.2 Rank

1.2.3 Local Rank

2. PyTorch单机多卡数据并行

2.1 多进程启动

2.1.1 多进程启动示例

2.2 启动进程间通信

2.2.1 初始化成功示例

2.2.2 初始化失败示例

2.2.3 进程间通信示例

2.3. 单机多卡数据并行示例

后记:如何拓展到多机多卡?


1.预备知识

多卡训练涉及到多进程和进程间通信,因此有必要先解释一些进程间通信的概念。

1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。

众所周知,每个主机都可以同时运行多个进程,但是在通常情况下每个进程都是做各自的事情,各个进程间是没有关系的。

而在MPI中,我们可以拥有一组能够互相发消息的进程,但是这些进程可以分布在多个主机中,这时我们可以将主机称为节点(Node),进程称为工作结点(Worker)。

PyTorch单机多卡分布式训练教程及代码示例_第1张图片

由于PyTorch中的主要说法还是进程,所以后面也会统一采用主机和进程的说法。

1.2 World,Rank,Local Rank

对于一组能够互相发消息的进程,我们需要区分每一个进程,因此每个进程会被分配一个序号,称作rank。进程间可以通过指定rank来进行通信。

1.2.1 World

World可以认为是一个集合,由一组能够互相发消息的进程组成。

如下图中假如Host 1的所有进程和Host 2的所有进程都可以进行通信,那么它们就组成了一个World。

PyTorch单机多卡分布式训练教程及代码示例_第2张图片

因此,world size就表示这组能够互相通信的进程的总数,上图中world size为6。

1.2.2 Rank

Rank可以认为是这组能够互相通信的进程在World中的序号。

PyTorch单机多卡分布式训练教程及代码示例_第3张图片

1.2.3 Local Rank

Local Rank可以认为是这组能够互相通信的进程在它们相应主机(Host)中的序号。

即在每个Host中,Local rank都是从0开始。

PyTorch单机多卡分布式训练教程及代码示例_第4张图片

2. PyTorch单机多卡数据并行

数据并行本质上就是增大模型的batch size,但batch size也不是越大越好,所以一般对于大模型才会使用数据并行。

Pytorch进行数据并行主要依赖于它的两个模块multiprocessing和distributed。

所以首先介绍multiprocessing和distributed模块的基本用法。

2.1 多进程启动

由于Python多线程存在GIL(全局解释器锁),为了提高效率,Pytorch实现了一个multiprocessing多进程模块。用于在一个Python进程中启动额外的进程。

2.1.1 多进程启动示例

该程序启动了4个进程,每个进程会输出当前rank,表明与其他进程不同。

#run_multiprocess.py
#运行命令:python run_multiprocess.py
import torch.multiprocessing as mp

def run(rank, size):
    print("world size:{}. I'm rank {}.".format(size,rank))


if __name__ == "__main__":
    world_size = 4
    mp.set_start_method("spawn")
    #创建进程对象
    #target为该进程要运行的函数,args为target函数的输入参数
    p0 = mp.Process(target=run, args=(0, world_size))
    p1 = mp.Process(target=run, args=(1, world_size))
    p2 = mp.Process(target=run, args=(2, world_size))
    p3 = mp.Process(target=run, args=(3, world_size))

    #启动进程
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    #当前进程会阻塞在join函数,直到相应进程结束。
    p0.join()
    p1.join()
    p2.join()
    p3.join()

输出结果:

world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.

2.2 启动进程间通信

虽然启动了多进程,但是此时进程间并不能进行通信,所以PyTorch设计了另一个distributed模块用于进程间通信。

init_process_group函数是distributed模块用于初始化通信模块的函数。

当该函数初始化成功则表明进程间可以进行通信。

2.2.1 初始化成功示例

只有当world size和实际启动的进程数匹配,init_process_group才可以初始化成功。

#multiprocess_comm.py
#运行命令:python multiprocess_comm.py

import os
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。
    #由于是在单机上,所以用localhost的ip就可以了。
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    #端口可以是任意空闲端口
    os.environ['MASTER_PORT'] = '29500'
    #通信模块初始化
    #进程会阻塞在该函数,直到确定所有进程都可以通信。
    dist.init_process_group('gloo', rank=rank, world_size=size)
    print("world size:{}. I'm rank {}.".format(size,rank))


if __name__ == "__main__":
    world_size = 4
    mp.set_start_method("spawn")
    #创建进程对象
    #target为该进程要运行的函数,args为函数的输入参数
    p0 = mp.Process(target=run, args=(0, world_size))
    p1 = mp.Process(target=run, args=(1, world_size))
    p2 = mp.Process(target=run, args=(2, world_size))
    p3 = mp.Process(target=run, args=(3, world_size))

    #启动进程
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    #等待进程结束
    p0.join()
    p1.join()
    p2.join()
    p3.join()

输出结果:

world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.

2.2.2 初始化失败示例

当将world size设置为2,但是实际却启动了4个进程,此时init_process_group就会报错。

#multiprocess_comm.py
#运行命令:python multiprocess_comm.py

import os
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。
    #由于是在单机上,所以用localhost的ip就可以了。
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    #端口可以是任意空闲端口
    os.environ['MASTER_PORT'] = '29500'
    #通信模块初始化
    #进程会阻塞在该函数,直到确定所有进程都可以通信。
    dist.init_process_group('gloo', rank=rank, world_size=size)
    print("world size:{}. I'm rank {}.".format(size,rank))


if __name__ == "__main__":
    world_size = 2
    mp.set_start_method("spawn")
    #创建进程对象
    #target为该进程要运行的函数,args为target函数的输入参数
    p0 = mp.Process(target=run, args=(0, world_size))
    p1 = mp.Process(target=run, args=(1, world_size))
    p2 = mp.Process(target=run, args=(2, world_size))
    p3 = mp.Process(target=run, args=(3, world_size))

    #启动进程
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    #当前进程会阻塞在join函数,直到相应进程结束。
    p0.join()
    p1.join()
    p2.join()
    p3.join()

输出结果:

RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1623448224956/work/third_party/gloo/gloo/context.cc:27] rank < size. 3 vs 2

2.2.3 进程间通信示例

当init_process_group初始化成功,进程间就可以进行通信了,这里我以集体通信Allreduce为例。

#multiprocess_allreduce.py
#运行命令:python multiprocess_allreduce.py

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    #通信模块初始化
    #进程会阻塞在该函数,直到确定所有进程都可以通信。
    dist.init_process_group('gloo', rank=rank, world_size=size)
    #每个进程都创建一个Tensor,Tensor值为该进程相应rank。
    param = torch.tensor([rank])
    print("rank {}: tensor before allreduce: {}".format(rank,param))
    #对该Tensor进行Allreduce。
    dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
    print("rank {}: tensor after allreduce: {}".format(rank,param))


if __name__ == "__main__":
    world_size = 4
    mp.set_start_method("spawn")
    #创建进程对象
    #target为该进程要运行的函数,args为target函数的输入参数
    p0 = mp.Process(target=run, args=(0, world_size))
    p1 = mp.Process(target=run, args=(1, world_size))
    p2 = mp.Process(target=run, args=(2, world_size))
    p3 = mp.Process(target=run, args=(3, world_size))

    #启动进程
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    #当前进程会阻塞在join函数,直到相应进程结束。
    p0.join()
    p1.join()
    p2.join()
    p3.join()

输出结果:

rank 0: tensor before allreduce: tensor([0])
rank 2: tensor before allreduce: tensor([2])
rank 3: tensor before allreduce: tensor([3])
rank 1: tensor before allreduce: tensor([1])

rank 0: tensor after allreduce: tensor([6])
rank 3: tensor after allreduce: tensor([6])
rank 2: tensor after allreduce: tensor([6])
rank 1: tensor after allreduce: tensor([6])

2.3. 单机多卡数据并行示例

当可以启动多进程,并进行进程间通信后,实际上就已经可以进行单机多卡的分布式训练了。

但是Pytorch为了便于用户使用,所以在这之上又增加了很多更高层的封装,如DistributedDataParallel,DistributedSampler等。

所以为了便于理解这中间的一些流程,这里演示一下不使用这些封装时的单机多卡数据并行。

该示例代码和单机训练主要有两个区别:

(1)需要基于每个进程的rank将模型参数放置到不同的GPU。

(2) 在参数更新前需要对梯度进行Allreduce。

#multiprocess_training.py
#运行命令:python multiprocess_training.py
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#用于平均梯度的函数
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size
#模型
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
    
def accuracy(outputs,labels):
    _, preds = torch.max(outputs, 1) # taking the highest value of prediction.
    correct_number = torch.sum(preds == labels.data)
    return (correct_number/len(preds)).item()


def run(rank, size):
    #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。
    #由于是在单机上,所以用localhost的ip就可以了。
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    #端口可以是任意空闲端口
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group('gloo', rank=rank, world_size=size)

    #1.数据集预处理
    train_dataset = torchvision.datasets.MNIST(root='../data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)
    training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)

    #2.搭建模型
    #device = torch.device("cuda:{}".format(rank))
    device = torch.device("cpu")
    print(device)
    torch.manual_seed(0)
    model = ConvNet().to(device)
    torch.manual_seed(rank)
    criterion = nn.CrossEntropyLoss() 
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.001,momentum=0.9) # fine tuned the lr
    #3.开始训练
    epochs = 15
    batch_num = len(training_loader)
    running_loss_history = []
    for e in range(epochs):
        for i,(inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device) 
            labels = labels.to(device)
            #前向传播
            outputs = model(inputs) 
            loss = criterion(outputs, labels) 
            optimizer.zero_grad() 
            #反传
            loss.backward() 
            #记录loss
            running_loss_history.append(loss.item())
            #参数更新前需要Allreduce梯度。
            average_gradients(model)
            #参数更新
            optimizer.step() 
            if (i + 1) % 50 == 0 and rank == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},acc:{:.2f}'.format(e + 1, epochs, i + 1, batch_num,loss.item(),accuracy(outputs,labels)))


if __name__ == "__main__":
    world_size = 4
    mp.set_start_method("spawn")
    #创建进程对象
    #target为该进程要运行的函数,args为target函数的输入参数
    p0 = mp.Process(target=run, args=(0, world_size))
    p1 = mp.Process(target=run, args=(1, world_size))
    p2 = mp.Process(target=run, args=(2, world_size))
    p3 = mp.Process(target=run, args=(3, world_size))

    #启动进程
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    #当前进程会阻塞在join函数,直到相应进程结束。
    p0.join()
    p1.join()
    p2.join()
    p3.join()

后记:如何拓展到多机多卡?

在多机多卡环境中初始化init_process_group还需要做一些额外的处理,主要考虑两个问题

(1)需要让其余进程知道rank=0进程的 IP:Port 地址,此时rank=0进程会在相应端口进行监听,其余进程则会给这个IP:Port发消息。这样rank=0进程就可以进行统计,确认初始化是否成功。这一步在PyTorch中是通过设置os.environ['MASTER_ADDR']和os.environ['MASTER_PORT']这两个环境变量来做的。

(2)需要为每个进程确定相应rank,通常采用的做法是给主机编号,因此多机多卡启动时给不同主机传入的参数肯定是不同的。此时参数可以直接手动在每个主机的代码上修改,也可以通过argparse模块在运行时传递不同参数来做。


 

你可能感兴趣的:(python,pytorch,分布式,人工智能,深度学习)