TD3算法实战:用PyTorch从零搭建强化学习模型(附完整代码)
TD3算法实战用PyTorch从零搭建强化学习模型附完整代码强化学习在机器人控制、自动驾驶等领域展现出巨大潜力而TD3算法作为DDPG的升级版本凭借其稳定性和高效性成为处理连续动作空间问题的首选。本文将带你从零开始实现一个完整的TD3模型不仅包含核心代码解析还会分享实际训练中的调参技巧和常见陷阱。1. TD3算法核心原理TD3Twin Delayed Deep Deterministic policy gradient算法通过三项关键技术改进解决了DDPG中Q值高估的问题双重Critic网络使用两个独立的Q网络进行价值估计取较小值作为更新目标延迟策略更新Critic网络更新频率高于Actor网络通常比例为2:1目标策略平滑在目标动作上添加受限噪声防止策略过拟合# 双重Critic网络结构示例 class TwinCritic(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() # 第一个Q网络 self.q1_linear1 nn.Linear(state_dim action_dim, 256) self.q1_linear2 nn.Linear(256, 256) self.q1_output nn.Linear(256, 1) # 第二个Q网络 self.q2_linear1 nn.Linear(state_dim action_dim, 256) self.q2_linear2 nn.Linear(256, 256) self.q2_output nn.Linear(256, 1)注意实际实现时需要确保两个Q网络的初始化参数不同否则会失去双重估计的意义2. 环境配置与数据准备2.1 安装依赖环境推荐使用conda创建隔离的Python环境conda create -n td3 python3.8 conda activate td3 pip install torch1.12.1 gym0.21.0 numpy matplotlib2.2 经验回放缓冲区实现高效的经验回放是稳定训练的关键要素class ReplayBuffer: def __init__(self, capacity): self.buffer deque(maxlencapacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions random.sample(self.buffer, batch_size) return zip(*transitions) def __len__(self): return len(self.buffer)参数推荐值作用说明capacity1e6缓冲区容量batch_size256每次采样数量seed42随机种子3. 核心网络架构实现3.1 Actor策略网络策略网络直接输出确定性动作class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super().__init__() self.net nn.Sequential( nn.Linear(state_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, action_dim), nn.Tanh() ) self.max_action max_action def forward(self, state): return self.max_action * self.net(state)3.2 Critic价值网络双重Q网络结构设计要点两个独立的前向传播路径共享底层特征提取层可选最后输出层保持独立class Critic(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() # Q1网络 self.q1 nn.Sequential( nn.Linear(state_dim action_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1) ) # Q2网络 self.q2 nn.Sequential( nn.Linear(state_dim action_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1) ) def forward(self, state, action): sa torch.cat([state, action], dim1) return self.q1(sa), self.q2(sa) def q1_value(self, state, action): sa torch.cat([state, action], dim1) return self.q1(sa)4. 完整训练流程实现4.1 训练循环框架def train(env, agent, buffer, episodes1000): for ep in range(episodes): state env.reset() episode_reward 0 while True: action agent.select_action(state) next_state, reward, done, _ env.step(action) buffer.add(state, action, reward, next_state, done) if len(buffer) batch_size: agent.update(buffer) state next_state episode_reward reward if done: break4.2 关键参数设置# 超参数配置 config { gamma: 0.99, # 折扣因子 tau: 0.005, # 软更新系数 policy_noise: 0.2, # 策略噪声系数 noise_clip: 0.5, # 噪声裁剪范围 policy_freq: 2, # 策略更新频率 lr_actor: 1e-4, # Actor学习率 lr_critic: 1e-3 # Critic学习率 }4.3 模型更新逻辑def update(self, buffer): # 采样批次数据 state, action, reward, next_state, done buffer.sample(batch_size) # 计算目标Q值 with torch.no_grad(): noise (torch.randn_like(action) * self.policy_noise).clamp( -self.noise_clip, self.noise_clip) next_action (self.actor_target(next_state) noise).clamp( -self.max_action, self.max_action) target_Q1, target_Q2 self.critic_target(next_state, next_action) target_Q torch.min(target_Q1, target_Q2) target_Q reward (1 - done) * self.gamma * target_Q # 更新Critic current_Q1, current_Q2 self.critic(state, action) critic_loss F.mse_loss(current_Q1, target_Q) F.mse_loss(current_Q2, target_Q) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # 延迟更新Actor和目标网络 if self.total_it % self.policy_freq 0: actor_loss -self.critic.q1_value(state, self.actor(state)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # 软更新目标网络 soft_update(self.critic, self.critic_target, self.tau) soft_update(self.actor, self.actor_target, self.tau)5. 实战调试技巧5.1 常见问题排查训练不稳定检查经验回放缓冲区是否足够大验证目标网络更新是否正常调整策略噪声幅度回报不增长尝试降低学习率检查网络初始化方式增加探索噪声5.2 性能优化建议# 使用自动混合精度训练加速 scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): # 前向计算 current_Q1, current_Q2 self.critic(state, action) critic_loss F.mse_loss(current_Q1, target_Q) F.mse_loss(current_Q2, target_Q) # 反向传播 self.critic_optimizer.zero_grad() scaler.scale(critic_loss).backward() scaler.step(self.critic_optimizer) scaler.update()5.3 可视化监控# 使用TensorBoard记录训练过程 from torch.utils.tensorboard import SummaryWriter writer SummaryWriter() writer.add_scalar(Loss/actor, actor_loss.item(), global_step) writer.add_scalar(Loss/critic, critic_loss.item(), global_step) writer.add_scalar(Reward/episode, episode_reward, ep)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2417688.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!