用PyTorch手搓DDPG算法:从Actor-Critic到目标网络,一步步搞定连续控制
用PyTorch手搓DDPG算法从Actor-Critic到目标网络一步步搞定连续控制在强化学习领域连续控制问题一直是极具挑战性的研究方向。想象一下训练机器人完成精细操作或者让自动驾驶车辆在复杂环境中平稳行驶——这些场景都需要算法能够输出连续范围内的动作值。Deep Deterministic Policy GradientDDPG正是为解决这类问题而生的算法它巧妙地将深度神经网络与传统的Actor-Critic框架相结合成为攻克连续控制任务的利器。本文将带您从零开始实现DDPG算法重点解决实际编码中的三个核心难题如何设计四个协同工作的神经网络Actor、Critic及其目标网络、如何处理动作空间的连续输出、以及如何平衡探索与利用的关系。我们将以MountainCarContinuous-v0环境为实验场通过PyTorch代码逐层拆解算法实现细节让您不仅理解DDPG的工作原理更能掌握其工程实现的关键技巧。1. DDPG算法核心架构解析DDPG算法的精妙之处在于它融合了DQN和策略梯度的优点形成独特的双网络双目标结构。与离散动作空间的DQN不同DDPG的Actor网络直接输出连续动作值这使其特别适合控制类任务。1.1 四大神经网络分工class PolicyNet(nn.Module): # Actor主网络 def __init__(self, n_states, n_hiddens, n_actions, action_bound): super().__init__() self.fc1 nn.Linear(n_states, n_hiddens) self.fc2 nn.Linear(n_hiddens, n_actions) self.action_bound action_bound def forward(self, x): x torch.tanh(self.fc2(F.relu(self.fc1(x)))) return x * self.action_boundDDPG包含四个关键神经网络Actor网络策略函数μ(s|θ^μ)输入状态输出确定性动作Critic网络价值函数Q(s,a|θ^Q)评估状态-动作对的价值Target Actor策略目标网络μ(s|θ^μ)用于稳定训练Target Critic价值目标网络Q(s,a|θ^Q)提供TD目标基准这种分离设计解决了移动目标问题——当使用同一个网络既计算预测值又计算目标值时会导致训练过程不稳定。通过引入目标网络我们相当于为算法提供了一个相对固定的参考系。1.2 关键数学原理DDPG的核心更新规则建立在贝尔曼方程基础上Critic更新目标y r γ(1-done)Q(s,μ(s|θ^μ)|θ^Q)Actor更新策略∇θ^μ J ≈ E[∇a Q(s,a|θ^Q)|aμ(s) ∇θ^μ μ(s|θ^μ)]这两个公式揭示了DDPG的双重学习机制Critic学习准确评估动作价值而Actor则朝着提升Critic评分的方向优化策略。这种分工协作的模式使得算法既能处理连续动作空间又能保持较高的样本效率。2. 经验回放机制实现经验回放是DDPG稳定训练的关键组件它通过存储和重复利用历史经验打破了样本间的时序相关性。我们实现了一个高效的回放缓冲区2.1 回放缓冲区设计class ReplayBuffer: def __init__(self, capacity): self.buffer collections.deque(maxlencapacity) # 固定容量队列 def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions random.sample(self.buffer, batch_size) return zip(*transitions)技术细节说明使用collections.deque实现固定大小的循环缓冲区每个经验元组存储(s_t, a_t, r_{t1}, s_{t1}, done)五要素采样时随机抽取batch_size个独立样本打破时序相关性2.2 优先经验回放改进基础实现采用均匀采样但我们可以进一步优化class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, capacity, alpha0.6): super().__init__(capacity) self.priorities np.zeros(capacity) self.alpha alpha # 控制优先程度 self.pos 0 def add(self, *args): max_prio self.priorities.max() if self.buffer else 1.0 self.priorities[self.pos] max_prio super().add(*args) self.pos (self.pos 1) % self.buffer.maxlen优先回放根据TD误差调整采样概率使对学习更有价值的经验更频繁地被回放。这种改进可以显著提升样本利用率特别是在稀疏奖励场景中。3. 噪声探索策略实现确定性策略的一个固有问题是缺乏探索能力。DDPG通过添加噪声解决这一问题使Agent能够探索动作空间。3.1 高斯噪声实现def take_action(self, state): state torch.FloatTensor(state).unsqueeze(0).to(self.device) action self.actor(state).cpu().detach().numpy()[0] # 添加高斯噪声 noise self.sigma * np.random.randn(self.n_actions) return np.clip(action noise, -self.action_bound, self.action_bound)参数调节技巧sigma控制噪声强度通常从0.1开始逐步衰减训练初期可设置较大噪声增强探索训练后期减小噪声使策略趋于稳定使用np.clip确保动作不超出环境允许范围3.2 噪声退火策略更高级的实现可以采用噪声退火机制self.sigma max(0.01, self.sigma * 0.995) # 每步衰减噪声这种线性或指数衰减策略能够在训练初期充分探索在后期稳定策略。实际应用中还可以采用Ornstein-Uhlenbeck过程噪声它特别适合物理系统的惯性特性。4. 网络更新机制详解DDPG的训练过程涉及四种网络的协同更新这是算法实现中最复杂的部分。我们将拆解每个更新步骤的代码实现。4.1 Critic网络更新Critic的目标是最小化TD误差# 计算目标Q值 next_actions self.target_actor(next_states) next_q_values self.target_critic(next_states, next_actions) q_targets rewards self.gamma * (1 - dones) * next_q_values # 计算当前Q值 q_values self.critic(states, actions) # 计算损失并更新 critic_loss F.mse_loss(q_values, q_targets) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()关键点使用目标网络计算下一状态的动作和价值通过贝尔曼方程构造TD目标最小化预测值与目标值的均方误差4.2 Actor网络更新Actor的目标是最大化预期回报# 计算策略梯度 actor_actions self.actor(states) actor_loss -self.critic(states, actor_actions).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step()这里使用负号是因为PyTorch优化器默认执行最小化。本质上我们是在沿着Critic评估的梯度方向提升策略性能。4.3 目标网络软更新DDPG采用软更新而非硬更新def soft_update(self, net, target_net): for param, target_param in zip(net.parameters(), target_net.parameters()): target_param.data.copy_( self.tau * param.data (1 - self.tau) * target_param.data )参数选择建议tau通常设为0.001-0.01较小的tau使目标网络更新更平缓过大的tau可能导致训练不稳定5. 完整训练流程实现现在我们将所有组件整合构建完整的训练循环。以MountainCarContinuous-v0环境为例5.1 环境初始化env gym.make(MountainCarContinuous-v0) n_states env.observation_space.shape[0] n_actions env.action_space.shape[0] action_bound env.action_space.high[0] agent DDPG( n_statesn_states, n_hiddens64, n_actionsn_actions, action_boundaction_bound, sigma0.1, actor_lr1e-3, critic_lr1e-3, tau0.005, gamma0.99, devicedevice )5.2 训练循环for episode in range(200): state env.reset() episode_return 0 while True: action agent.take_action(state) next_state, reward, done, _ env.step(action) replay_buffer.add(state, action, reward, next_state, done) state next_state episode_return reward if len(replay_buffer) batch_size: transitions replay_buffer.sample(batch_size) agent.update(transitions) if done: break5.3 训练曲线分析典型的训练过程会呈现以下特征初期回报波动较大Agent在探索阶段随着经验积累策略逐渐稳定后期回报趋于收敛找到较优策略建议监控以下指标每回合总回报Critic损失值Actor策略更新幅度噪声强度变化6. 实战技巧与调优策略在实际实现DDPG时有几个关键点需要特别注意6.1 网络结构设计Actor网络架构建议最后一层使用tanh激活输出范围[-1,1]通过action_bound缩放输出到环境范围隐藏层不宜过深2-3层通常足够Critic网络设计要点输入为状态和动作的拼接最后一层线性输出无激活函数可考虑使用Layer Normalization稳定训练6.2 超参数调优参数典型值调节建议回放缓冲区大小1e5-1e6越大越好但受内存限制批量大小64-256太小导致不稳定太大降低样本效率Actor学习率1e-4-1e-3通常小于Critic学习率Critic学习率1e-3-3e-3可适当增大折扣因子γ0.95-0.99长周期任务取较大值软更新系数τ0.001-0.01越小更新越平缓6.3 常见问题排查训练不收敛的可能原因Critic损失爆炸尝试减小学习率梯度裁剪Actor策略退化检查噪声是否足够增大探索回报波动大增大回放缓冲区减小批量大小目标网络更新过快减小τ值调试技巧可视化网络权重分布监控梯度幅度检查动作值是否合理验证TD误差是否逐渐减小7. 进阶改进方向基础DDPG实现后可以考虑以下改进方案提升性能7.1 Twin Delayed DDPG (TD3)TD3算法针对DDPG的三个主要弱点进行了改进目标策略平滑减少Critic估计误差双Critic网络取最小值避免过估计延迟策略更新Critic更稳定后再更新Actor# TD3的双Critic实现示例 self.critic1 QValueNet(n_states, n_hiddens, n_actions).to(device) self.critic2 QValueNet(n_states, n_hiddens, n_actions).to(device) self.target_critic1 QValueNet(n_states, n_hiddens, n_actions).to(device) self.target_critic2 QValueNet(n_states, n_hiddens, n_actions).to(device)7.2 分布式DDPG通过多个Agent并行收集经验加速训练过程每个Worker有独立的探索策略共享中心化回放缓冲区定期同步主网络参数7.3 分层DDPG对于复杂任务可以设计分层控制高层策略制定子目标底层DDPG执行具体动作通过目标重标定连接不同层次在MountainCarContinuous环境中基础DDPG通常能在100-200个训练回合后找到解决方案。实际测试中一个配置得当的DDPG Agent可以将小车在约110步内推到目标位置远优于随机策略的300步表现。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2588836.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!