同前文
这个部分仍然参考周博磊老师的第六节的顺序
主要参考课程 Intro to Reinforcement Learning,Bolei Zhou
相关文中代码 https://github.com/ThousandOfWind/RL-basic-alg.git
DDPG和TD3的实现参考了Addressing Function Approximation Error in Actor-Critic Methods,TD3作者的代码非常清晰明了!
SAC主要参考openai的官方tutorial, pytorch复现版本code , pytorch复现版本2
这里标出Q-learning只是为了强调DQN的target网络结构和经验池。
然而尽管和DQN关系匪浅,这里的算法都主要适用于连续控制!
DDPG很好的继承了DQN的特性,如果DQN用的多其实是很好理解DDPG结构的。
代码这里需要注意的是,和上一个系列不同,DDPG这里actor 和critic网络是分别训练的。
self.Q = DDPG_Critic(param_set)
self.actor = DDPG_Actor(param_set)
self.targetQ = copy.deepcopy(self.Q)
self.targetA = copy.deepcopy(self.actor)
self.critic_optimiser = Adam(params=self.Q.parameters(), lr=self.learning_rate)
self.actor_optimiser = Adam(params=self.actor.parameters(), lr=self.learning_rate)
currentQ = self.Q(obs, action_index)
targetQ = (reward + self.gamma * (1-done) * self.targetQ(next_obs, self.targetA(next_obs))).detach()
critic_loss = F.mse_loss(currentQ, targetQ)
self.writer.add_scalar('Loss/TD_loss', critic_loss.item(), self.step )
# Optimize the critic
self.critic_optimiser.zero_grad()
critic_loss.backward()
self.critic_optimiser.step()
actor_loss = - self.Q(obs, self.actor(obs))
self.writer.add_scalar('Loss/pi_loss', actor_loss.item(), self.step )
self.actor_optimiser.zero_grad()
actor_loss.backward()
self.actor_optimiser.step()
另外target网络一般采用软更新方式
for param, target_param in zip(self.Q.parameters(), self.targetQ.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.targetA.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
这里有个我一直困惑的问题,当动作时连续的时候显然critic到actor的梯度是通的,但是动作离散时,是怎样向actor回传梯度的呢?
软更新和硬更新各有什么优缺点呢?
这个就有点像Double DQN的扩展版本了
他的两个改进就如同名字
currentQ1, currentQ2 = self.Q(obs, action_index)
targetQ1, targetQ2 = self.targetQ(next_obs, self.targetA(next_obs))
targetQ = th.min(targetQ1, targetQ2)
targetQ = (reward + self.gamma * (1-done) * targetQ).detach()
critic_loss = F.mse_loss(currentQ1, targetQ) + F.mse_loss(currentQ2, targetQ)
self.writer.add_scalar('Loss/TD_loss', critic_loss.item(), self.step )
# Optimize the critic
self.critic_optimiser.zero_grad()
critic_loss.backward()
self.critic_optimiser.step()
self.step += 1
if self.step - self.last_update > self.pi_update_frequncy:
self.last_update = self.step
q1, q2 = - self.Q(obs, self.actor(obs))
actor_loss = - q1
self.writer.add_scalar('Loss/pi_loss', actor_loss.item(), self.step )
self.actor_optimiser.zero_grad()
actor_loss.backward()
self.actor_optimiser.step()
for param, target_param in zip(self.Q.parameters(), self.targetQ.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.targetA.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
所谓最大熵是希望最大化策略的熵
H ( P ) = E x ∼ P [ − log P ( x ) ] H(P)=\underset{x \sim P}{\mathrm{E}}[-\log P(x)] H(P)=x∼PE[−logP(x)]
可想而知该项会鼓励智能体尽可能的均匀探索环境
以下是SAC的loss方式
L ( ϕ i , D ) = E ( s , a , r , s ′ , d ) ∼ D [ ( Q ϕ i ( s , a ) − y ( r , s ′ , d ) ) 2 ] L\left(\phi_{i}, \mathcal{D}\right)=\underset{\left(s, a, r, s^{\prime}, d\right) \sim \mathcal{D}}{\mathrm{E}}\left[\left(Q_{\phi_{i}}(s, a)-y\left(r, s^{\prime}, d\right)\right)^{2}\right] L(ϕi,D)=(s,a,r,s′,d)∼DE[(Qϕi(s,a)−y(r,s′,d))2]
y ( r , s ′ , d ) = r + γ ( 1 − d ) ( min j = 1 , 2 Q . ϕ targ , j ( s ′ , a ~ ′ ) − α log π θ ( a ~ ′ ∣ s ′ ) , a ~ ′ ∼ π θ ( ⋅ ∣ s ′ ) y\left(r, s^{\prime}, d\right)=r+\gamma(1-d)\left(\min _{j=1,2} Q_{.\phi_{\text {targ }, j}}\left(s^{\prime}, \tilde{a}^{\prime}\right)-\alpha \log \pi_{\theta}\left(\tilde{a}^{\prime} \mid s^{\prime}\right), \quad \tilde{a}^{\prime} \sim \pi_{\theta}\left(\cdot \mid s^{\prime}\right)\right. y(r,s′,d)=r+γ(1−d)(j=1,2minQ.ϕtarg ,j(s′,a~′)−αlogπθ(a~′∣s′),a~′∼πθ(⋅∣s′)
我这里有一个疑惑,一般信息熵是 ∑ p ( x ) log p ( x ) \sum p(x)\log p(x) ∑p(x)logp(x) 为什么这边的熵项就只剩一半了啊
SAC继承了TD3的双网络特性
但也有以下几点不同:
第三项其实我不是特别特别清楚啊,大概意思就是TD3是通过网络直接输出了动作,而SAC其实输出的动作的概率,所以就。。。
Unlike in TD3, there is no explicit target policy smoothing. TD3 trains a deterministic policy, and so it accomplishes smoothing by adding random noise to the next-state actions. SAC trains a stochastic policy, and so the noise from that stochasticity is sufficient to get a similar effect.
but 我没有康到TD3平滑策略的部分呀
首先因为我们希望得到的连续动作自带分布,我们就需要引入这个,具体这个是怎么回事,还没有研究
class TanhNormal(Distribution):
"""
Represent distribution of X where
X ~ tanh(Z)
Z ~ N(mean, std)
Note: this is not very numerically stable.
"""
def __init__(self, normal_mean, normal_std, epsilon=1e-6):
"""
:param normal_mean: Mean of the normal distribution
:param normal_std: Std of the normal distribution
:param epsilon: Numerical stability epsilon when computing log-prob.
"""
self.normal_mean = normal_mean
self.normal_std = normal_std
self.normal = Normal(normal_mean, normal_std)
self.epsilon = epsilon
def sample_n(self, n, return_pre_tanh_value=False):
z = self.normal.sample_n(n)
if return_pre_tanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def log_prob(self, value, pre_tanh_value=None):
"""
:param value: some value, x
:param pre_tanh_value: arctanh(x)
:return:
"""
if pre_tanh_value is None:
pre_tanh_value = torch.log(
(1+value) / (1-value)
) / 2
return self.normal.log_prob(pre_tanh_value) - torch.log(
1 - value * value + self.epsilon
)
def sample(self, return_pretanh_value=False):
"""
Gradients will and should *not* pass through this operation.
See https://github.com/pytorch/pytorch/issues/4620 for discussion.
"""
z = self.normal.sample().detach()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def rsample(self, return_pretanh_value=False):
"""
Sampling in the reparameterization case.
"""
z = (
self.normal_mean +
self.normal_std *
Normal(
torch.zeros(self.normal_mean.size()),
torch.ones(self.normal_std.size())
).sample()
)
z.requires_grad_()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def get_action(self, observation, sample=False):
obs = th.FloatTensor(observation)
dist = self.actor(obs)
action = dist.sample() if sample else dist.mean
action = action.clamp(*self.action_range)
return action
currentQ1, currentQ2 = self.Q(obs, action_index)
next_dist = self.actor(obs)
next_action = next_dist.rsample()
targetnextQ1, targetnextQ2 = self.targetQ(next_obs, next_action)
next_log_prob = next_dist.log_prob(next_action).sum(-1, keepdim=True)
targetV = th.min(targetnextQ1, targetnextQ2) - self.alpha * next_log_prob
targetQ = (reward + self.gamma * (1-done) * targetV).detach()
critic_loss = F.mse_loss(currentQ1, targetQ) + F.mse_loss(currentQ2, targetQ)
# Optimize the critic
self.critic_optimiser.zero_grad()
critic_loss.backward()
self.critic_optimiser.step()
dist = self.actor(obs)
action = dist.rsample()
q1, q2 = - self.Q(obs, action)
q = th.min(q1, q2)
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
actor_loss = self.alpha.detach() * log_prob - q
self.actor_optimiser.zero_grad()
actor_loss.backward()
self.actor_optimiser.step()
self.alpha_optimiser.zero_grad()
alpha_loss = (self.alpha *
(-log_prob - self.target_entropy).detach()).mean()
alpha_loss.backward()
self.alpha_optimiser.step()
因为很多工作都是在这个路线基础上扩展的,甚至蔓延到了离散动作控制问题,所以我们这里就要探究怎么样去离散。
如果只是将离散动作读入网络显然梯度无法回传,所以自然的我们就会想到用动作概率,具体就是说critic不再只输出一维,而是像一般的DQN网络,有 ∣ A ∣ |A| ∣A∣维,然后和动作概率结合得到Q
currentQ1, currentQ2 = self.Q(obs)
currentQ1 = currentQ1.gather(1, action_index)
currentQ2 = currentQ2.gather(1, action_index)
next_action_index, next_action_log_probs, next_pi = self.actor(next_obs)
target_next_Q = th.min(self.targetQ(next_obs))
targetV = (next_pi * (target_next_Q - self.alpha * next_action_log_probs)).sum(dim=1, keepdim=True)
targetQ = (reward + self.gamma * (1-done) * targetV).detach()
critic_loss = F.mse_loss(currentQ1, targetQ) + F.mse_loss(currentQ2, targetQ)
action_index, action_log_probs, pi = self.actor(obs)
q1, q2 = - self.Q(obs)
q = (th.min(q1, q2) * pi).sum(dim=1, keepdim=True)
entropies = -(action_log_probs * pi).sum(dim=1, keepdim=True)
actor_loss = (- self.alpha.detach() * entropies - q).mean()
alpha_loss = (self.alpha *
(entropies.detach() - self.target_entropy).detach()).mean()