告别巨型Q表!用PyTorch手把手实现价值函数逼近(VFA),搞定CartPole游戏
告别巨型Q表用PyTorch手把手实现价值函数逼近VFA搞定CartPole游戏当你在Gymnasium的CartPole环境中第一次尝试Q-Learning时是否曾被那个不断膨胀的Q表格吓到状态空间稍微复杂些内存占用就会指数级增长。这就是传统表格型强化学习方法的致命伤——维度灾难Curse of Dimensionality。今天我们将用PyTorch实现价值函数逼近Value Function Approximation用神经网络这个万能函数逼近器来替代笨重的Q表格。1. 为什么需要价值函数逼近在经典CartPole问题中小车的状态由四个连续变量构成小车位置x小车速度v杆角度θ杆角速度ω如果用离散化方法处理假设每个维度分成20个区间动作空间有2个动作左/右那么Q表大小将是20^4 * 2 320,000 个条目这种存储方式存在三个致命缺陷内存爆炸状态维度增加时存储需求呈指数增长泛化性差相似状态无法共享经验效率低下查表操作在连续空间变得极其低效函数逼近的核心思想是用参数化函数$Q(s,a;w)$代替Q表其中w是可训练参数。PyTorch实现的优势在于自动微分简化梯度计算GPU加速提升训练速度灵活的神经网络架构设计提示VFA不仅适用于离散动作空间稍加修改就能扩展到连续控制问题2. 环境准备与特征工程首先建立我们的实验环境import gymnasium as gym import torch import torch.nn as nn env gym.make(CartPole-v1) state_dim env.observation_space.shape[0] # 4 action_dim env.action_space.n # 2对于线性逼近器特征设计至关重要。我们采用多项式特征增强表现力def polynomial_features(state, degree2): 将4维状态转换为多项式特征 x, v, theta, omega state features [ 1, x, v, theta, omega, x*v, x*theta, x*omega, v*theta, v*omega, theta*omega, x**2, v**2, theta**2, omega**2 ] return torch.FloatTensor(features)这种特征工程比原始状态更适合线性模型捕捉非线性关系。不同特征处理方式对比特征类型维度优点缺点原始状态4简单直接无法捕捉非线性多项式特征15增强非线性能力维度增长快神经网络自定义自动学习特征需要更多数据3. 构建PyTorch逼近器我们实现两种典型的函数逼近器3.1 线性逼近器class LinearVFA(nn.Module): def __init__(self, feature_dim, action_dim): super().__init__() self.weights nn.Linear(feature_dim, action_dim, biasFalse) def forward(self, features): return self.weights(features)3.2 神经网络逼近器class NeuralVFA(nn.Module): def __init__(self, state_dim, action_dim, hidden_size64): super().__init__() self.net nn.Sequential( nn.Linear(state_dim, hidden_size), nn.ReLU(), nn.Linear(hidden_size, action_dim) ) def forward(self, state): return self.net(state)关键参数初始化技巧def init_weights(m): if type(m) nn.Linear: nn.init.xavier_uniform_(m.weight) if m.bias is not None: m.bias.data.fill_(0.01) model.apply(init_weights) # 应用Xavier初始化4. 训练流程实现完整的训练循环包含这些关键步骤经验收集使用ε-greedy策略与环境交互def select_action(state, epsilon): if random.random() epsilon: return env.action_space.sample() else: with torch.no_grad(): q_values model(state) return q_values.argmax().item()TD目标计算实现Sarsa风格的更新current_q model(current_state)[action] next_q model(next_state)[next_action] # Sarsa风格 target reward gamma * next_q * (1 - done) loss F.mse_loss(current_q, target)参数更新PyTorch标准优化流程optimizer.zero_grad() loss.backward() optimizer.step()完整的训练参数配置参数推荐值作用γ (gamma)0.99折扣因子ε初始值1.0探索率ε衰减率0.995线性衰减学习率0.001Adam优化器批次大小32每次更新样本数隐藏层大小64神经网络宽度5. 效果对比与调优经过2000轮训练后两种方法的性能对比指标线性VFA神经网络VFA收敛步数~800~400最高得分200500训练速度快(1x)慢(3x)稳定性高需要调参常见问题解决方案震荡不收敛减小学习率增加批次大小得分卡在200调整γ接近1增强长期回报考虑探索不足采用ε衰减策略如epsilon max(0.01, epsilon * 0.995) # 指数衰减进阶技巧使用经验回放打破数据相关性实现Double DQN减少过高估计添加优先级采样提升重要经验利用率# 示例优先级回放缓冲区 class PriorityBuffer: def __init__(self, capacity): self.capacity capacity self.buffer [] self.priorities np.zeros(capacity) def add(self, experience, priority): if len(self.buffer) self.capacity: self.buffer.append(experience) else: idx np.argmin(self.priorities) self.buffer[idx] experience self.priorities[len(self.buffer)-1] priority在实际测试中当杆子快要倒下时神经网络VFA能捕捉到更细微的状态变化。比如当θ0.2且ω0.5时模型会强烈建议向相反方向移动而线性模型对这种非线性关系的反应要迟钝许多。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2610179.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!