Categorical()
的参数有三个,分别为probs
,logits
,validate_args
,通过研究其源码,可以看到:
class Categorical(Distribution):
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None): # 既不传入probs也不传入logits,抛出error
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None: # 传入probs
if probs.dim() < 1: # 传入probs维度异常,抛出error
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else: # 传入logits
if logits.dim() < 1: # 传入logits维度异常,抛出error
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
如果传入的probs
不为空,比如传入probs=[0.4, 0.3, 0.2, 0.1]
,或者probs=[4.0, 3.0, 2.0, 1.0]
,代码是直接对传入的probs
进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为: p j = p j ∑ i = 1 n p i p_j = \frac{p_j}{\sum_{i=1}^{n}p_i} pj=∑i=1npipj经过处理后的所有 p i p_i pi的累加和为1,即 ∑ i = 1 n p i = 1 \sum_{i=1}^{n}p_i=1 i=1∑npi=1
可以通过代码进行验证:
import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.4000, 0.3000, 0.2000, 0.1000])
probs = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.4000, 0.3000, 0.2000, 0.1000])
如果不传入probs
,传入logits
,可以看到代码的处理为:
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
这里需要稍微解释一下 logits.logsumexp()
,这里具体借用大佬的介绍,传送门在这里 —— 【关于LogSumExp】
简单来说,Categorical()
对传入的logits
数据做了以下处理: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)简单来说,就是对logits
中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp
,看字面意思很形象。
说多无益,通过代码看一下具体使用效果,同样,我们传入logits=[0.4, 0.3, 0.2, 0.1]
,或者logits=[4.0, 3.0, 2.0, 1.0]
,具体结果如下所示:
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits) # tensor([-0.4402, -1.4402, -2.4402, -3.4402])
logit = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits) # tensor([-1.2425, -1.3425, -1.4425, -1.5425])
我们对logits=[0.4, 0.3, 0.2, 0.1]
进行验证,使用上面介绍的公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)
可以得到:
import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits) # -0.4401896985611957
这里进一步验证了我们的想法是正确的。
这里也可以看到,如果传入的是probs=[0.4, 0.3, 0.2, 0.1]
,或者probs=[4.0, 3.0, 2.0, 1.0]
,得到的probs
都是一样的,都是tensor([0.4000, 0.3000, 0.2000, 0.1000])
。
如果传入的是logits=[0.4, 0.3, 0.2, 0.1]
,或者logits=[4.0, 3.0, 2.0, 1.0]
,得到的logits
是不一样的,logits
的结果分别是tensor([-0.4402, -1.4402, -2.4402, -3.4402])
和tensor([-1.2425, -1.3425, -1.4425, -1.5425])
,这是因为在logits
中使用到了指数,4和0.4对应的指数值是不同的,所以得到的logits
的值是不同的。
在torch.distributions.Categorical()
中可以通过logits
和probs
获取对应的概率值,其具体实现如下所示:
import torch
import torch.nn.functional as F
def logits_to_probs(logits, is_binary=False):
r"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
multi-dimensional case, the values along the last dimension denote
the log probabilities (possibly unnormalized) of the events.
"""
if is_binary:
return torch.sigmoid(logits)
return F.softmax(logits, dim=-1)
def clamp_probs(probs):
eps = torch.finfo(probs.dtype).eps # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
return probs.clamp(min=eps, max=1 - eps) # 对probs进行处理,probs的最小值为eps,最大值为1-eps
def probs_to_logits(probs, is_binary=False):
r"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
For the multi-dimensional case, the values along the last dimension
denote the probabilities of occurrence of each of the events.
"""
ps_clamped = clamp_probs(probs)
if is_binary:
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
return torch.log(ps_clamped)
@lazy_property
def logits(self):
return probs_to_logits(self.probs)
@lazy_property
def probs(self):
return logits_to_probs(self.logits)
假如传入logits=[4.0, 3.0, 2.0, 1.0]
,在上面我们介绍了,Categorical()
中使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)对传入logits
进行处理。如果想获取Categorical.probs
,代码中的实现方式为:
F.softmax(logits, dim=-1) # 这里不讨论 if is_binary 的情况
通过简单的方式来说明def logits()
的具体实现原理。
softmax
的具体计算可以表示为: X = [ x 1 , x 2 . x 3 ] X = [x_1,x_2.x_3] X=[x1,x2.x3] s o f t m a x ( X ) = [ e x 1 e x 1 + e x 2 + e x 3 , e x 2 e x 1 + e x 2 + e x 3 , e x 3 e x 1 + e x 2 + e x 3 ] softmax(X) = [\frac{e^{x_1}}{e^{x_1} + e^{x_2}+e^{x_3}}, \frac{e^{x_2}}{e^{x_1} + e^{x_2}+e^{x_3}},\frac{e^{x_3}}{e^{x_1} + e^{x_2}+e^{x_3}}] softmax(X)=[ex1+ex2+ex3ex1,ex1+ex2+ex3ex2,ex1+ex2+ex3ex3]
假定传入到Categorical()
的logits
为: l o g i t s = [ l 1 , l 2 , l 3 ] logits = [l_1,l_2,l_3] logits=[l1,l2,l3]使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(∑i=1nexiexj)=log(exj)−log(i=1∑nexi)=xj−log(i=1∑nexi)对logits
进行处理后,可以得到 l o g i t s = [ l 1 − l o g ∑ i = 1 3 e l i , l 2 − l o g ∑ i = 1 3 e l i , l 3 − l o g ∑ i = 1 3 e l i ] logits=[l1-log\sum_{i=1}^3e^{l_i},l2-log\sum_{i=1}^3e^{l_i},l3-log\sum_{i=1}^3e^{l_i}] logits=[l1−logi=1∑3eli,l2−logi=1∑3eli,l3−logi=1∑3eli]
对处理后的logits
进行softmax
处理,可以得到: s o f t m a x ( l i g i t s ) = [ e l 1 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 2 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 3 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i ] softmax(ligits)=[\frac{e^{l1-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l2-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l3-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}}] softmax(ligits)=[el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel1−log∑i=13eli,el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel2−log∑i=13eli,el1−log∑i=13eli+el2−log∑i=13eli+el3−log∑i=13eliel3−log∑i=13eli]
对上述公式进行简单转化:
s o f t m a x ( l o g i t s ) = [ e l 1 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 2 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 3 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i ] softmax(logits)=[\frac{\frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}] softmax(logits)=[log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel1,log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel2,log∑i=13eliel2+log∑i=13eliel3+log∑i=13eliel1log∑i=13eliel3]
最后可以得到: s o f t m a x ( l o g i t s ) = [ e l 1 e l 1 + e l 2 + e l 3 , e l 2 e l 1 + e l 2 + e l 3 , e l 3 e l 1 + e l 2 + e l 3 ] softmax(logits)= [\frac{e^{l_1}}{e^{l_1} + e^{l_2}+e^{l_3}}, \frac{e^{l_2}}{e^{l_1} + e^{l_2}+e^{l_3}},\frac{e^{l_3}}{e^{l_1} + e^{l_2}+e^{l_3}}] softmax(logits)=[el1+el2+el3el1,el1+el2+el3el2,el1+el2+el3el3]
即可以理解为直接对初始输入的logits
进行softmax
处理即可得到对应的probs
。
通过代码来验证一下:
import math
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.probs) # tensor([0.6439, 0.2369, 0.0871, 0.0321])
num = math.exp(4.0)/(math.exp(4.0)+math.exp(3.0)+math.exp(2.0)+math.exp(1.0))
print(num) # 0.6439142598879722
可以看到,通过接口获取的probs
和我们手动计算得到的probs
的结果是一致的,验证了我们的想法。
如果传入的数据为probs
,使用Categorical.logits
获取对应的对数概率值,代码实现为:
torch.log(ps_clamped)
这里我们不讨论 is_binary
的情况,可以看到代码实现只是简单地使用了对数转换,我们通过代码检查一下:
import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.logits) # tensor([-0.9163, -1.2040, -1.6094, -2.3026])
print(torch.log(torch.tensor(0.4))) # tensor(-0.9163)
可以看到,两者的值是一样的,验证了我们的想法。
代码简单如下所示:
def sample(self, sample_shape=torch.Size()):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
# self._num_events = self._param.size()[-1]
probs_2d = self.probs.reshape(-1, self._num_events) # 维度变换
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T # 采样
return samples_2d.reshape(self._extended_shape(sample_shape)) # 维度变换
sample()的操作比较简单,这里主要记录两个地方:
(1)torch.multinomial()
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor
input
的每一行进行num_samples
次取值,输出为每次取值的索引;input
每一行中的元素是该索引被采样的权重。如果元素为0,那么其他位置被采样完之前,这个位置都不会被采样;replacement=False
为不放回采样,replacement=True
为有放回采样;num_samoles
的值必须小于等于input.size(-1)
的值,否则会报错;举个例子:
import torch
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
torch.multinomial(weights, 2) # tensor([1, 2])
torch.multinomial(weights, 6) # 不放回取样,报错,sample n_sample > prob_dist.size(-1)
torch.multinomial(weights, 4, replacement=True) # tensor([2, 1, 1, 2])
(2)numel()