SAC算法实战:用PyTorch手把手实现Soft Actor-Critic(附完整代码)
SAC算法实战用PyTorch手把手实现Soft Actor-Critic附完整代码强化学习领域近年来最令人兴奋的进展之一莫过于Soft Actor-CriticSAC算法的崛起。这个融合了最大熵原理与离线策略学习的算法不仅在机器人控制、游戏AI等领域展现出惊人效果更因其卓越的样本效率成为工业界的新宠。今天我们就抛开繁琐的数学推导直接从代码层面拆解SAC的实现奥秘。1. 环境搭建与核心网络架构在PyTorch中实现SAC首先要解决三个核心组件的建模问题策略网络Actor、软Q网络Critic和状态值网络Value。不同于传统Actor-Critic架构SAC需要维护两套Q网络来缓解过估计问题。import torch import torch.nn as nn import torch.nn.functional as F class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim256): super().__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) self.mean nn.Linear(hidden_dim, action_dim) self.log_std nn.Linear(hidden_dim, action_dim) def forward(self, state): x F.relu(self.fc1(state)) x F.relu(self.fc2(x)) mean self.mean(x) log_std self.log_std(x) log_std torch.clamp(log_std, min-20, max2) return mean, log_std关键实现细节策略网络输出动作分布的均值和标准差而非直接输出动作使用log_std而非直接输出标准差确保数值稳定性通过clamp限制标准差范围防止训练崩溃对应的Q网络和V网络实现class QNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim256): super().__init__() self.fc1 nn.Linear(state_dim action_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) self.fc3 nn.Linear(hidden_dim, 1) def forward(self, state, action): x torch.cat([state, action], dim1) x F.relu(self.fc1(x)) x F.relu(self.fc2(x)) return self.fc3(x) class ValueNetwork(nn.Module): def __init__(self, state_dim, hidden_dim256): super().__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) self.fc3 nn.Linear(hidden_dim, 1) def forward(self, state): x F.relu(self.fc1(state)) x F.relu(self.fc2(x)) return self.fc3(x)2. 重参数化技巧与动作采样SAC的核心创新之一是将熵正则化引入目标函数这要求策略能够输出随机动作。我们采用重参数化技巧reparameterization trick实现可微分的随机采样class SACAgent: def get_action(self, state, deterministicFalse): state torch.FloatTensor(state).unsqueeze(0).to(self.device) mean, log_std self.policy_net(state) if deterministic: return torch.tanh(mean).detach().cpu().numpy()[0] std log_std.exp() normal torch.distributions.Normal(mean, std) z normal.rsample() # 重参数化采样 action torch.tanh(z) log_prob normal.log_prob(z) - torch.log(1 - action.pow(2) 1e-6) log_prob log_prob.sum(1, keepdimTrue) return action.detach().cpu().numpy()[0], log_prob.detach().cpu().numpy()数学原理动作采样a ~ π(·|s) tanh(μ σ⊙ε)其中ε ~ N(0,I)对数概率修正logπ(a|s) logN(z|μ,σ) - Σlog(1-tanh²(z)ε)注意tanh变换会导致概率密度变化需要对应的Jacobian修正项3. 损失函数与自动熵调节SAC的损失函数包含三个关键部分策略损失、Q函数损失和值函数损失。其中最具特色的是自动调节的温度参数α# 温度参数自动调节 self.target_entropy -torch.prod(torch.Tensor(action_dim)).item() self.log_alpha torch.zeros(1, requires_gradTrue, devicedevice) self.alpha_optim torch.optim.Adam([self.log_alpha], lrlr) # 策略损失 policy_loss (self.alpha * log_pi - q_value).mean() # Q函数损失 q1_loss F.mse_loss(q1, target_q.detach()) q2_loss F.mse_loss(q2, target_q.detach()) # 值函数损失 v_loss F.mse_loss(v, target_v.detach()) # 温度参数损失 alpha_loss -(self.log_alpha * (log_pi self.target_entropy).detach()).mean()超参数设置参考参数推荐值作用说明学习率3e-4Adam优化器基础学习率折扣因子γ0.99长期回报折扣率软更新参数τ0.005目标网络更新速率回放缓冲区大小1e6经验回放容量批大小256每次采样样本数4. 训练流程与调试技巧完整的训练循环需要特别注意目标网络的更新机制和梯度裁剪def update_parameters(self, batch): # 计算所有损失 ... # 策略网络优化 self.policy_optim.zero_grad() policy_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5) self.policy_optim.step() # 自动熵调节 self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() # 软更新目标网络 with torch.no_grad(): for param, target_param in zip(self.v_net.parameters(), self.target_v_net.parameters()): target_param.data.copy_( self.tau * param.data (1 - self.tau) * target_param.data)常见问题排查指南训练初期回报不上升检查环境状态是否归一化尝试增大熵系数α的初始值验证网络架构是否有梯度流动训练后期策略崩溃降低学习率增加目标网络更新频率检查动作缩放范围是否合理Q值爆炸性增长添加梯度裁剪使用双重Q网络取最小值减小奖励缩放系数5. 实战效果优化策略在MuJoCo环境中测试时以下几个技巧能显著提升最终性能技巧组合包状态归一化维护运行均值和方差目标回报缩放动态调整奖励尺度延迟策略更新每2次Q网络更新1次策略优先经验回放关键transition重点学习# 优先经验回放示例 class PrioritizedReplayBuffer: def __init__(self, capacity, alpha0.6): self.alpha alpha self.priorities np.zeros((capacity,), dtypenp.float32) ... def sample(self, batch_size, beta0.4): probs self.priorities[:self.size] ** self.alpha probs / probs.sum() indices np.random.choice(len(probs), batch_size, pprobs) weights (len(self) * probs[indices]) ** (-beta) weights / weights.max() return indices, weights, self.buffer[indices]在HalfCheetah-v3环境中的典型学习曲线显示经过约1M步训练后SAC能够稳定达到6000以上的回报。相比之下PPO算法需要约3倍样本量才能达到相同水平这验证了SAC卓越的样本效率。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2511075.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!