DQN的原理和代码实现

文章目录

  • 1. 概述
  • 2. DQN的训练步骤
    • 2.1 初始化
    • 2.2 训练循环
    • 2.3 终止条件
    • 2.4 评估
  • 3. 代码示例


1. 概述

深度 Q 网络(Deep Q-Network, DQN)是强化学习中的一种重要算法,由 Google DeepMind 于2013年提出。DQN 结合了 Q 学习和深度学习,通过使用神经网络来近似 Q 值函数,解决了传统 Q 学习在高维状态空间中的问题。

2. DQN的训练步骤

2.1 初始化

  1. 环境:定义环境(例如,Atari游戏、CartPole等)。
  2. 网络:初始化两个神经网络:一个是在线网络(Online Network),另一个是目标网络(Target Network)。这两个网络的结构相同,但参数不同。
  3. 经验回放池:初始化一个经验回放池(Replay Buffer),用于存储经验(状态、动作、奖励、下一状态、是否终止)。
  4. 超参数:设置超参数,如学习率(learning rate)、折扣因子(discount factor γ)、探索率(exploration rate ε)、批量大小(batch size)等。

2.2 训练循环

  1. 初始化状态:从环境中获取初始状态 s 0 s_0 s0
  2. 选择动作:根据当前策略(ε-greedy策略)选择动作 a t a_t at
    • 以概率 ϵ \epsilon ϵ 随机选择一个动作。
    • 以概率 1 − ϵ 1 - \epsilon 1ϵ 选择最大化 Q 值的动作。
  3. 执行动作:在环境中执行动作 a t a_t at,观察新的状态 s t + 1 s_{t+1} st+1 和奖励 r t r_t rt,以及是否到达终止状态 d t d_t dt
  4. 存储经验:将经验 ( s t , a t , r t , s t + 1 , d t ) (s_t, a_t, r_t, s_{t+1}, d_t) (st,at,rt,st+1,dt) 存储到经验回放池中。
  5. 采样批次:从经验回放池中随机抽取一个批次的经验 ( s i , a i , r i , s i + 1 , d i ) (s_i, a_i, r_i, s_{i+1}, d_i) (si,ai,ri,si+1,di)
  6. 计算目标Q值
    • 对于每个经验 ( s i , a i , r i , s i + 1 , d i ) (s_i, a_i, r_i, s_{i+1}, d_i) (si,ai,ri,si+1,di),计算目标 Q 值 y i y_i yi
      y i = { r i if  d i r i + γ max ⁡ a ′ Q ( s i + 1 , a ′ ; θ − ) otherwise y_i = \begin{cases} r_i & \text{if } d_i \\ r_i + \gamma \max_{a'} Q(s_{i+1}, a'; \theta^-) & \text{otherwise} \end{cases} yi={riri+γmaxaQ(si+1,a;θ)if diotherwise
      其中, γ \gamma γ 是折扣因子, θ − \theta^- θ 是目标网络的参数。这个公式 r i + γ max ⁡ a ′ Q ( s i + 1 , a ′ ; θ − ) r_i + \gamma \max_{a'} Q(s_{i+1}, a'; \theta^-) ri+γmaxaQ(si+1,a;θ) 来源于最优贝尔曼方程
  7. 计算损失:计算当前网络的 Q 值 Q ( s i , a i ; θ ) Q(s_i, a_i; \theta) Q(si,ai;θ) 与目标 Q 值 y i y_i yi 之间的均方误差(MSE)损失:
    L = 1 N ∑ i = 1 N ( y i − Q ( s i , a i ; θ ) ) 2 L = \frac{1}{N} \sum_{i=1}^{N} (y_i - Q(s_i, a_i; \theta))^2 L=N1i=1N(yiQ(si,ai;θ))2
    其中, N N N 是批次大小。
  8. 更新网络:使用梯度下降法(如 Adam 优化器)最小化损失函数,更新在线网络的参数 θ \theta θ
  9. 更新目标网络:定期或逐步更新目标网络的参数 θ − \theta^- θ 为在线网络的参数 θ \theta θ
    θ − ← τ θ + ( 1 − τ ) θ − \theta^- \leftarrow \tau \theta + (1 - \tau) \theta^- θτθ+(1τ)θ
    其中, τ \tau τ 是一个小的更新率,通常设置为0.001或每一定步数完全更新一次。

2.3 终止条件

  1. 检查终止条件:如果达到预定的最大迭代次数或环境中的终止状态,停止训练。
  2. 保存模型:保存训练好的模型参数。

2.4 评估

  1. 评估模型:在测试环境中评估模型的性能,记录平均奖励、成功率等指标。

3. 代码示例

简化的 DQN 算法的 Python 代码示例,使用 PyTorch 实现:

import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import gym

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

# 定义DQN代理
class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01, buffer_size=10000, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon
        self.batch_size = batch_size

        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

        self.replay_buffer = deque(maxlen=buffer_size)
        self.criterion = nn.MSELoss()

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.tensor([state], dtype=torch.float32)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()

    def store_experience(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))

    def update_network(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.int64)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.tensor(next_states, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
        next_q_values = self.target_network(next_states).max(1)[0]
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        loss = self.criterion(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

# 环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# 初始化代理
agent = DQNAgent(state_dim, action_dim)

# 训练循环
episodes = 1000
for episode in range(episodes):
    state = env.reset()
    total_reward = 0
    done = False

    while not done:
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.store_experience(state, action, reward, next_state, done)
        agent.update_network()
        state = next_state
        total_reward += reward

    agent.update_target_network()
    agent.decay_epsilon()

    print(f'Episode {episode + 1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}')

# 评估
state = env.reset()
done = False
while not done:
    action = agent.choose_action(state)
    state, _, done, _ = env.step(action)
    env.render()

env.close()

欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

你可能感兴趣的:(NLP&机器学习,DQN,强化学习,深度学习)