torch.distributions.Categorical()的简单记录

目录

  • 一、传入参数probs和参数logits的区别
  • 二、通过probs和logits获取对应的概率值
  • 三、sample()采样

一、传入参数probs和参数logits的区别

Categorical()的参数有三个,分别为probslogitsvalidate_args,通过研究其源码,可以看到:

class Categorical(Distribution):
	def __init__(self, probs=None, logits=None, validate_args=None):
	    if (probs is None) == (logits is None):  # 既不传入probs也不传入logits,抛出error
	        raise ValueError("Either `probs` or `logits` must be specified, but not both.")
	    if probs is not None:  # 传入probs
	        if probs.dim() < 1:  # 传入probs维度异常,抛出error
	            raise ValueError("`probs` parameter must be at least one-dimensional.")
	        self.probs = probs / probs.sum(-1, keepdim=True)
	    else:  # 传入logits
	        if logits.dim() < 1:  # 传入logits维度异常,抛出error
	            raise ValueError("`logits` parameter must be at least one-dimensional.")
	        # Normalize
	        self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
	    self._param = self.probs if probs is not None else self.logits
	    self._num_events = self._param.size()[-1]
	    batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
	    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)

如果传入的probs不为空,比如传入probs=[0.4, 0.3, 0.2, 0.1],或者probs=[4.0, 3.0, 2.0, 1.0],代码是直接对传入的probs进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为: p j = p j ∑ i = 1 n p i p_j = \frac{p_j}{\sum_{i=1}^{n}p_i} pj=i=1npipj经过处理后的所有 p i p_i pi的累加和为1,即 ∑ i = 1 n p i = 1 \sum_{i=1}^{n}p_i=1 i=1npi=1

可以通过代码进行验证:

import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.4000, 0.3000, 0.2000, 0.1000])

probs = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(probs=probs)
print(pd.probs)  # tensor([0.4000, 0.3000, 0.2000, 0.1000])

如果不传入probs,传入logits,可以看到代码的处理为:

self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)

这里需要稍微解释一下 logits.logsumexp(),这里具体借用大佬的介绍,传送门在这里 —— 【关于LogSumExp】

简单来说,Categorical()对传入的logits数据做了以下处理: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)简单来说,就是对logits中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp,看字面意思很形象。

说多无益,通过代码看一下具体使用效果,同样,我们传入logits=[0.4, 0.3, 0.2, 0.1],或者logits=[4.0, 3.0, 2.0, 1.0],具体结果如下所示:

import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits)  # tensor([-0.4402, -1.4402, -2.4402, -3.4402])

logit = torch.tensor([0.4,0.3,0.2,0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits)  # tensor([-1.2425, -1.3425, -1.4425, -1.5425])

我们对logits=[0.4, 0.3, 0.2, 0.1]进行验证,使用上面介绍的公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)
可以得到:

import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits)  # -0.4401896985611957

这里进一步验证了我们的想法是正确的。

这里也可以看到,如果传入的是probs=[0.4, 0.3, 0.2, 0.1],或者probs=[4.0, 3.0, 2.0, 1.0],得到的probs都是一样的,都是tensor([0.4000, 0.3000, 0.2000, 0.1000])

如果传入的是logits=[0.4, 0.3, 0.2, 0.1],或者logits=[4.0, 3.0, 2.0, 1.0],得到的logits是不一样的,logits的结果分别是tensor([-0.4402, -1.4402, -2.4402, -3.4402])tensor([-1.2425, -1.3425, -1.4425, -1.5425]),这是因为在logits中使用到了指数,4和0.4对应的指数值是不同的,所以得到的logits的值是不同的。

二、通过probs和logits获取对应的概率值

torch.distributions.Categorical()中可以通过logitsprobs获取对应的概率值,其具体实现如下所示:

import torch
import torch.nn.functional as F
def logits_to_probs(logits, is_binary=False):
    r"""
    Converts a tensor of logits into probabilities. Note that for the
    binary case, each value denotes log odds, whereas for the
    multi-dimensional case, the values along the last dimension denote
    the log probabilities (possibly unnormalized) of the events.
    """
    if is_binary:
        return torch.sigmoid(logits)
    return F.softmax(logits, dim=-1)

def clamp_probs(probs):
    eps = torch.finfo(probs.dtype).eps  # 获取probs对应的dtype数据类型使得1.0 + eps != 1.0 的最小值
    return probs.clamp(min=eps, max=1 - eps)  # 对probs进行处理,probs的最小值为eps,最大值为1-eps

def probs_to_logits(probs, is_binary=False):
    r"""
    Converts a tensor of probabilities into logits. For the binary case,
    this denotes the probability of occurrence of the event indexed by `1`.
    For the multi-dimensional case, the values along the last dimension
    denote the probabilities of occurrence of each of the events.
    """
    ps_clamped = clamp_probs(probs)
    if is_binary:
        return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
    return torch.log(ps_clamped)

@lazy_property
def logits(self):
    return probs_to_logits(self.probs)

@lazy_property
def probs(self):
    return logits_to_probs(self.logits)

假如传入logits=[4.0, 3.0, 2.0, 1.0],在上面我们介绍了,Categorical()中使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)对传入logits进行处理。如果想获取Categorical.probs,代码中的实现方式为:

F.softmax(logits, dim=-1)  # 这里不讨论 if is_binary 的情况

通过简单的方式来说明def logits()的具体实现原理。

softmax的具体计算可以表示为: X = [ x 1 , x 2 . x 3 ] X = [x_1,x_2.x_3] X=[x1,x2.x3] s o f t m a x ( X ) = [ e x 1 e x 1 + e x 2 + e x 3 , e x 2 e x 1 + e x 2 + e x 3 , e x 3 e x 1 + e x 2 + e x 3 ] softmax(X) = [\frac{e^{x_1}}{e^{x_1} + e^{x_2}+e^{x_3}}, \frac{e^{x_2}}{e^{x_1} + e^{x_2}+e^{x_3}},\frac{e^{x_3}}{e^{x_1} + e^{x_2}+e^{x_3}}] softmax(X)=[ex1+ex2+ex3ex1,ex1+ex2+ex3ex2,ex1+ex2+ex3ex3]

假定传入到Categorical()logits为: l o g i t s = [ l 1 , l 2 , l 3 ] logits = [l_1,l_2,l_3] logits=[l1,l2,l3]使用公式: l o g ( e x j ∑ i = 1 n e x i ) = l o g ( e x j ) − l o g ( ∑ i = 1 n e x i ) = x j − l o g ( ∑ i = 1 n e x i ) log(\frac{e^{x_j}}{\sum_{i=1}^{n}e^{x_i}}) = log(e^{x_j})-log(\sum_{i=1}^{n}e^{x_i}) = x_j - log(\sum_{i=1}^{n}e^{x_i}) log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi)logits进行处理后,可以得到 l o g i t s = [ l 1 − l o g ∑ i = 1 3 e l i , l 2 − l o g ∑ i = 1 3 e l i , l 3 − l o g ∑ i = 1 3 e l i ] logits=[l1-log\sum_{i=1}^3e^{l_i},l2-log\sum_{i=1}^3e^{l_i},l3-log\sum_{i=1}^3e^{l_i}] logits=[l1logi=13eli,l2logi=13eli,l3logi=13eli]
对处理后的logits进行softmax处理,可以得到: s o f t m a x ( l i g i t s ) = [ e l 1 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 2 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i , e l 3 − l o g ∑ i = 1 3 e l i e l 1 − l o g ∑ i = 1 3 e l i + e l 2 − l o g ∑ i = 1 3 e l i + e l 3 − l o g ∑ i = 1 3 e l i ] softmax(ligits)=[\frac{e^{l1-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l2-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}},\frac{e^{l3-log\sum_{i=1}^3e^{l_i}}}{e^{l1-log\sum_{i=1}^3e^{l_i}} + e^{l2-log\sum_{i=1}^3e^{l_i}}+e^{l3-log\sum_{i=1}^3e^{l_i}}}] softmax(ligits)=[el1logi=13eli+el2logi=13eli+el3logi=13eliel1logi=13eli,el1logi=13eli+el2logi=13eli+el3logi=13eliel2logi=13eli,el1logi=13eli+el2logi=13eli+el3logi=13eliel3logi=13eli]
对上述公式进行简单转化:
s o f t m a x ( l o g i t s ) = [ e l 1 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 2 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i , e l 3 l o g ∑ i = 1 3 e l i e l 2 l o g ∑ i = 1 3 e l i + e l 3 l o g ∑ i = 1 3 e l i + e l 1 l o g ∑ i = 1 3 e l i ] softmax(logits)=[\frac{\frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}},\frac{\frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}}}{\frac{e^{l2}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l3}}{log\sum_{i=1}^3e^{l_i}} + \frac{e^{l1}}{log\sum_{i=1}^3e^{l_i}}}] softmax(logits)=[logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel1,logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel2,logi=13eliel2+logi=13eliel3+logi=13eliel1logi=13eliel3]
最后可以得到: s o f t m a x ( l o g i t s ) = [ e l 1 e l 1 + e l 2 + e l 3 , e l 2 e l 1 + e l 2 + e l 3 , e l 3 e l 1 + e l 2 + e l 3 ] softmax(logits)= [\frac{e^{l_1}}{e^{l_1} + e^{l_2}+e^{l_3}}, \frac{e^{l_2}}{e^{l_1} + e^{l_2}+e^{l_3}},\frac{e^{l_3}}{e^{l_1} + e^{l_2}+e^{l_3}}] softmax(logits)=[el1+el2+el3el1,el1+el2+el3el2,el1+el2+el3el3]
即可以理解为直接对初始输入的logits进行softmax处理即可得到对应的probs

通过代码来验证一下:

import math
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)

print(pd.probs)  # tensor([0.6439, 0.2369, 0.0871, 0.0321])

num = math.exp(4.0)/(math.exp(4.0)+math.exp(3.0)+math.exp(2.0)+math.exp(1.0))
print(num)  # 0.6439142598879722

可以看到,通过接口获取的probs和我们手动计算得到的probs的结果是一致的,验证了我们的想法。

如果传入的数据为probs,使用Categorical.logits获取对应的对数概率值,代码实现为:

torch.log(ps_clamped)

这里我们不讨论 is_binary的情况,可以看到代码实现只是简单地使用了对数转换,我们通过代码检查一下:

import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)

print(pd.logits)  # tensor([-0.9163, -1.2040, -1.6094, -2.3026])

print(torch.log(torch.tensor(0.4)))  # tensor(-0.9163)

可以看到,两者的值是一样的,验证了我们的想法。

三、sample()采样

代码简单如下所示:

def sample(self, sample_shape=torch.Size()):
    if not isinstance(sample_shape, torch.Size):
        sample_shape = torch.Size(sample_shape)
    # self._num_events = self._param.size()[-1]
    probs_2d = self.probs.reshape(-1, self._num_events)  # 维度变换
    samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T  # 采样
    return samples_2d.reshape(self._extended_shape(sample_shape))  # 维度变换

sample()的操作比较简单,这里主要记录两个地方:

(1)torch.multinomial()

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor
  • input的每一行进行num_samples次取值,输出为每次取值的索引;
  • input每一行中的元素是该索引被采样的权重。如果元素为0,那么其他位置被采样完之前,这个位置都不会被采样;
  • replacement=False为不放回采样,replacement=True为有放回采样;
  • 在不放回采样中,num_samoles的值必须小于等于input.size(-1)的值,否则会报错;

举个例子:

import torch
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
torch.multinomial(weights, 2)  # tensor([1, 2])
torch.multinomial(weights, 6) # 不放回取样,报错,sample n_sample > prob_dist.size(-1)
torch.multinomial(weights, 4, replacement=True)  # tensor([2, 1, 1, 2])

(2)numel()

  • numel()用来统计tensor中元素的个数;

你可能感兴趣的:(pytorch,python,机器学习,深度学习)