Pytorch一些不常见函数解析(持续更新)

1. Categorical()

torch.distributions.Categorical()

可以按照一定概率产生具体数字,比如:

import torch
from torch.distributions import Categorical

rand = Categorical(torch.tensor([0.25,0.25,0.25,0.25]))
print(rand.sample())
# tensor(3)

这个Categorical()还有一些有趣的功能,比如可以求策略梯度REINFORCE,有个小例子:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

注意,这里的策略梯度REINFORCE的公式为:
在这里插入图片描述

  • 左边是(神经网络)的参数
  • Alpha是学习速率,r是奖励(reward),p则是在状态s以及给定策略pi中执行动作a的概率

而上述代码中的m.log_prob(value)函数则是公式中的log部分

你可能感兴趣的:(Pytorch论文复现,强化学习系列,神经网络)