Week 9: 函数逼近入门

回顾:表格型强化学习方法 (Tabular RL)

到目前为止,我们学习的 MC, TD, SARSA, Q-Learning 都属于表格型 (Tabular) 强化学习方法。

  • 核心: 使用表格来存储和更新每个状态 (\(V(s)\)) 或每个状态-动作对 (\(Q(s, a)\)) 的价值。
    • 例如,\(Q\) 表是一个 \(|S| \times |A|\) 大小的矩阵(或字典)。
  • 适用性: 对于状态空间 \(S\) 和动作空间 \(A\) 都比较离散的问题效果很好(如 Gridworld, Blackjack, CliffWalking)。

局限性: 当状态或动作空间变得庞大甚至连续时,表格型方法面临巨大挑战。

为何需要函数逼近 (Function Approximation)?

表格型方法的局限主要体现在以下几个方面:

  1. 维度灾难 (Curse of Dimensionality) - 状态空间巨大:
    • 许多现实世界的商业问题具有非常大的状态空间。
      • 库存管理: 如果管理数千种商品,每种商品有多个库存水平,状态组合数量将是天文数字。\(S = (inv_1, inv_2, ..., inv_{1000})\)
      • 动态定价: 如果状态包含历史价格、竞争对手价格、多种客户分群信息,状态空间会急剧膨胀。
      • 游戏: 棋盘游戏(如围棋 \(10^{170}\) 状态)、视频游戏(像素级输入,状态近乎无限)。
    • 使用表格存储所有状态(或状态-动作对)的价值变得不可行
      • 内存需求: 需要巨大的内存来存储 \(Q\) 表。
      • 计算需求: 访问和更新这个巨大的表格非常耗时。
      • 样本效率: 需要访问每个状态(或状态-动作对)很多次才能得到准确的价值估计,这在巨大状态空间中需要天文数字般的经验数据。
  2. 连续状态空间 (Continuous State Spaces):
    • 很多问题的状态是连续的。
      • 机器人控制: 关节角度、速度。
      • 金融交易: 股票价格、技术指标。
      • 物理模拟: CartPole 的车位置、杆角度、速度都是连续的。
    • 表格无法直接存储无限多的连续状态。虽然可以进行离散化(分箱),但这会导致信息损失,并且在高维连续空间中,离散化的格子数量仍然会爆炸式增长。
  3. 连续动作空间 (Continuous Action Spaces):
    • 动作也可能是连续的。
      • 机器人控制: 电机施加的力矩。
      • 自动驾驶: 方向盘转角、油门/刹车力度。
      • 资源分配: 分配给不同项目的预算比例。
    • 表格型方法(特别是基于 \(\max_a Q(s, a)\) 的方法如 Q-Learning)难以直接处理无限多的连续动作。

核心问题: 我们无法为每一个可能的状态(或状态-动作对)单独学习和存储一个值。

解决方案: 使用函数逼近 (Function Approximation)

函数逼近的基本思想

不再为每个 \(s\)\((s, a)\) 存储独立的价值,而是用一个带参数的函数来近似价值函数:

  • 近似状态值函数: \(\hat{V}(s, w) \approx V_{\pi}(s)\)\(V^{*}(s)\)
  • 近似动作值函数: \(\hat{Q}(s, a, w) \approx Q_{\pi}(s, a)\)\(Q^{*}(s, a)\)

其中:

  • \(\hat{V} / \hat{Q}\): 表示价值函数的近似值 (Approximation)
  • \(w\): 函数的参数 (Parameters / Weights)。这是一个维度远小于 \(|S| 或 |S| \times |A|\) 的向量。
  • 函数形式: 可以是各种函数,例如:
    • 线性函数 (Linear Function)
    • 决策树 (Decision Tree)
    • 神经网络 (Neural Network) (尤其是深度神经网络,即深度强化学习的基础)

学习目标: 调整参数 \(w\),使得近似函数 \(\hat{V}(s, w)\)\(\hat{Q}(s, a, w)\) 尽可能地接近真实的价值函数 \(V_{\pi}(s)\) / \(V^{*}(s)\)\(Q_{\pi}(s, a)\) / \(Q^{*}(s, a)\)

优势:

  1. 泛化 (Generalization):
    • 函数逼近器可以从有限的经验中泛化到未见过或很少访问的状态。
    • 相似的状态(根据函数的特征表示)会得到相似的价值估计。
    • 学习一个状态的价值可以帮助改进对其他“相似”状态的价值估计。
  2. 处理大规模/连续空间:
    • 参数 \(w\) 的数量远少于状态(或状态-动作对)的数量,大大减少了内存和计算需求。
    • 可以直接处理连续的状态输入(例如,神经网络可以直接接收连续值的向量作为输入)。

学习过程:

函数逼近下的强化学习通常借鉴监督学习的思想。我们将 RL 算法产生的目标值 (Target)(例如,MC 回报 \(G_t\) 或 TD 目标 \(R + \gamma \hat{V}(S', w)\))视为“标签”,将当前的价值估计 \(\hat{V}(S, w)\)\(\hat{Q}(S, A, w)\) 视为模型的“预测”。然后,我们定义一个损失函数 (Loss Function) 来衡量预测与目标之间的误差,并通过梯度下降 (Gradient Descent) 等优化算法来调整参数 \(w\) 以最小化这个损失。

例如,对于 TD(0) 预测 \(V_{\pi}\),更新参数 \(w\) 的一种常见方式是半梯度 TD(0) (Semi-gradient TD(0)):

  • TD 误差: \(\delta_t = R_{t+1} + \gamma \hat{V}(S_{t+1}, w) - \hat{V}(S_t, w)\)
  • 参数更新: \(w \leftarrow w + \alpha * \delta_t * \nabla \hat{V}(S_t, w)\)

其中 \(\nabla \hat{V}(S_t, w)\) 是近似函数 \(\hat{V}\) 对参数 \(w\) 在状态 \(S_t\) 处的梯度。这个更新规则试图将 \(\hat{V}(S_t, w)\) 朝着 TD 目标移动。

半梯度 (Semi-gradient)

之所以称为”半梯度”,是因为虽然TD目标值\(R + \gamma \hat{V}(S', w)\)本身也依赖于参数\(w\),但在梯度计算时我们通常只考虑当前预测值\(\hat{V}(S_t, w)\)\(w\)的梯度,而忽略目标值对\(w\)的梯度依赖。这种做法虽然简化了计算且在实际应用中通常有效,但可能导致理论上的收敛性问题。特别是在Off-Policy(异策略)、Bootstrapping(自举)和Function Approximation(函数逼近)三者同时存在的情况下,这种不稳定性尤为明显,这种情况被称为强化学习中的”死亡三角”(Deadly Triad)。

线性函数逼近 (Linear Function Approximation)

最简单的函数逼近形式之一是线性函数。

思想: 将状态 \(s\) 表示为一个特征向量 (Feature Vector) \(\phi(s)\)\(\phi(s) = (\phi_1(s), \phi_2(s), ..., \phi_d(s))^\top\)

其中 \(d\) 是特征的数量,通常远小于状态总数 \(|S|\)。特征可以是:

  • 状态变量的原始值(如果状态是向量)。
  • 状态变量的多项式组合。
  • 基于状态的某种编码(如 Tile Coding, Radial Basis Functions)。
  • 领域知识提取的关键指标。

线性价值函数近似: \(\hat{V}(s, w) = w^\top \phi(s) = \sum_{i=1}^d w_i \phi_i(s)\)

参数 \(w = (w_1, w_2, ..., w_d)^\top\) 是与特征对应的权重向量。价值被近似为特征的加权和。

参数更新 (半梯度 TD(0)):

\(\nabla \hat{V}(s, w) = \phi(s)\) (线性函数对参数的梯度就是其特征向量) \(w \leftarrow w + \alpha * [R + \gamma w^\top \phi(S') - w^\top \phi(s)] * \phi(s)\)

优点:

  • 简单,计算高效。
  • 理论性质相对较好(例如,线性函数逼近下的 TD(0) 通常能收敛)。

缺点:

  • 特征工程 (Feature Engineering): 效果高度依赖于特征 \(\phi(s)\) 的好坏。设计好的特征需要大量的领域知识和尝试。
  • 表达能力有限: 线性模型只能表示状态特征和价值之间的线性关系,可能无法捕捉复杂的非线性价值函数。

对于动作值函数 \(\hat{Q}(s, a, w)\),线性逼近可以表示为: \(\hat{Q}(s, a, w) = w^\top \phi(s, a)\) 其中 \(\phi(s, a)\) 是状态-动作对的特征向量。

概念 Lab/演示:表格方法的局限性 (CartPole)

回顾一下 CartPole-v1 环境:

  • 目标: 控制小车左右移动,以保持杆子竖直不倒。
  • 状态 (Observation): 一个包含 4 个连续值的向量:
    1. 小车位置 (Cart Position)
    2. 小车速度 (Cart Velocity)
    3. 杆子角度 (Pole Angle)
    4. 杆子角速度 (Pole Angular Velocity)
  • 动作 (Action): 0 (向左推), 1 (向右推) (离散)。
  • 奖励: 每保持一步奖励 +1。
  • 结束条件: 杆子倾斜超过一定角度,或小车移出边界,或达到最大步数。

为什么表格型 Q-Learning/SARSA 难以处理 CartPole?

  1. 连续状态空间: 状态包含 4 个连续变量。理论上有无限多个可能的状态。
  2. 离散化的挑战:
    • 我们可以尝试将每个连续变量离散化(分箱)。例如,将位置分成 10 个区间,速度分成 10 个区间,角度分成 20 个区间,角速度分成 20 个区间。
    • 即使这样粗略的离散化,总的状态数也将是 10 * 10 * 20 * 20 = 40,000 个。
    • 如果需要更精细的离散化,状态数量会急剧增加。
    • 离散化会丢失状态变量的精确信息。两个非常接近但落在不同箱子里的状态会被视为完全不同;而同一个箱子里的两个相距较远的状态会被视为相同。
  3. 维度灾难: 随着状态维度的增加(即使是离散化后),所需的状态数量呈指数级增长。表格方法无法有效处理。

结论: 对于像 CartPole 这样具有连续状态(即使动作是离散的)的问题,表格型方法不再适用,我们必须使用函数逼近。

引入 Stable Baselines3 (SB3) 库

手动实现基于神经网络的函数逼近(如 DQN, A2C)涉及许多细节:网络结构设计、梯度计算、优化器选择、经验回放管理等。

Stable Baselines3 (SB3) 是一个基于 PyTorch 的开源库,它提供了可靠的、经过良好测试的深度强化学习 (Deep Reinforcement Learning, DRL) 算法实现。

主要目的: 让研究人员和开发者能够方便地使用 SOTA (State-of-the-Art) 或经典的 DRL 算法来解决问题,而无需从头实现算法的复杂细节。

核心特点:

  • 多种 DRL 算法实现: 包括 DQN, A2C, PPO, SAC, TD3 等(我们将在后续课程中学习其中一些)。
  • 统一的接口: 所有算法遵循类似的使用模式,易于切换和比较。
  • 与 Gym/Gymnasium 兼容: 可以直接用于标准的 Gym/Gymnasium 环境。
  • 预定义的策略网络: 为常见任务(如 MLP 网络用于向量输入,CNN 网络用于图像输入)提供了预定义的网络结构。
  • 易于定制: 允许用户自定义网络结构、特征提取器等。
  • 包含实用工具: 如回调函数 (Callbacks) 用于监控训练、保存模型、评估模型等。

基本用法概览 (以 DQN 为例,细节下周 Lab 讲解):

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env # 用于创建向量化环境

# 1. 创建环境 (可以是单个环境,或多个并行环境以加速训练)
# env = gym.make("CartPole-v1", render_mode="rgb_array") # 使用 rgb_array 模式进行训练
vec_env = make_vec_env("CartPole-v1", n_envs=4) # 创建 4 个并行环境

# 2. 定义模型 (选择算法,指定策略网络类型,传入环境)
# "MlpPolicy": 使用多层感知机 (MLP) 作为 Q 网络
model = DQN("MlpPolicy", vec_env, verbose=1, # verbose=1 打印训练信息
            learning_rate=1e-4,
            buffer_size=100000, # 经验回放缓冲区大小
            learning_starts=1000, # 多少步后开始学习
            batch_size=32,
            tau=1.0, # Target network update rate
            gamma=0.99,
            train_freq=4, # 每多少步训练一次
            gradient_steps=1,
            target_update_interval=1000, # Target network 更新频率
            exploration_fraction=0.1, # 探索率衰减的总步数比例
            exploration_final_eps=0.05, # 最终探索率
           )

# 3. 训练模型
# total_timesteps: 总的训练步数
model.learn(total_timesteps=100000, log_interval=4) # log_interval 控制打印频率

# 4. 保存模型
model.save("dqn_cartpole")

# 5. 加载模型并使用 (评估/预测)
# del model # 删除现有模型 (可选)
# loaded_model = DQN.load("dqn_cartpole")

# # 使用加载的模型进行预测
# obs, info = vec_env.reset()
# for _ in range(1000):
#     action, _states = loaded_model.predict(obs, deterministic=True) # deterministic=True 使用贪心策略
#     obs, rewards, terminated, truncated, infos = vec_env.step(action)
#     # 注意:vec_env 的 render 需要特殊处理,或者创建一个单独的非向量化环境来可视化
#     # env.render()
#     if any(terminated) or any(truncated):
#         obs, info = vec_env.reset()

vec_env.close()
# env.close() # 如果创建了单个环境
课程重点

在本课程的后半部分,我们将重点学习如何使用 Stable Baselines3 来运行和理解 DRL 算法(如 DQN, A2C),而不是要求大家从头实现这些复杂的算法。你需要理解算法的核心思想、关键组件(如经验回放、目标网络)的作用,以及如何调整超参数来训练模型并分析结果。


下周预告: 深度 Q 网络 (Deep Q-Network, DQN)。我们将深入学习 DQN 如何使用神经网络逼近 Q 函数,以及经验回放和目标网络这两个关键技巧。Lab 6 将使用 Stable Baselines3 运行 DQN 解决 CartPole 问题。