torch.multinomial用法

以输入的张量作为权重进行多分布采样

torch.multinomial(input, num_samples,replacement=False, out=None) → LongTensor

对input的每一行进行num_samples次取值,输出为每次取值的索引。

input每一行中的元素是该索引被采样的权重。如果元素为0,那么其他位置被采样完之前,这个位置都不会被采样。

replacement=False为不放回采样,True为有放回采样。

例子:

import torch
input = torch.Tensor([5, 0, 10, 0])
#无放回
output = torch.multinomial(input, num_samples=4)
print(output)
#有放回
output = torch.multinomial(input, num_samples=4, replacement=True)
print(output)

输出:

tensor([2, 0, 3, 1])
tensor([2, 2, 2, 2])

你可能感兴趣的:(pytorch,人工智能,深度学习,机器学习)