【技术总结】Pytorch 复现性设置

前言

有时候我们需要保证程序的可复现性,比如需要提交可复现代码的比赛。因此需要学习如何保证pytorch的代码可以复现。

代码设置

具体的代码如下所示,注意需要在程序开始有随机化操作之前调用这个函数,比如main函数的第一行代码。

def seed_everything(seed: int):  # fix the seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

代码解释

random.seed(seed)
比较容易理解,就是对python原生的包random进行随机种子初始化。

np.random.seed(seed)
由于程序中经常用 numpy, 所以也需要设置 numpy的种子。

torch.manual_seed(seed)
pytorch 中的随机种子设置。

torch.backends.cudnn.benchmark = False,
如果这个选项被打开的话,一些运算会从多个实现里面测试,找到最快的实现方法。但是如果切换机器/甚至不切换(计算资源发生变化),选择的实现方法可能不同,得到的结果也可能不同。因此需要关闭改选项。

torch.backends.cudnn.deterministic = True 或者 torch.use_deterministic_algorithms(True)
本身pytorch的有些运算的实现就会输出非确定性的结果,因此将这个开关打开,pytorch如果遇到非确定的实现就会报错。并且pytorch也会主动避免使用非确定实现,选择确定性的实现。

特殊情形

如果我们使用多进程,比如在pytorch DDP训练模式的时候,也需要注意Dataloader所使用的随机种子。一般来说,如果开启多进程,每个进程的pytorch的随机种子base_seed + worker_id。base_seed 是主进程所使用的随机种子。如果想获得进程使用的随机种子,可以使用命令torch.utils.data.get_worker_info().seed 或者 torch.initial_seed()。
需要注意的一点是即使pytorch会更改种子,其他的库比如numpy,random的随机种子可能和主线程保持一致。为了防止在这种情况下每个线程使用一样的随机数字,可以通过更改dataloader的 woker_init_fn, generator来修正。
如下所示:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

参考文档

pytorch dataloader randomness

你可能感兴趣的:(AI算法常用技术,pytorch,python,深度学习)