pytorch教学:torch.nn.parallel.DistributedDataParallel(DDP分布式训练)

DDP训练大致是一个GPU开一个线程,如果有两个GPU,则将dataset分成2份,然后一个GPU读取一份
下面的代码能正确使用DDP分布式训练,直接参考即可
注:本代码只适用于单机多卡训练,多机多卡的由于资源有限还没试过
在终端的运行命令:

python -m torch.distributed.launch --nproc_per_node 2 train.py
其中2表示你有几个GPU

import datetime
import os

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import joint_transforms
from config import msra10k_path
from datasets import ImageFolder
from misc import AvgMeter, check_mkdir
from model import R3Net
from torch.backends import cudnn

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

# torch.distributed.init_process_group(backend="nccl")
dist.init_process_group(backend='nccl', init_method='env://')
batch_size = 12  # 主卡上的batchsize
data_size = 25  # 总共的batchsize
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
#dist.init_process_group(backend='nccl', init_method='env://', world_size=2, rank=local_rank)
print(local_rank) # 注意!!!!!!!!!!!!! 会先输出0  再输出1
# device = torch.device("cuda", local_rank)
cudnn.benchmark = True
torch.manual_seed(2018)

ckpt_path = './ckpt'
exp_name = 'R3Net/train_model'

args = {
    'iter_num': 8000,
    'train_batch_size': 10,
    'last_iter': 0,
    'lr': 1e-3,
    'lr_decay': 0.9,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'snapshot': ''
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(300),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(msra10k_path, joint_transform, img_transform, target_transform)
#dataset = train_set(data_size, local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
                                                                num_replicas=2,
                                                                rank=local_rank)
#sampler = DistributedSampler(dataset)
train_loader = DataLoader(dataset=train_set,batch_size=batch_size,sampler=train_sampler)
#train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')


def main():
    net = R3Net()
    net = net.cuda()
    device = torch.device('cuda:%d' % local_rank)
    net = net.to(device)
    net = nn.parallel.DistributedDataParallel(net,
                                                device_ids=[local_rank, ],  # !!!!!!!!!!!!是个List
                                                output_device=0)
    #device = torch.device('cuda:%d' % 1)
    #net = torch.nn.DataParallel(net).module.to(device)
    #net.load_state_dict(torch.load('/home/yyb/pytorch_proj/R3Net/ckpt/R3Net/2020.7.3/1/12500.pth'))

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])


    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)


def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_sim_record, loss5_sim_record = AvgMeter(), AvgMeter()  ##

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs) ##
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            loss3 = criterion(outputs3, labels)
            loss4 = criterion(outputs4, labels)
            loss5 = criterion(outputs5, labels)
            loss6 = criterion(outputs6, labels)

            total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss0_record.update(loss0.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)
            loss3_record.update(loss3.item(), batch_size)
            loss4_record.update(loss4.item(), batch_size)
            loss5_record.update(loss5.item(), batch_size)
            loss6_record.update(loss6.item(), batch_size)


            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f],[lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
                   optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            # if curr_iter == 10500:
            #     torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
            #     torch.save(optimizer.state_dict(),
            #                os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
            if curr_iter % 400 == 0:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d_epoch.pth' % (curr_iter / 1250)))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_epoch_optim.pth' % (curr_iter / 1250)))

            if curr_iter % args['iter_num'] == 0:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
            if curr_iter == args['iter_num']:
                return


if __name__ == '__main__':
    main()

参考文献
1、关于pytorch 使用DDP模式(torch.nn.parallel.DistributedDataParallel)时,DistributedSampler(dataset)用法解释

你可能感兴趣的:(Pytorch)