class torch.distributions.multinomial.Multinomial()

class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)

创建由 total_count 和 probs 或 logits (但不是两者)参数化的多项式分布。

参数:

total_count -int-试验次数

probs -Tensor-事件概率

logits -Tensor-事件对数概率

注意:

1、当且仅当log_prob()被调用的时候,total_count 不需要指定

2、probs 是非负的、有限的、具有非零和,并且沿最后一维标准化后的和为1

3、sample() 所有参数和样本都需要一个共享的 total_count.

4、log_probs() 允许每个参数和样本使用不同的 total_count.

 

>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
>>> x = m.sample()  # equal probability of 0, 1, 2, 3
tensor([ 21.,  24.,  30.,  25.])

>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
tensor([-4.1338])

你可能感兴趣的:(pytorch,pytorch)