pytorch设置随机种子来复现代码结果

文章目录

  • 前言
  • 一、随机种子的介绍
          • 总结一下,随机种子是一个起始值,用于生成伪随机数序列。指定相同的随机种子可以确保每次生成相同的随机数序列。
  • 二、代码
  • 三、代码解读
        • 1、os.environ[‘PYTHONHASHSEED’] = str(seed)
        • 2、torch.manual_seed(seed)
        • 3、torch.backends.cudnn.deterministic = True
        • 4、torch.backends.cudnn.benchmark = False
  • 四、补充


前言

深度学习中,模型完成后运行代码,会发现每次运行出来的结果都不一样。这种情况会导致我们在原有的模型基础做出修改后而跑出的结果,无法确定是改进了还是变得更差了。所以要固定随机种子来使每次运行的结果一样,以便比较不同模型的优劣。


一、随机种子的介绍

(1)随机种子是用于生成伪随机数序列的起始值。在计算机中,所谓的"随机"实际上是通过特定的算法生成的伪随机序列。而随机种子就是这个算法的起始输入。
(2)在很多编程语言和应用中,可以使用随机种子来控制伪随机数的生成。通过指定相同的随机种子,可以确保每次生成的随机数序列是一致的。这在某些需要可复现性的应用中非常有用,比如调试和测试。
(3)随机种子可以是任何整数值。常见的做法是使用当前系统时间作为随机种子,因为时间通常是不断变化的,这样可以产生看似随机的序列。也可以手动指定一个固定的种子值,以产生确定性的随机数序列。

总结一下,随机种子是一个起始值,用于生成伪随机数序列。指定相同的随机种子可以确保每次生成相同的随机数序列。

二、代码

将下面这段代码放在程序开始之前,代码如下(示例):

# 定义随机种子固定的函数
def get_random_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
# 调用函数,设置随机种子为73
get_random_seed(73)

上述是使用 PyTorch 时比较常见的设置随机种子的代码。通常来说,在不同的 CPU 和 GPU 上,即使设置了相同的随机种子,也不能够保证能复现出相同的结果。但是理论上,在同一设备上是能够进行复现的,一般需要在主程序前,同时固定 PyTorchPythonNumpy 的随机种子。

三、代码解读

1、os.environ[‘PYTHONHASHSEED’] = str(seed)

主要是为了禁止 hash 随机化。

2、torch.manual_seed(seed)

torch.cuda.manual_seed(seed) 的功能是类似的,一个是设置当前 CPU 的随机种子,而另一个是设置当前 GPU 的随机种子,如果存在多个 GPU 可以使用 torch.cuda.manual_seed_all(seed) 对全部 GPU 都设置随机种子。

3、torch.backends.cudnn.deterministic = True

这行代码就如同其名字一样,确定是否使用确定性卷积算法(默认是 False ),如果为 True,则能保证在相同设备上的相同输入能够实现相同输出。

4、torch.backends.cudnn.benchmark = False

如果设置 torch.backends.cudnn.benchmark = True 会让程序在一开始时增加额外的预处理时间,以让整个 model 的卷积层寻找到最适合的、最有效率的卷积实现算法,进而实现网络加速,但是与此同时可能会导致结果不可复现。

四、补充

在确保实验结果可复现,还应该要注意 DataLoader,这是经常被忽略掉的地方,做法是使用 worker_init_fn 去保证可复现性(这里主要是多进程加载数据时会产生这个问题)。

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker
)

提示:我在复现时没有用到这部分代码,具体是否采用该复现方式需要根据自己的情况来决定。


本篇是我引用下面博主文章,在此基础上完成的
原文链接:https://blog.csdn.net/weixin_48249563/article/details/115031039

你可能感兴趣的:(pytorch,人工智能,python,深度学习)