Target Network缓解DQN的动作价值的高估问题

1、高估问题产生的原因

原因1:由于噪声的存在,影响 m a x ( Q ) max(Q) max(Q)的估计最大值比真实的最大值更大,最小值比真实最小值更小;

原因2:Bootstrapping,DQN近似动作价值 Q Q Q,使用TD算法更新DQN,因为TD算法存在高估,更新DQN时造成高估,下一次TD更新时也会不断高估;


2、Target Network解决动作价值高估问题思路

使用Target Network计算: max ⁡ a Q ( s t + 1 , a ; w − ) \max_aQ(s_{t+1},a;\mathbf{w}^-) maxaQ(st+1,a;w)

TD learning with naïve update:
TD Target:  y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w ) . \begin{gathered} \text{TD Target: }\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}). \\ \end{gathered} TD Target: yt=rt+γamaxQ(st+1,a;w).
TD learning with target network:
TD Target: y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w − ) \text{TD Target:}\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}^-) TD Target:yt=rt+γamaxQ(st+1,a;w)


3、代码实现

实现带有target network的DQN

class DQNWithTargetNetwork:
    def __init__(self, dim_state=None, num_action=None, discount=0.9):
        self.discount = discount
        self.Q = QNet(dim_state, num_action)
        # 添加target network
        self.target_Q = QNet(dim_state, num_action)
        self.target_Q.load_state_dict(self.Q.state_dict())

    def get_action(self, state):
        # 使用最大价值的动作
        qvals = self.Q(state)
        return qvals.argmax()

    def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
        # 计算s_batch,a_batch对应的值。
        qvals = self.target_Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
        # 使用target Q网络计算next_s_batch对应的值。
        next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
        # 使用MSE计算loss。
        loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
        return loss

隔一段时间在再更新target network

# 加权更新target network
def soft_update(target, source, tau=0.01):
    """
    update target by target = tau * source + (1 - tau) * target.
    """
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

4、对gather的理解
例如三维的input,从广播机制很容易理解。当dim==0,意味着

out[i][j][k]中的[i]指的是用[index[i][j][k]]取数据放到i的,out[j][k]指的是这两个维度与out同时变化

广播机制是计算循环的一种更快的机制,因此用循环来理解是一样的:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0

等价于:

out = torch.zeros(index.shape)#定义zero空tensor

# 循环赋值
for j in range(input.shape[1]):
    for k in range(input.shape[2]):
        out[:, j, k] = input[index[i][j][k], j, k]

如果是其他维度可参考:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

一个例子:

t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

>>tensor([[ 1,  1],
         [ 4,  3]])

torch.gather — PyTorch 2.0 documentation


5、对detech的理解:
将tensor从计算图中分离,不进行梯度更新

torch.Tensor.detach — PyTorch 2.0 documentation

你可能感兴趣的:(机器学习,python,开发语言,深度学习,人工智能)