有关于pytorch模型训练的可复现性

1. 同设备

如果是在硬件设备都一致的情况下,可以用以下几种措施来增强自己的模型性能可复现性:

随机种子

其中最主要的就是随机种子(包括python,numpy和pytorch等等),下面这段代码是笔者一直在用的:

def seed_torch(seed=42,deter=False):
    '''
    `deter` means use deterministic algorithms for GPU training reproducibility, 
    if set `deter=True`, please set the environment variable `CUBLAS_WORKSPACE_CONFIG` in advance
    '''
    seed = int(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
    torch.set_deterministic(deter)  # avoiding nondeterministic algorithms (see https://pytorch.org/docs/stable/notes/randomness.html)
    torch.use_deterministic_algorithms(deter)

注意:上述代码在torch==1.8.0上使用没有问题,但torch.set_deterministic这个接口在一些旧版本的torch里面(torch==1.12)是不支持的,请注释掉这一行。

最后两行的torch.set_deterministictorch.use_deterministic_algorithms是为了防止torch使用某些非决定性算法(详见官网描述),如果设置deter=True的话就可以使用决定性算法,降低随机性干扰,但需要提前在你的代码文件开头设置CUBLAS_WORKSPACE_CONFIG环境变量,否则会报错:

## os设置环境变量
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'  ## enable the deterministic method
## 当然也可以在cmd里面设置,`XXX.py`是你要运行的脚本
CUBLAS_WORKSPACE_CONFIG=:16:8 python XXX.py args

一旦固定随机种子,和使用决定性算法之后,在同一台设备上,基本上性能就能保持一致了。

Dataloader和shuffle

最好不要在训练的时候使用shuffle,如果非要用的话记得也要固定随机种子:

import random 
random.Random(args.seed).shuffle(train_data)

然后Dataloader里面的sampler也最好用SequentialSampler.

用CPU

参考torch官方的申明:
有关于pytorch模型训练的可复现性_第1张图片
事实上pytorch并不担保模型训练的完全可复现,如果想要严格保证不收随机性影响的话,那就用cpu训练,而不是GPU。很多使用GPU的优化操作(包括上述提到的非决定性方法)都会造成随机性。

2. 跨设备

很遗憾,跨设备可复现性基本不可能实现。硬件的差别,比方说cuda不同,浮点数优化产生细微差别,最后性能可能就会差的很多。
有关于pytorch模型训练的可复现性_第2张图片
详见该博客:Reproducibility over Different Machines

3. 参考:

  • https://pytorch.org/docs/stable/notes/randomness.html
  • https://discuss.pytorch.org/t/random-seed-with-external-gpu/102260/3
  • https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

你可能感兴趣的:(深度学习,pytorch,深度学习,python)