Cross-Entropy Method (CEM, 交叉熵方法) 与强化学习

转自:https://the0demiurge.blogspot.com/2017/08/cross-entropy-method-cem.html

 

前言

之前阅读Deep Reinforcement Learning: Pong from Pixels的时候,作者在文中最后提到“One should always try a BB gun before reaching for the Bazooka. In the case of Reinforcement Learning for example, one strong baseline that should always be tried first is the cross-entropy method (CEM), a simple stochastic hill-climbing “guess and check” approach inspired loosely by evolution.” (如果你做强化学习,就应该先用最原始的方法即CEM先测试一下,当然试验结果也可以充当benchmark进行对比。)

由于(居然)没找到相关的中文资料,我就来简要介绍一下吧。

简介

交叉熵方法是一种蒙特卡洛方法,主要用来优化和重要性采样。和进化算法类似,在空间中按照某种规则撒点,获得每个点的误差,再根据这些误差信息决定下一轮撒点的规则。交叉熵方法之所以叫这个名字,是因为该方法(从理论上来说)目标是最小化随机撒点得到的数据分布与数据实际分布的交叉熵(等价于最小化 KL 距离),尽量使采样分布(撒的点)与实际情况同分布。

该方法适当选取撒点规则就可以适应多目标优化等情况,在组合优化中也有许多应用。本文主要讨论 CEM 在强化学习的策略优化中的应用。

CEM 流程

简介说得花里胡哨,可是实际应用起来基本没有交叉熵什么事,实际应用通常是以下步骤:

 

 

  1. 首先,建模,将问题的解参数化。比如强化学习中,假设状态 S 为一个 n 维向量,动作总共有 2 种,最简单的想法就是建立一个 n 维参数向量 W ,求 ST  ×× W 得到一个标量 Q ,当 Q > 0 则采取第一种动作,否则采取第二种。强化学习问题转化为优化问题。(接下来其实可以使用任何优化算法求解最优的 W ,只不过交叉熵方法可以很快很稳定地收敛。)
  2. 假设参数 W 属于高斯分布,随机设置一个 n 维向量 μμ 和一个 n 维向量 σ2σ2 ,分别对应于 W 的每一维。
  3. 以 μμ 和 σ2σ2 为均值、方差采样得到 m 组参数 w1,w2,w3...wmw1,w2,w3...wm
  4. 计算每一组 w 的回报 reward
  5. 选取回报最高的 k (k
  6. 如果收敛则返回 reward 最大的一组 w, 否则重复步骤3~6

评价

从简介里说的蒙特卡洛方法就知道,该方法就是靠不断尝试,方法十分简单,所以也只是在简单的情况下十分有效。由此可见,简单的强化学习模型用进化算法也可解。

CEM只是最简单的方法,因此可用于验证最小网络结构是否能优化得出来,由于求解费时少,能够快速验证想法(还能快速验证代码是否有错)

参考资料

代码实现:https://gist.github.com/andrewliao11/d52125b52f76a4af73433e1cf8405a8f

 

维基百科中的伪代码:

1. mu:=-6; sigma2:=100; t:=0; maxits=100;    // Initialize parameters
2. N:=100; Ne:=10;                           //
3. while t < maxits and sigma2 > epsilon     // While maxits not exceeded and not converged
4.  X = SampleGaussian(mu,sigma2,N);         // Obtain N samples from current sampling distribution
5.  S = exp(-(X-2)^2) + 0.8 exp(-(X+2)^2);   // Evaluate objective function at sampled points
6.  X = sort(X,S);                           // Sort X by objective function values (in descending order)
7.  mu = mean(X(1:Ne)); sigma2=var(X(1:Ne)); // Update parameters of sampling distribution
8.  t = t+1;                                 // Increment iteration counter
9. return mu                                 // Return mean of final sampling distribution as solution


一篇非常棒的论文(理论讲得很详细清楚)
CROSS-ENTROPY FOR MONTE-CARLO TREE SEARCH
https://dke.maastrichtuniversity.nl/m.winands/documents/crossmc.pdf

你可能感兴趣的:((深度)增强学习)