最近由于想加速神经网络模型训练,便开始着手学习pytorch的分布式训练(DDP),结果踩了很多坑,在这里记录一下,便于以后查看,也同时分享给大家。
我是通过下面几篇博客学习pytorch分布式训练的,感觉写得都不错,很清晰明了,强烈推荐第一篇:
在训练过程中,我设置了两个进程,使其以数据并行的方式训练,但是在训练过程中,我发现两个进程的loss竟然不一致:
DDP可以自动实现不同进程间的梯度同步,从而使各进程的参数保持一致。因此出现这种情况主要是因为train_sampler使用了DistributedRandomIdentitySampler,该类的__iter__函数会根据进程的rank将数据集划分为不同部分,从而使各个进程读取的训练集中的数据不同,如下所示:
class DistributedRandomIdentitySampler(Sampler):
'''......此处省略其它代码......'''
# 从list_container中获取当前进程对应的样本索引
list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas]
assert len(list_container) == self.num_samples//self.num_instances
因此,虽然各个进程的模型参数一致,但训练时读取的数据不同,计算出的loss也就会不同。顺便提一句,每个进程的batch_size相同(手动设置),只是读取的数据不同,而DDP会将不同进程的batch数据对应的梯度进行聚合并同步梯度,因此DDP的实际batch_size是设置的batch_size的n倍,n为进程数量。
除此之外还有一种情况,就是不同进程的参数初始化不一致,这也会导致不同进程的训练出现loss值不同的情况。此时应该检查随机数种子是否固定,从而防止参数初始化的随机性。如下所示,即在不同进程中固定set_seed函数的输入seed值:
def set_seed(seed=None):
if seed is None:
return
random.seed(seed)
os.environ['PYTHONHASHSEED'] = ("%s" % seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
在测试阶段如果使用了DistributedInferenceSampler,会导致不同进程读取到的测试集数据不同,因此测试结果也不同,这是错误的。正确做法是testloader不使用DistributedInferenceSampler而直接使用默认的sampler(不指定sampler):
galleryloader = DataLoader(dataset=ImageDataset_test(dataset.gallery, transform=transform_test),
batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
pin_memory=True, drop_last=False, shuffle=False)
解决方法:在使用nn.parallel.DistributedDataParallel封装模型时加上find_unused_parameters=True:
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
解决方法:检查在训练过程中有没有出现同一个模型被前向执行两次及以上的情况,例如:
features1=model(img1)
features2=model(img2)
将同一个模型的前向输出过程写到一行代码中,如下所示:
features1, features2 = model(torch.cat((img1, img2), dim=0)).split(img1.size(0), dim=0)
如果这种方式没有成功解决报错,在nn.parallel.DistributedDataParallel封装模型时加上broadcast_buffers=False,即设置模型在每次迭代时不要提前将缓冲区内的数据覆盖到其它进程:
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=True, broadcast_buffers=False)