Pytorch有两种方法实现多GPU训练,分别是DataParallel(DP)和DistributedDataParallel(DDP)。DP实现简单,但没有完全利用所有GPU资源,DDP实现相对复杂,但是更快,我建议使用DDP。
DP使用torch.nn.DataParallel。原理是,假设用K个GPU训练,前向传播阶段,一个batch的数据会被平均分成K份,模型也会复制K份,分别送到每个GPU上;反向传播阶段,各复制模型产生的梯度会被累加到主模型上。batch size应该大于使用的GPU数量。
我们写一份简单的程序实现一下torch.nn.DataParallel:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, input):
output = self.fc(input)
print("\tIn: input size", input.size(),
"output size", output.size())
return output
# 超参数
input_size = 5 # 输入维数
output_size = 2 # 输出维数
batch_size = 30 # batch size
data_size = 30 # 样本数
gpus = [0, 1, 2] # GPU索引
rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
batch_size=batch_size, shuffle=True)
model = Model(input_size, output_size)
# Multi-GPUS操作
model = nn.DataParallel(model, device_ids=gpus)
model = model.to(gpus[0]) # 主模型
for data in rand_loader:
input = data.to(gpus[0])
output = model(input)
print("Out: input size", input.size(),
"output_size", output.size())
GPUS = [0, 1, 2]
GPUS = [0, 1]
建议使用DDP代替DP。DP基于单进程,多线程,只能在一个机器上多卡训练,由于多线程之间GIL连接引入了额外开销,即使在一个机器上也比DDP慢;DDP基于多进程,可以在多个机器上训练,每个GPU由专有进程控制,训练更快!
下面的介绍都针对单机器多卡训练,因为这是我们最常见的情况。DDP的基本原理也是将模型复制到每个GPU上,收集每个GPU产生的梯度,平均这些梯度更新模型,然后同步所有GPU上的模型。首先了解一些必要的概念:
DDP的实现流程可以概括为:设置进程组,分割数据,DDP化模型,训练模型,clean up。我们首先介绍各部分对应的代码,最后给出整体demo。
import torch.distributed as dist
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost' # 主节点地址
os.environ['MASTER_PORT'] = '12355' # 主节点端口,用于进程之间通讯
dist.init_process_group("nccl", rank=rank, world_size=world_size)
可以使用torch.utils.data.distributed.DistributedSampler分割数据的索引,然后将每组索引送入dataloader组成batch。请注意,这里与DP有区别,DP将一个batch的数据分成K份,然后送入每个GPU,因此设置的batch size等于训练使用的batch size;DDP先将训练数据分成K份,然后送入每个GPU,再生成batch,因此实际训练使用的batch size等于设置的batch size再乘以K!
from torch.utils.data.distributed import DistributedSampler
def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
dataset = Your_Dataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
return dataloader
假设K = 3,共有10条数据。
请注意,设置num_workers = 0和pin_memory = False可以避免DDP下一些不必要的BUG。
from torch.nn.parallel import DistributedDataParallel as DDP
model = Model().to(rank)
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
这边也有一些需要注意的点!
使用spawn方法管理多进程,对于多进程来说,所有子进程和父进程运行的是相同的程序。
import torch.multiprocessing as mp
if __name__ == '__main__':
world_size = 3
mp.spawn(
main,
args=(world_size),
nprocs=world_size
)
main是运行在每一个进程上的训练过程,main的第一个形参必须是rank,spawn会自动传递这个值给main,所以spawn.args只写了world_size参数。rank = 0是默认的主节点。同时注意在epoch和iter的循环之间必须加上dataloader.sampler.set_epoch(epoch),数据才能正确分割。
main的最后一行是clean up操作。
def cleanup():
dist.destroy_process_group()
在主节点保存网络,保存函数后面要加上dist.barrier()函数,暂停此时其他进程的运行,等待网络保存完毕。
if rank == 0:
model.save_nets()
dist.barrier() # 保存结束前其他 process 不运行
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
self.label = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return self.len
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, input):
output = self.fc(input)
return output
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost' # 主节点地址
os.environ['MASTER_PORT'] = '12355' # 主节点端口,用于进程之间通讯
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
dataset = RandomDataset(5, 60)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers,
drop_last=False, shuffle=False, sampler=sampler)
return dataloader
def cleanup():
dist.destroy_process_group()
def main(rank, world_size):
# 建立进程组
setup(rank, world_size)
print("Rank", rank)
# 分割数据
dataloader = prepare(rank, world_size, batch_size=10)
# DDP化模型
model = Model(5, 5).to(rank)
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
# 训练模型
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
for epoch in range(10):
dataloader.sampler.set_epoch(epoch) # 这个必须要加!
for x, y in dataloader:
optim.zero_grad()
x = x.to(rank)
y = y.to(rank)
pred = model(x)
l = loss(pred, y)
l.backward()
optim.step()
# 保存网络
# if rank == 0:
# model.save_nets() # 自己编写保存参数函数
# dist.barrier()
# clean up
cleanup()
if __name__ == "__main__":
world_size = 3
mp.spawn(
main,
args=(world_size,),
nprocs=world_size,
)
链接1
链接2