Categorical()`的参数有三个,分别为`probs`,`logits`,`validate_args
probs
比如传入probs=[0.4, 0.3, 0.2, 0.1]
,或者probs=[4.0, 3.0, 2.0, 1.0]
,代码是直接对传入的probs
进行归一化处理,对每个数据除以传入数据的累加和得到归一化后的数值,归一化的数据累加和为1。通过公式表示为:
p j = p j Σ i = 1 n p i p_j=\frac{p_j}{\varSigma _{i=1}^{n}p_i} pj=Σi=1npipj
经过处理后的所有 p i p_i pi的累加和为1,即
Σ n i = 1 p i = 1 \underset{i=1}{\overset{n}{\varSigma}}p_i=1 i=1Σnpi=1
import torch
from torch.distributions import Categorical
probs = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.4000, 0.3000, 0.2000, 0.1000])
print(pd) # Categorical(probs: torch.Size([4]))
print(probs) # tensor([4., 3., 2., 1.])
probs = torch.tensor([1, 2, 3, 4])
pd = Categorical(probs=probs)
print(pd.probs) # tensor([0.1000, 0.2000, 0.3000, 0.4000])
logits
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
就是对logits
中的每一个数据都减去其对数指数累加和,公式的最后一部分就是代码的具体实现。公式中减号后面的部分就是LogSumExp
,看字面意思很形象。
import torch
from torch.distributions import Categorical
logit = torch.tensor([4.0,3.0,2.0,1.0])
pd = Categorical(logits=logit)
print(pd.logits) # tensor([-0.4402, -1.4402, -2.4402, -3.4402])
print(pd) # Categorical(logits: torch.Size([4]))
logit = torch.tensor([1, 2, 3, 4])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits) # tensor([-3.4402, -2.4402, -1.4402, -0.4402])
logit = torch.tensor([0.4, 0.3, 0.2, 0.1])
pd = Categorical(logits=logit)
pd.logits
print(pd.logits) # tensor([-1.2425, -1.3425, -1.4425, -1.5425])
对上面公式进行验证 logits=[4, 3, 2, 1]
import math
logit = 4 - math.log(math.exp(4) + math.exp(3) + math.exp(2) + math.exp(1))
print(logits) # -0.4401896985611957
DataFrame.sample(n=None, frac=None, replace=False, weights=None, random_state=None, axis=None)
n:这是一个可选参数, 由整数值组成, 并定义生成的随机行数。
frac:它也是一个可选参数, 由浮点值组成, 并返回浮点值*数据帧值的长度。不能与参数n一起使用。
replace:由布尔值组成。如果为true, 则返回带有替换的样本。替换的默认值为false。
权重:它也是一个可选参数, 由类似于str或ndarray的参数组成。默认值”无”将导致相等的概率加权。如果正在通过系列赛;它将与索引上的目标对象对齐。在采样对象中找不到的权重索引值将被忽略, 而在采样对象中没有权重的索值将被分配零权重。如果在轴= 0时正在传递DataFrame, 则返回0。它将接受列的名称。如果权重是系列;然后, 权重必须与被采样轴的长度相同。如果权重不等于1;它将被标准化为1的总和。权重列中的缺失值被视为零。权重栏中不允许无穷大。
random_state:它也是一个可选参数, 由整数或numpy.random.RandomState组成。如果值为int, 则为随机数生成器或numpy RandomState对象设置种子。
axis:它也是由整数或字符串值组成的可选参数。 0或”行”和1或”列”。