DDP训练大致是一个GPU开一个线程,如果有两个GPU,则将dataset分成2份,然后一个GPU读取一份
下面的代码能正确使用DDP分布式训练,直接参考即可
注:本代码只适用于单机多卡训练,多机多卡的由于资源有限还没试过
在终端的运行命令:
python -m torch.distributed.launch --nproc_per_node 2 train.py
其中2表示你有几个GPU
import datetime
import os
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import joint_transforms
from config import msra10k_path
from datasets import ImageFolder
from misc import AvgMeter, check_mkdir
from model import R3Net
from torch.backends import cudnn
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
# torch.distributed.init_process_group(backend="nccl")
dist.init_process_group(backend='nccl', init_method='env://')
batch_size = 12 # 主卡上的batchsize
data_size = 25 # 总共的batchsize
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
#dist.init_process_group(backend='nccl', init_method='env://', world_size=2, rank=local_rank)
print(local_rank) # 注意!!!!!!!!!!!!! 会先输出0 再输出1
# device = torch.device("cuda", local_rank)
cudnn.benchmark = True
torch.manual_seed(2018)
ckpt_path = './ckpt'
exp_name = 'R3Net/train_model'
args = {
'iter_num': 8000,
'train_batch_size': 10,
'last_iter': 0,
'lr': 1e-3,
'lr_decay': 0.9,
'weight_decay': 5e-4,
'momentum': 0.9,
'snapshot': ''
}
joint_transform = joint_transforms.Compose([
joint_transforms.RandomCrop(300),
joint_transforms.RandomHorizontallyFlip(),
joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()
train_set = ImageFolder(msra10k_path, joint_transform, img_transform, target_transform)
#dataset = train_set(data_size, local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
num_replicas=2,
rank=local_rank)
#sampler = DistributedSampler(dataset)
train_loader = DataLoader(dataset=train_set,batch_size=batch_size,sampler=train_sampler)
#train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True)
criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
def main():
net = R3Net()
net = net.cuda()
device = torch.device('cuda:%d' % local_rank)
net = net.to(device)
net = nn.parallel.DistributedDataParallel(net,
device_ids=[local_rank, ], # !!!!!!!!!!!!是个List
output_device=0)
#device = torch.device('cuda:%d' % 1)
#net = torch.nn.DataParallel(net).module.to(device)
#net.load_state_dict(torch.load('/home/yyb/pytorch_proj/R3Net/ckpt/R3Net/2020.7.3/1/12500.pth'))
optimizer = optim.SGD([
{'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
'lr': 2 * args['lr']},
{'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
'lr': args['lr'], 'weight_decay': args['weight_decay']}
], momentum=args['momentum'])
if len(args['snapshot']) > 0:
print('training resumes from ' + args['snapshot'])
net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
optimizer.param_groups[0]['lr'] = 2 * args['lr']
optimizer.param_groups[1]['lr'] = args['lr']
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
open(log_path, 'w').write(str(args) + '\n\n')
train(net, optimizer)
def train(net, optimizer):
curr_iter = args['last_iter']
while True:
total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
loss3_sim_record, loss5_sim_record = AvgMeter(), AvgMeter() ##
for i, data in enumerate(train_loader):
optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
) ** args['lr_decay']
optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
) ** args['lr_decay']
inputs, labels = data
batch_size = inputs.size(0)
inputs = Variable(inputs).cuda()
labels = Variable(labels).cuda()
optimizer.zero_grad()
outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs) ##
loss0 = criterion(outputs0, labels)
loss1 = criterion(outputs1, labels)
loss2 = criterion(outputs2, labels)
loss3 = criterion(outputs3, labels)
loss4 = criterion(outputs4, labels)
loss5 = criterion(outputs5, labels)
loss6 = criterion(outputs6, labels)
total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
total_loss.backward()
optimizer.step()
total_loss_record.update(total_loss.item(), batch_size)
loss0_record.update(loss0.item(), batch_size)
loss1_record.update(loss1.item(), batch_size)
loss2_record.update(loss2.item(), batch_size)
loss3_record.update(loss3.item(), batch_size)
loss4_record.update(loss4.item(), batch_size)
loss5_record.update(loss5.item(), batch_size)
loss6_record.update(loss6.item(), batch_size)
curr_iter += 1
log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
'[loss4 %.5f], [loss5 %.5f], [loss6 %.5f],[lr %.13f]' % \
(curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
optimizer.param_groups[1]['lr'])
print(log)
open(log_path, 'a').write(log + '\n')
# if curr_iter == 10500:
# torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
# torch.save(optimizer.state_dict(),
# os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
if curr_iter % 400 == 0:
torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d_epoch.pth' % (curr_iter / 1250)))
torch.save(optimizer.state_dict(),
os.path.join(ckpt_path, exp_name, '%d_epoch_optim.pth' % (curr_iter / 1250)))
if curr_iter % args['iter_num'] == 0:
torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
torch.save(optimizer.state_dict(),
os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
if curr_iter == args['iter_num']:
return
if __name__ == '__main__':
main()
参考文献
1、关于pytorch 使用DDP模式(torch.nn.parallel.DistributedDataParallel)时,DistributedSampler(dataset)用法解释