Pytorch DDP 分布式训练实例

相关注释已经写在代码块中。

代码实例

'''
文件名: DDP.py
脚本启动指令:
if torch version < 1.12.0:
    python -m torch.distributed.launch --nproc_per_node=2 DDP.py
else:
    torchrun --nproc_per_node=2 DDP.py
'''

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import torch.nn.functional as F
from torch import distributed
from torch.utils.data import DataLoader
from torchvision import models


## 初始化DDP进程组
try:
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    distributed.init_process_group("nccl")
except KeyError:
    rank = 0
    local_rank = 0
    world_size = 1
    distributed.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12584",
        rank=rank,
        world_size=world_size,
    )


def seed_all(seed):
    if not seed:
        seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def build_dataloader():
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    val_transform = transforms.Compose([transforms.ToTensor()])

    trainset = CIFAR100(root='your data root', train=True, download=True, transform=train_transform)
    valset = CIFAR100(root='your data root', train=False, download=True, transform=val_transform)

    ## create sampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(valset)

    ## 这里的batch_size指的是每个进程下的batch_size, 总batch_size是这里的batch_size再乘以并行数(world_size)
    trainloader = DataLoader(trainset, batch_size=16, num_workers=2, sampler=train_sampler, shuffle=False, pin_memory=True, drop_last=True)
    valloader = DataLoader(valset, batch_size=16, num_workers=2, sampler=val_sampler, shuffle=False, pin_memory=True, drop_last=True)

    return trainloader, valloader


def metric(logit, truth):
    prob = F.softmax(logit, 1)
    _, top = prob.topk(1, dim=1, largest=True, sorted=True)
    correct = top.eq(truth.view(-1, 1).expand_as(top))

    correct = correct.data.cpu().numpy()
    correct = np.mean(correct)
    return correct


def main():
    ## 全局层面控制随机数, 基本控制全局层面的随机数
    seed_all(42)

    ## set device
    torch.cuda.set_device(local_rank)

    ## build dataloader
    trainloader, valloader = build_dataloader()

    ## build model
    model = models.resnet101(pretrained=False, num_classes=100).to(local_rank)

    ## load model
    ckpt_path = 'your model dir'
    if rank == 0 and ckpt_path is not None:
        model.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cuda", local_rank)))
    
    ## use SyncBatchNorm
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)

    ## build DDP model
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[local_rank], 
        output_device=local_rank, 
        find_unused_parameters=True  # 当为True时, 在forward结束后, 会标记出所有没被用到的parameter, 提前把这些设定为ready. 默认为False, 因为其会拖慢速度.
        )

    ## get optimizer。
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    ## build loss function
    loss_func = nn.CrossEntropyLoss().to(local_rank)

    ## train model
    num_epochs = 100
    model.train()
    for epoch in range(num_epochs):
        ## 设置sampler的epoch,让不同的epoch产生shuffle的效果
        trainloader.sampler.set_epoch(epoch)

        for data, label in trainloader:
            data, label = data.to(local_rank), label.to(local_rank)
            optimizer.zero_grad()
            prediction = model(data)
            loss = loss_func(prediction, label)

            ## 同步进程
            distributed.barrier()

            '''
            不需要使用distributed.all_reduce来对loss进行累加求和并取平均, DDP在求梯度时会自动计算不同进程下梯度的均值
            可参考官方文档: https://pytorch.org/docs/stable/notes/ddp.html
            '''
            loss.backward()

            ## 查看模型参数梯度, 通过打印各进程的梯度, 验证各进程的梯度是否相同
            for name, param in model.named_parameters():
                print(f'name = {name}, grad_value = {param.grad}')
            
            optimizer.step()

        ## 模型保存的是model.module
        if rank == 0:
            torch.save(model.module.state_dict(), "%d.ckpt" % epoch)

        ## eval
        if (epoch+1) % 5 == 0:
            total_acc = 0
            for data, label in valloader:
                data, label = data.to(local_rank), label.to(local_rank)
                prediction = model(data)

                ## 收集不同进程下的预测值
                _gather_prediction = [
                            torch.zeros_like(prediction).cuda()
                            for _ in range(world_size)
                        ]
                _gather_label = [
                            torch.zeros_like(label).cuda()
                            for _ in range(world_size)
                        ]
                distributed.all_gather(_gather_prediction, prediction)
                distributed.all_gather(_gather_label, label)
                prediction = torch.cat(_gather_prediction)
                label = torch.cat(_gather_label)

                accuracy = metric(prediction, label)
                total_acc += accuracy

            avg_acc = total_acc / len(valloader)
            print(avg_acc)

    ## destroy
    distributed.destroy_process_group()


if __name__ == "__main__":
    main()

你可能感兴趣的:(pytorch,pytorch,DDP,图像分类)