强化学习实战:从CartPole到Doom的策略梯度算法
1. 项目概述当强化学习遇上经典控制问题最近在复现经典论文时我重新把玩了下OpenAI Gym里的CartPole环境顺手用PyTorch实现了Policy Gradient算法。这个看似简单的平衡杆问题其实包含了强化学习最核心的试错学习思想。更让我惊喜的是同样的算法框架稍作调整后居然在ViZDoom的3D环境中也能跑出不错的效果。今天就来拆解这个从玩具问题到第一人称射击游戏的算法迁移之旅。2. 核心原理策略梯度的数学之美2.1 从概率分布到梯度更新Policy Gradient的核心思想非常直观让智能体在环境中尝试各种动作增加带来高回报的动作概率减少低回报动作概率。用数学表达就是# 伪代码示例 probs policy_network(state) action torch.multinomial(probs, 1) loss -torch.log(probs[action]) * discounted_reward这里的关键在于损失函数设计使用-log(prob)表示动作概率的负对数似然乘以discounted_reward作为权重因子反向传播时高回报的动作梯度会获得更大更新幅度2.2 折扣回报与基线技巧原始REINFORCE算法存在高方差问题我通常采用两种改进折扣回报计算def compute_returns(rewards, gamma0.99): R 0 returns [] for r in reversed(rewards): R r gamma * R returns.insert(0, R) return returns引入基线baselineadvantages returns - returns.mean() # 减去均值作为基线 loss -torch.log(probs[action]) * advantages[step]3. CartPole环境实战3.1 网络架构设计对于CartPole这种低维状态空间两层全连接网络足矣class PolicyNet(nn.Module): def __init__(self): super().__init__() self.fc1 nn.Linear(4, 128) # 4维状态空间 self.fc2 nn.Linear(128, 2) # 2个离散动作 def forward(self, x): x F.relu(self.fc1(x)) return F.softmax(self.fc2(x), dim-1)注意最后一层一定要用softmax保证输出是概率分布3.2 训练流程关键点我的训练循环包含几个重要技巧for episode in range(1000): states, actions, rewards [], [], [] state env.reset() # 数据收集阶段 while True: prob policy_net(torch.FloatTensor(state)) action torch.multinomial(prob, 1).item() next_state, reward, done, _ env.step(action) states.append(state) actions.append(action) rewards.append(reward) state next_state if done: break # 计算折扣回报 returns compute_returns(rewards) # 策略更新 optimizer.zero_grad() for s, a, R in zip(states, actions, returns): prob policy_net(torch.FloatTensor(s)) loss -torch.log(prob[a]) * R loss.backward() optimizer.step()4. Doom环境迁移挑战4.1 视觉输入处理ViZDoom的屏幕分辨率通常是160×120的RGB图像需要引入CNN处理class DoomPolicyNet(nn.Module): def __init__(self): super().__init__() self.cnn nn.Sequential( nn.Conv2d(3, 32, kernel_size5, stride2), nn.ReLU(), nn.Conv2d(32, 64, kernel_size3, stride2), nn.ReLU(), nn.Flatten() ) self.fc nn.Linear(64*18*13, 3) # 假设输出3个动作 def forward(self, x): x x.permute(0, 3, 1, 2) # NHWC - NCHW x self.cnn(x) return F.softmax(self.fc(x), dim-1)4.2 帧堆叠技巧为了获取时序信息我采用4帧堆叠作为状态输入state np.stack([frame1, frame2, frame3, frame4], axis-1)5. 性能优化实战技巧5.1 并行环境采样使用多进程加速数据收集from multiprocessing import Process, Queue def worker(env_id, queue): env gym.make(env_id) while True: # ...收集轨迹数据... queue.put((states, actions, rewards))5.2 熵正则化防止策略过早收敛probs policy_net(state) entropy -torch.sum(probs * torch.log(probs)) loss -torch.log(probs[action]) * advantage - 0.01 * entropy6. 调试与问题排查6.1 常见失败模式策略不收敛检查折扣因子gamma是否过大建议0.9-0.99尝试减小学习率从3e-4开始调试回报波动剧烈增加基线函数复杂度尝试PPO等改进算法6.2 监控指标我习惯记录这些关键指标print(fEpisode {episode}: fReturn{sum(rewards):.1f}, fMax Prob{max(probs):.2f}, fEntropy{entropy:.2f})7. 进阶扩展方向7.1 连续动作空间对于需要精确控制力度的场景如机器人控制可以改用高斯策略class GaussianPolicy(nn.Module): def forward(self, x): mu self.mu_head(x) # 均值 std torch.exp(self.std_head(x)) # 标准差 return torch.distributions.Normal(mu, std)7.2 混合离散-连续动作某些环境如赛车游戏需要同时处理离散动作换挡连续动作方向盘角度可以用不同的网络头处理不同类型动作。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2550190.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!