PyTorch学习笔记 —— Categorical函数

一、介绍

Categorical函数来自包 torch.distributions,官方定义的接口如下:

class torch.distributions.Categorical(probs)

作用是创建以参数probs为标准的类别分布,样本是来自 “0 … K-1” 的整数,其中 K 是probs参数的长度。也就是说,按照传入的probs中给定的概率,在相应的位置处进行取样,取样返回的是该位置的整数索引。

如果 probs 是长度为 K 的一维列表,则每个元素是对该索引处的类进行抽样的相对概率。

如果 probs 是二维的,它被视为一批概率向量。

 

二、使用示例

probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])

dist = Categorical(probs)
print(dist)
# Categorical(probs: torch.Size([2, 3]))

index = dist.sample()
print(index.numpy())
# [2 2]

 

你可能感兴趣的:(Pytorch)