Week 9: 函数逼近入门
回顾:表格型强化学习方法 (Tabular RL)
到目前为止,我们学习的 MC, TD, SARSA, Q-Learning 都属于表格型 (Tabular) 强化学习方法。
- 核心: 使用表格来存储和更新每个状态 (\(V(s)\)) 或每个状态-动作对 (\(Q(s, a)\)) 的价值。
- 例如,\(Q\) 表是一个 \(|S| \times |A|\) 大小的矩阵(或字典)。
- 适用性: 对于状态空间 \(S\) 和动作空间 \(A\) 都比较小且离散的问题效果很好(如 Gridworld, Blackjack, CliffWalking)。
局限性: 当状态或动作空间变得庞大甚至连续时,表格型方法面临巨大挑战。
为何需要函数逼近 (Function Approximation)?
表格型方法的局限主要体现在以下几个方面:
- 维度灾难 (Curse of Dimensionality) - 状态空间巨大:
- 许多现实世界的商业问题具有非常大的状态空间。
- 库存管理: 如果管理数千种商品,每种商品有多个库存水平,状态组合数量将是天文数字。\(S = (inv_1, inv_2, ..., inv_{1000})\)
- 动态定价: 如果状态包含历史价格、竞争对手价格、多种客户分群信息,状态空间会急剧膨胀。
- 游戏: 棋盘游戏(如围棋 \(10^{170}\) 状态)、视频游戏(像素级输入,状态近乎无限)。
- 使用表格存储所有状态(或状态-动作对)的价值变得不可行:
- 内存需求: 需要巨大的内存来存储 \(Q\) 表。
- 计算需求: 访问和更新这个巨大的表格非常耗时。
- 样本效率: 需要访问每个状态(或状态-动作对)很多次才能得到准确的价值估计,这在巨大状态空间中需要天文数字般的经验数据。
- 许多现实世界的商业问题具有非常大的状态空间。
- 连续状态空间 (Continuous State Spaces):
- 很多问题的状态是连续的。
- 机器人控制: 关节角度、速度。
- 金融交易: 股票价格、技术指标。
- 物理模拟: CartPole 的车位置、杆角度、速度都是连续的。
- 表格无法直接存储无限多的连续状态。虽然可以进行离散化(分箱),但这会导致信息损失,并且在高维连续空间中,离散化的格子数量仍然会爆炸式增长。
- 很多问题的状态是连续的。
- 连续动作空间 (Continuous Action Spaces):
- 动作也可能是连续的。
- 机器人控制: 电机施加的力矩。
- 自动驾驶: 方向盘转角、油门/刹车力度。
- 资源分配: 分配给不同项目的预算比例。
- 表格型方法(特别是基于 \(\max_a Q(s, a)\) 的方法如 Q-Learning)难以直接处理无限多的连续动作。
- 动作也可能是连续的。
核心问题: 我们无法为每一个可能的状态(或状态-动作对)单独学习和存储一个值。
解决方案: 使用函数逼近 (Function Approximation)。
函数逼近的基本思想
不再为每个 \(s\) 或 \((s, a)\) 存储独立的价值,而是用一个带参数的函数来近似价值函数:
- 近似状态值函数: \(\hat{V}(s, w) \approx V_{\pi}(s)\) 或 \(V^{*}(s)\)
- 近似动作值函数: \(\hat{Q}(s, a, w) \approx Q_{\pi}(s, a)\) 或 \(Q^{*}(s, a)\)
其中:
- \(\hat{V} / \hat{Q}\): 表示价值函数的近似值 (Approximation)。
- \(w\): 函数的参数 (Parameters / Weights)。这是一个维度远小于 \(|S| 或 |S| \times |A|\) 的向量。
- 函数形式: 可以是各种函数,例如:
- 线性函数 (Linear Function)
- 决策树 (Decision Tree)
- 神经网络 (Neural Network) (尤其是深度神经网络,即深度强化学习的基础)
- …
学习目标: 调整参数 \(w\),使得近似函数 \(\hat{V}(s, w)\) 或 \(\hat{Q}(s, a, w)\) 尽可能地接近真实的价值函数 \(V_{\pi}(s)\) / \(V^{*}(s)\) 或 \(Q_{\pi}(s, a)\) / \(Q^{*}(s, a)\)。
优势:
- 泛化 (Generalization):
- 函数逼近器可以从有限的经验中泛化到未见过或很少访问的状态。
- 相似的状态(根据函数的特征表示)会得到相似的价值估计。
- 学习一个状态的价值可以帮助改进对其他“相似”状态的价值估计。
- 处理大规模/连续空间:
- 参数 \(w\) 的数量远少于状态(或状态-动作对)的数量,大大减少了内存和计算需求。
- 可以直接处理连续的状态输入(例如,神经网络可以直接接收连续值的向量作为输入)。
学习过程:
函数逼近下的强化学习通常借鉴监督学习的思想。我们将 RL 算法产生的目标值 (Target)(例如,MC 回报 \(G_t\) 或 TD 目标 \(R + \gamma \hat{V}(S', w)\))视为“标签”,将当前的价值估计 \(\hat{V}(S, w)\) 或 \(\hat{Q}(S, A, w)\) 视为模型的“预测”。然后,我们定义一个损失函数 (Loss Function) 来衡量预测与目标之间的误差,并通过梯度下降 (Gradient Descent) 等优化算法来调整参数 \(w\) 以最小化这个损失。
例如,对于 TD(0) 预测 \(V_{\pi}\),更新参数 \(w\) 的一种常见方式是半梯度 TD(0) (Semi-gradient TD(0)):
- TD 误差: \(\delta_t = R_{t+1} + \gamma \hat{V}(S_{t+1}, w) - \hat{V}(S_t, w)\)
- 参数更新: \(w \leftarrow w + \alpha * \delta_t * \nabla \hat{V}(S_t, w)\)
其中 \(\nabla \hat{V}(S_t, w)\) 是近似函数 \(\hat{V}\) 对参数 \(w\) 在状态 \(S_t\) 处的梯度。这个更新规则试图将 \(\hat{V}(S_t, w)\) 朝着 TD 目标移动。
之所以称为”半梯度”,是因为虽然TD目标值\(R + \gamma \hat{V}(S', w)\)本身也依赖于参数\(w\),但在梯度计算时我们通常只考虑当前预测值\(\hat{V}(S_t, w)\)对\(w\)的梯度,而忽略目标值对\(w\)的梯度依赖。这种做法虽然简化了计算且在实际应用中通常有效,但可能导致理论上的收敛性问题。特别是在Off-Policy(异策略)、Bootstrapping(自举)和Function Approximation(函数逼近)三者同时存在的情况下,这种不稳定性尤为明显,这种情况被称为强化学习中的”死亡三角”(Deadly Triad)。
线性函数逼近 (Linear Function Approximation)
最简单的函数逼近形式之一是线性函数。
思想: 将状态 \(s\) 表示为一个特征向量 (Feature Vector) \(\phi(s)\)。 \(\phi(s) = (\phi_1(s), \phi_2(s), ..., \phi_d(s))^\top\)
其中 \(d\) 是特征的数量,通常远小于状态总数 \(|S|\)。特征可以是:
- 状态变量的原始值(如果状态是向量)。
- 状态变量的多项式组合。
- 基于状态的某种编码(如 Tile Coding, Radial Basis Functions)。
- 领域知识提取的关键指标。
线性价值函数近似: \(\hat{V}(s, w) = w^\top \phi(s) = \sum_{i=1}^d w_i \phi_i(s)\)
参数 \(w = (w_1, w_2, ..., w_d)^\top\) 是与特征对应的权重向量。价值被近似为特征的加权和。
参数更新 (半梯度 TD(0)):
\(\nabla \hat{V}(s, w) = \phi(s)\) (线性函数对参数的梯度就是其特征向量) \(w \leftarrow w + \alpha * [R + \gamma w^\top \phi(S') - w^\top \phi(s)] * \phi(s)\)
优点:
- 简单,计算高效。
- 理论性质相对较好(例如,线性函数逼近下的 TD(0) 通常能收敛)。
缺点:
- 特征工程 (Feature Engineering): 效果高度依赖于特征 \(\phi(s)\) 的好坏。设计好的特征需要大量的领域知识和尝试。
- 表达能力有限: 线性模型只能表示状态特征和价值之间的线性关系,可能无法捕捉复杂的非线性价值函数。
对于动作值函数 \(\hat{Q}(s, a, w)\),线性逼近可以表示为: \(\hat{Q}(s, a, w) = w^\top \phi(s, a)\) 其中 \(\phi(s, a)\) 是状态-动作对的特征向量。
概念 Lab/演示:表格方法的局限性 (CartPole)
回顾一下 CartPole-v1 环境:
- 目标: 控制小车左右移动,以保持杆子竖直不倒。
- 状态 (Observation): 一个包含 4 个连续值的向量:
- 小车位置 (Cart Position)
- 小车速度 (Cart Velocity)
- 杆子角度 (Pole Angle)
- 杆子角速度 (Pole Angular Velocity)
- 动作 (Action): 0 (向左推), 1 (向右推) (离散)。
- 奖励: 每保持一步奖励 +1。
- 结束条件: 杆子倾斜超过一定角度,或小车移出边界,或达到最大步数。
为什么表格型 Q-Learning/SARSA 难以处理 CartPole?
- 连续状态空间: 状态包含 4 个连续变量。理论上有无限多个可能的状态。
- 离散化的挑战:
- 我们可以尝试将每个连续变量离散化(分箱)。例如,将位置分成 10 个区间,速度分成 10 个区间,角度分成 20 个区间,角速度分成 20 个区间。
- 即使这样粗略的离散化,总的状态数也将是 10 * 10 * 20 * 20 = 40,000 个。
- 如果需要更精细的离散化,状态数量会急剧增加。
- 离散化会丢失状态变量的精确信息。两个非常接近但落在不同箱子里的状态会被视为完全不同;而同一个箱子里的两个相距较远的状态会被视为相同。
- 维度灾难: 随着状态维度的增加(即使是离散化后),所需的状态数量呈指数级增长。表格方法无法有效处理。
结论: 对于像 CartPole 这样具有连续状态(即使动作是离散的)的问题,表格型方法不再适用,我们必须使用函数逼近。
引入 Stable Baselines3 (SB3) 库
手动实现基于神经网络的函数逼近(如 DQN, A2C)涉及许多细节:网络结构设计、梯度计算、优化器选择、经验回放管理等。
Stable Baselines3 (SB3) 是一个基于 PyTorch 的开源库,它提供了可靠的、经过良好测试的深度强化学习 (Deep Reinforcement Learning, DRL) 算法实现。
主要目的: 让研究人员和开发者能够方便地使用 SOTA (State-of-the-Art) 或经典的 DRL 算法来解决问题,而无需从头实现算法的复杂细节。
核心特点:
- 多种 DRL 算法实现: 包括 DQN, A2C, PPO, SAC, TD3 等(我们将在后续课程中学习其中一些)。
- 统一的接口: 所有算法遵循类似的使用模式,易于切换和比较。
- 与 Gym/Gymnasium 兼容: 可以直接用于标准的 Gym/Gymnasium 环境。
- 预定义的策略网络: 为常见任务(如 MLP 网络用于向量输入,CNN 网络用于图像输入)提供了预定义的网络结构。
- 易于定制: 允许用户自定义网络结构、特征提取器等。
- 包含实用工具: 如回调函数 (Callbacks) 用于监控训练、保存模型、评估模型等。
基本用法概览 (以 DQN 为例,细节下周 Lab 讲解):
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env # 用于创建向量化环境
# 1. 创建环境 (可以是单个环境,或多个并行环境以加速训练)
# env = gym.make("CartPole-v1", render_mode="rgb_array") # 使用 rgb_array 模式进行训练
= make_vec_env("CartPole-v1", n_envs=4) # 创建 4 个并行环境
vec_env
# 2. 定义模型 (选择算法,指定策略网络类型,传入环境)
# "MlpPolicy": 使用多层感知机 (MLP) 作为 Q 网络
= DQN("MlpPolicy", vec_env, verbose=1, # verbose=1 打印训练信息
model =1e-4,
learning_rate=100000, # 经验回放缓冲区大小
buffer_size=1000, # 多少步后开始学习
learning_starts=32,
batch_size=1.0, # Target network update rate
tau=0.99,
gamma=4, # 每多少步训练一次
train_freq=1,
gradient_steps=1000, # Target network 更新频率
target_update_interval=0.1, # 探索率衰减的总步数比例
exploration_fraction=0.05, # 最终探索率
exploration_final_eps
)
# 3. 训练模型
# total_timesteps: 总的训练步数
=100000, log_interval=4) # log_interval 控制打印频率
model.learn(total_timesteps
# 4. 保存模型
"dqn_cartpole")
model.save(
# 5. 加载模型并使用 (评估/预测)
# del model # 删除现有模型 (可选)
# loaded_model = DQN.load("dqn_cartpole")
# # 使用加载的模型进行预测
# obs, info = vec_env.reset()
# for _ in range(1000):
# action, _states = loaded_model.predict(obs, deterministic=True) # deterministic=True 使用贪心策略
# obs, rewards, terminated, truncated, infos = vec_env.step(action)
# # 注意:vec_env 的 render 需要特殊处理,或者创建一个单独的非向量化环境来可视化
# # env.render()
# if any(terminated) or any(truncated):
# obs, info = vec_env.reset()
vec_env.close()# env.close() # 如果创建了单个环境
在本课程的后半部分,我们将重点学习如何使用 Stable Baselines3 来运行和理解 DRL 算法(如 DQN, A2C),而不是要求大家从头实现这些复杂的算法。你需要理解算法的核心思想、关键组件(如经验回放、目标网络)的作用,以及如何调整超参数来训练模型并分析结果。
下周预告: 深度 Q 网络 (Deep Q-Network, DQN)。我们将深入学习 DQN 如何使用神经网络逼近 Q 函数,以及经验回放和目标网络这两个关键技巧。Lab 6 将使用 Stable Baselines3 运行 DQN 解决 CartPole 问题。