使用 PyTorch 进行分布式训练

欢迎关注 “小白玩转Python”,发现更多 “有趣”

引言

在本教程中,您将学习如何在单个节点上跨多个 GPU 并行 ML 模型训练的实践方面。您还将学习 PyTorch 的分布式数据并行框架的基础知识。

学习之前,我们先了解一下什么是DDP。

什么是 DDP?

DDP 是 PyTorch 中的一个库,它支持跨多个设备的梯度同步。这意味着什么?这意味着您可以通过跨多个 GPU并行处理几乎线性地加快模型训练。换句话说,在一台有两个 GPU 的机器上训练这个模型比在一个 GPU 上训练要节约大约一半的时间。许多云计算提供商,如 AWS 和 GCP,都提供多 gpu 计算机。例如,AWS 中的 ml.p2.8 xlarge 实例具有8个 GPU。它不是严格的线性的,因为从多个设备收集张量的性能开销很小。

DDP 同样支持分布在多个独立主机上的训练,这样的设置超出了本教程的范围。

它是如何工作的?

DDP 的工作原理是为每个 GPU 创建一个单独的 Python 进程。每个进程都使用一个不重叠的数据子集。注意,最多只能在一个 GPU 上运行一个 DDP 进程。也就是说,机器上可用的 DDP 进程可能少于 GPU。在这种情况下,一些 gpu 将继续未使用。对于一个脚本,不可能有比 GPU 更多的进程。

PyTorch 提供了产生多个进程的工具,以及将数据集分割成不重叠的子集的工具。

如果您有兴趣了解关于 DDP 实现的更多细节,请随意浏览 DDP 设计文档:https://pytorch.org/docs/master/notes/ddp.html。

术语

在我们开始之前,让我们熟悉一些 DDP 的概念:

size:进行训练的 GPU 设备的数量

rank:对GPU设备有一个序列的id号

DDP 使您的脚本可以通过命令行参数获得 rank 值。可以通过 torch.cuda.device_count()获得 size,前提是您希望使用所有可用的 GPU。

准备数据集

让我们从学习如何使用 PyTorch 将数据集拆分为不重叠的子集开始。首先,让我们看看如何创建简单的非分布式数据集:

# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
                               download=True,
                               train=True)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

创建一个分布式的数据加载器,可以使用 torch.utils.data.DistributedSampler:

# Download and initialize MNIST train dataset
train_dataset = datasets.MNIST('./mnist_data',
                               download=True,
                               train=True,
                               transform=transform)
# Create distributed sampler pinned to rank
sampler = DistributedSampler(train_dataset,
                             num_replicas=world_size,
                             rank=rank,
                             shuffle=True,  # May be True
                             seed=42)
# Wrap train dataset into DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=False,  # Must be False!
                          num_workers=4,
                          sampler=sampler,
                          pin_memory=True)

我们创建 DisstributedSampler 并将其传递给 DataLoader。在 DataLoader 上设置 shuffle = False 以避免混乱子集是至关重要的。Shuffle 过程是由采样器完成的,所以你可能想要设置 shuffle = True。

准备模型

下面是一个在单个设备环境中如何初始化模型的例子:

def create_model():
    model = nn.Sequential(
        nn.Linear(28*28, 128),  # MNIST images are 28x28 pixels
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 10, bias=False)  # 10 classes to predict
    )
    return model
# Initialize the model
model = create_model()

为了使其在多 GPU 环境中工作,需要进行以下修改:

# Initialize the model
model = create_model()
# Create CUDA device
device = torch.device(f'cuda:{rank}')
# Send model parameters to the device
model = model.to(device)
# Wrap the model in DDP wrapper
model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)

训练周期

单设备训练周期的代码类似于:

for i in range(epochs):
    for x, y in train_loader:
        # do the training
        ...

在多 GPU 环境下,采样器必须知道哪个 epoch:

for i in range(epochs):
    train_loader.sampler.set_epoch(i)
    for x, y in train_loader:
        # do the training
        ...

从命令行参数获取rank

DDP 会将 -- local-rank 参数传递给你的脚本。你可以这样解析:

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
rank = args.local_rank

保存模型

在每个 GPU 上保存模型参数的副本,所以你应该只保存一次模型:

if rank == 0:
    torch.save(model.module.state_dict(), 'model.pt')

启动脚本

DDP 提供了一个启动实用程序,您可以使用它产生多个进程。如果你的机器有4个 GPU 可用,一个命令行看起来像这样:

python -m torch.distributed.launch --nproc_per_node=4 
ddp_tutorial_multi_gpu.py

注意,这里没有明确指定 -- local-rank。在调用脚本之前,PyTorch 会自动将其添加到命令行中。

确保它起作用

有几个简单的命令可以告诉你你的实现是正确的:

1. 打印数据加载程序的长度len(train_loader) .例如,您的数据集有10,000个示例,批量大小为100。这意味着数据加载器将有10,000/100 = 1,000个批次。这将是只使用一个设备时数据加载程序的长度。然而,在 DDP 中,这个数字将被设备的数量除以。因此,如果您使用相同的数据集和设置,但在2个 GPU 上进行训练,数据加载程序的长度将为1,000/2 = 500。

2. 使用 nvidia-smi 工具来监控 GPU 的使用。如果您的实现是正确的,那么您应该看到所有的 GPU 或多或少都得到了同样的使用。

·  END  ·

HAPPY LIFE

你可能感兴趣的:(分布式,python,大数据,深度学习,编程语言)