1、高估问题产生的原因
原因1:由于噪声的存在,影响 m a x ( Q ) max(Q) max(Q)的估计最大值比真实的最大值更大,最小值比真实最小值更小;
原因2:Bootstrapping,DQN近似动作价值 Q Q Q,使用TD算法更新DQN,因为TD算法存在高估,更新DQN时造成高估,下一次TD更新时也会不断高估;
2、Target Network解决动作价值高估问题思路
使用Target Network计算: max a Q ( s t + 1 , a ; w − ) \max_aQ(s_{t+1},a;\mathbf{w}^-) maxaQ(st+1,a;w−)
TD learning with naïve update:
TD Target: y t = r t + γ ⋅ max a Q ( s t + 1 , a ; w ) . \begin{gathered} \text{TD Target: }\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}). \\ \end{gathered} TD Target: yt=rt+γ⋅amaxQ(st+1,a;w).
TD learning with target network:
TD Target: y t = r t + γ ⋅ max a Q ( s t + 1 , a ; w − ) \text{TD Target:}\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}^-) TD Target:yt=rt+γ⋅amaxQ(st+1,a;w−)
3、代码实现
实现带有target network的DQN
class DQNWithTargetNetwork:
def __init__(self, dim_state=None, num_action=None, discount=0.9):
self.discount = discount
self.Q = QNet(dim_state, num_action)
# 添加target network
self.target_Q = QNet(dim_state, num_action)
self.target_Q.load_state_dict(self.Q.state_dict())
def get_action(self, state):
# 使用最大价值的动作
qvals = self.Q(state)
return qvals.argmax()
def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
# 计算s_batch,a_batch对应的值。
qvals = self.target_Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
# 使用target Q网络计算next_s_batch对应的值。
next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
# 使用MSE计算loss。
loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
return loss
隔一段时间在再更新target network
# 加权更新target network
def soft_update(target, source, tau=0.01):
"""
update target by target = tau * source + (1 - tau) * target.
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
4、对gather的理解
例如三维的input,从广播机制很容易理解。当dim==0,意味着
out[i][j][k]中的[i]指的是用[index[i][j][k]]取数据放到i的,out[j][k]指的是这两个维度与out同时变化
广播机制是计算循环的一种更快的机制,因此用循环来理解是一样的:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
等价于:
out = torch.zeros(index.shape)#定义zero空tensor
# 循环赋值
for j in range(input.shape[1]):
for k in range(input.shape[2]):
out[:, j, k] = input[index[i][j][k], j, k]
如果是其他维度可参考:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
一个例子:
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
>>tensor([[ 1, 1],
[ 4, 3]])
torch.gather — PyTorch 2.0 documentation
5、对detech的理解:
将tensor从计算图中分离,不进行梯度更新
torch.Tensor.detach — PyTorch 2.0 documentation