pytorch的随机抽样

目录

      • 1.如何设置随机种子?
      • 2.如何进行伯努利分布采样?
      • 3.如何进行多项式分布抽样?
      • 4.如何进行标准分布抽样?

1.如何设置随机种子?

torch.manual_seed(123)      #manual  手控的
# 如没有手动设置,则返回系统生成的随机种子;否则,返回手动设置的随机种子
seed = torch.initial_seed()
print("seed:{}".format(seed))
# 返回随机生成器的状态
state = torch.get_rng_state()       # state 状态
print("state:{}".format(state),len(state))
---------------------------------------------------------------------
result:
seed:123
state:tensor([123,   0,   0,  ...,   0,   0,   0], dtype=torch.uint8) 5056

2.如何进行伯努利分布采样?

伯努利分布(0-1分布)的结果只有0和1,并且0<=p<=1

torch.manual_seed(123)
a = torch.rand(3,3)
print(a)
b = torch.bernoulli(a)      # bernoulli()伯努利
print(b)
---------------------------------------------------------------------
result:
tensor([[0.2961, 0.5166, 0.2517],
        [0.6886, 0.0740, 0.8665],
        [0.1366, 0.1025, 0.1841]])
tensor([[0., 1., 0.],
        [1., 0., 1.],
        [0., 0., 0.]])

3.如何进行多项式分布抽样?

weights1 = torch.Tensor([20,10,3,2])
# torch.multinomial 第一个参数为多项式权重,可以是向量,也可以是矩阵,有权重决定‘下标’的抽样
# 数值越大,被采样到的几率越高
# 若为向量:replacement 代表是否又放回的抽样 True 行数=1 列数由num_sample指定   False: 行数为1  列数不超过 weights
# 不放回的采样个数当然要小于采样对象内部的元素个数
a = torch.multinomial(weights1,num_samples=4,replacement=False)   # multinomial 多项式
b = torch.multinomial(weights1,num_samples=10,replacement=True)
print(a)
print(b)
# 若为矩阵:replacement 代表是否又放回的抽样 True 行数=weights的行数 列数由num_sample指定   False: 行数为weights的行数  列数不超过 weights
weights2 = torch.Tensor([[20,10,3,2],[30,4,5,60]])
c = torch.multinomial(weights2,num_samples=3,replacement=False)
d = torch.multinomial(weights2,num_samples=15,replacement=True)
print(c)
print(d)
---------------------------------------------------------------------
result:
tensor([0, 1, 3, 2])
tensor([0, 1, 1, 1, 1, 2, 0, 0, 0, 2])
tensor([[1, 0, 2, 3],
        [0, 3, 2, 1]])
tensor([[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1],
        [0, 0, 0, 2, 3, 0, 3, 3, 1, 3, 0, 0, 3, 3, 3]])

4.如何进行标准分布抽样?

链接: 正态分布.

x = torch.normal(mean=0.5,std=torch.arange(0.1,1,0.0001))     # normal 标准
print(x)
y = torch.normal(mean=torch.arange(0.1,1,0.1),std=0.5)
print(y)
z= torch.normal(mean=torch.arange(0.1,1,0.1),std=torch.arange(0.1,1,0.1))
print(z)
plt.plot(torch.arange(0.1,1,0.0001).data.numpy(),x.data.numpy())
plt.show()
---------------------------------------------------------------------
result:
tensor([ 0.3764,  0.5162,  0.3239,  ...,  1.4991, -0.4632, -0.2401])
tensor([ 0.2708, -0.0637,  0.1827,  0.2584,  0.1495,  0.4239,  1.6819,  0.8798,
         1.2386])
tensor([0.0554, 0.3297, 0.2785, 1.1061, 0.4193, 0.6390, 0.5043, 1.1231, 1.8046])

正态分布采样结果:
pytorch的随机抽样_第1张图片

你可能感兴趣的:(PyTorch的攀登年华,pytorch,python,numpy)