Pytorch:torch.Generator()

PyTorch 通过 torch.Generator 类来操作随机数的生成

1. 默认的随机数生成器

import torch

# 设置默认的随机数种子
torch.manual_seed(0)

# 查看默认的随机数种子
torch.initial_seed()

2. 指定 torch.Generator随机数生成器

g = torch.Generator()

torch.Generator 实例的方法:

  • manual_seed(int): 设置随机数种子
  • initial_seed(): 获取随机数的种子
# 获取默认的 torch.Generator 实例
g_1 = torch.default_generator

# 查看指定随机数生成器的种子(结果也为 0)
g_1.initial_seed()

上述代码等价于

g_2 = torch.manual_seed(0)

通过关键字参数 generator 指定随机数生成器

# 1. 使用默认的随机数生成器
torch.manual_seed(1)

# 结果 tensor([0, 4, 2, 3, 1])
torch.randperm(5)

# 2. 手动创建随机数生成器
g = torch.Generator().manual_seed(1)

# 结果也为 tensor([0, 4, 2, 3, 1])
torch.randperm(5, generator=g)

3. 查看设备

torch.Generator 实例会区分 CPU 与 GPU 两种设备, 默认为 CPU 类型

# 结果为 device(type='cpu')
g.device

4. 获取状态 (没太懂):get_state()

一个torch.ByteTensor,其包含将生成器恢复到特定时间点的所有必要位。

print(g.get_state())

输出:

tensor([  1, 209, 156,  ...,   0,   0,   0], dtype=torch.uint8)

5. 设置状态:set_state()

g_cpu = torch.Generator()
print(g_cpu.get_state())

g_cpu_other = torch.Generator()
g_cpu_other.manual_seed(1)

g_cpu.set_state(g_cpu_other.get_state())
print(g_cpu.get_state())

输出:

tensor([  1, 209, 156,  ...,   0,   0,   0], dtype=torch.uint8)
tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8)

你可能感兴趣的:(Pytorch系列,pytorch,人工智能,python)