torch.multinomial()理解

torch.multinomial(input, num_samples,replacement=False, out=None) → LongTensor

作用是对input的每一行做n_samples次取值,输出的张量是每一次取值时input张量对应行的下标。

输入是一个input张量,一个取样数量,和一个布尔值replacement。

input张量可以看成一个权重张量,每一个元素代表其在该行中的权重。如果有元素为0,那么在其他不为0的元素

被取干净之前,这个元素是不会被取到的。

n_samples是每一行的取值次数,该值不能大于每一样的元素数,否则会报错。

replacement指的是取样时是否是有放回的取样,True是有放回,False无放回。

看官方给的例子:
>>> weights = torch.Tensor([0, 10, 3, 0]) # create a Tensor of weights
>>> torch.multinomial(weights, 4)

 1
 2
 0
 0
[torch.LongTensor of size 4]

>>> torch.multinomial(weights, 4, replacement=True)

 1
 2
 1
 2
[torch.LongTensor of size 4]

输入是[0,10,3,0],也就是说第0个元素和第3个元素权重都是0,在其他元素被取完之前是不会被取到的。

所以第一个multinomial取4次,可以试试重复运行这条命令,发现只会有2种结果:[1 2 0 0]以及[2 1 0 0],以[1 2 0 0]这种情况居多。这其实很好理解,第1个元素权重比第2个元素权重要大,所以先取第1个元素的概率就会大。在第1和2个元素取完之后,剩下了2个没有权重的元素,它们才会被取到。但实际上权重为0的元素被取到时也不会显示正确的下标,关于0的下标问题我还没有想到很合理的解释,先行略过。

而第二个multinomial取4次,发现就只会出现1和2这两个元素了。这是因为replacement为真,所以有放回,就永远也不会取到权重为0的元素了。

再试试输入二维张量,则返回的也会成为一个二维张量,行数为输入的行数,列数为n_samples,即每一行都取了n_samples次,取法和一维张量相同。

你可能感兴趣的:(PyTorch学习笔记)