强化学习实战:用Python手把手实现值迭代与策略迭代算法(附完整代码)
强化学习实战用Python手把手实现值迭代与策略迭代算法附完整代码强化学习作为机器学习的重要分支近年来在游戏AI、自动驾驶、机器人控制等领域展现出惊人潜力。对于初学者而言理解算法原理固然重要但真正掌握强化学习的精髓在于能够将数学公式转化为可运行的代码。本文将聚焦两种经典算法——值迭代和策略迭代通过Python代码实现带你深入理解其工作原理和工程实现细节。1. 环境准备与问题建模在开始编码前我们需要明确强化学习问题的基本框架。马尔可夫决策过程(MDP)是强化学习的数学基础包含状态集合S、动作集合A、转移概率P和奖励函数R四个核心要素。让我们先定义一个简单的网格世界环境作为实验场景import numpy as np class GridWorld: def __init__(self, size5): self.size size self.actions [up, down, left, right] self.goal (4, 4) # 右下角为目标位置 self.obstacles [(1, 1), (2, 2), (3, 3)] # 障碍物位置 self.rewards {self.goal: 10} # 到达目标奖励10分 self.gamma 0.9 # 折扣因子 def step(self, state, action): if state self.goal: return state, 0, True # 已到达目标 next_state list(state) if action up: next_state[0] max(0, next_state[0]-1) elif action down: next_state[0] min(self.size-1, next_state[0]1) elif action left: next_state[1] max(0, next_state[1]-1) elif action right: next_state[1] min(self.size-1, next_state[1]1) next_state tuple(next_state) if next_state in self.obstacles: return state, -1, False # 撞到障碍物 reward self.rewards.get(next_state, 0) done (next_state self.goal) return next_state, reward, done这个5×5的网格世界中右下角(4,4)是目标位置三个对角线位置设置了障碍物。智能体可以执行上、下、左、右四个动作到达目标获得10奖励碰到障碍物则获得-1惩罚。2. 值迭代算法实现值迭代基于贝尔曼最优方程通过不断更新状态值函数来逼近最优策略。其核心思想可以概括为初始化所有状态值V(s)对每个状态s计算所有可能动作的期望回报选择最大期望回报作为新的V(s)重复步骤2-3直到值函数收敛下面是Python实现代码def value_iteration(env, theta1e-6): V np.zeros((env.size, env.size)) while True: delta 0 for i in range(env.size): for j in range(env.size): state (i, j) if state in env.obstacles or state env.goal: continue v_old V[i][j] max_value -float(inf) for action in env.actions: next_state, reward, _ env.step(state, action) ni, nj next_state value reward env.gamma * V[ni][nj] if value max_value: max_value value V[i][j] max_value delta max(delta, abs(v_old - V[i][j])) if delta theta: break # 提取最优策略 policy {} for i in range(env.size): for j in range(env.size): state (i, j) if state in env.obstacles or state env.goal: continue best_action None best_value -float(inf) for action in env.actions: next_state, reward, _ env.step(state, action) ni, nj next_state value reward env.gamma * V[ni][nj] if value best_value: best_value value best_action action policy[state] best_action return V, policy关键实现细节theta参数控制收敛阈值当值函数更新幅度小于该值时停止迭代障碍物和目标状态的值固定为0不参与更新算法收敛后通过贪婪策略提取最优动作运行算法并可视化结果env GridWorld() V, policy value_iteration(env) print(最优值函数) print(np.round(V, 2)) print(\n最优策略示例) for i in range(env.size): for j in range(env.size): state (i, j) if state in policy: print(f状态({i},{j}): {policy[state]})3. 策略迭代算法实现策略迭代算法由策略评估和策略改进两个步骤交替进行策略评估计算当前策略下的状态值函数策略改进基于当前值函数改进策略重复上述步骤直到策略稳定Python实现如下def policy_evaluation(env, policy, V, theta1e-6): while True: delta 0 for i in range(env.size): for j in range(env.size): state (i, j) if state in env.obstacles or state env.goal: continue v_old V[i][j] action policy[state] next_state, reward, _ env.step(state, action) ni, nj next_state V[i][j] reward env.gamma * V[ni][nj] delta max(delta, abs(v_old - V[i][j])) if delta theta: break return V def policy_improvement(env, V, policy): policy_stable True for i in range(env.size): for j in range(env.size): state (i, j) if state in env.obstacles or state env.goal: continue old_action policy[state] best_action None best_value -float(inf) for action in env.actions: next_state, reward, _ env.step(state, action) ni, nj next_state value reward env.gamma * V[ni][nj] if value best_value: best_value value best_action action policy[state] best_action if old_action ! best_action: policy_stable False return policy, policy_stable def policy_iteration(env): # 初始化随机策略 policy {} for i in range(env.size): for j in range(env.size): state (i, j) if state not in env.obstacles and state ! env.goal: policy[state] np.random.choice(env.actions) V np.zeros((env.size, env.size)) while True: V policy_evaluation(env, policy, V) policy, policy_stable policy_improvement(env, V, policy) if policy_stable: break return V, policy策略迭代的特点策略评估阶段需要多次迭代直到值函数收敛策略改进阶段采用贪婪策略更新整体收敛速度通常比值迭代更快测试策略迭代算法env GridWorld() V_pi, policy_pi policy_iteration(env) print(策略迭代最优值函数) print(np.round(V_pi, 2)) print(\n策略迭代最优策略示例) for i in range(5): for j in range(5): state (i, j) if state in policy_pi: print(f状态({i},{j}): {policy_pi[state]})4. 算法对比与工程实践技巧值迭代和策略迭代都是解决MDP问题的动态规划方法但在实现细节和性能上存在差异特性值迭代策略迭代更新方式直接优化值函数交替进行策略评估与改进收敛速度相对较慢通常更快每次迭代计算成本较低较高需策略评估收敛中间策略不明确明确适用场景状态空间大策略空间简单工程实践中需要注意的几个关键点收敛条件设置# 值迭代收敛条件 if delta theta: break # 策略评估收敛条件 if delta theta: breaktheta值的选择需要权衡精度和计算成本通常设置为1e-4到1e-6之间。迭代次数限制为避免无限循环应设置最大迭代次数max_iter 1000 for _ in range(max_iter): # 迭代逻辑 if delta theta: break可视化监控实时绘制值函数或策略变化有助于调试import matplotlib.pyplot as plt def plot_value_function(V): plt.imshow(V, cmaphot) plt.colorbar() plt.show()性能优化技巧向量化计算使用NumPy矩阵运算替代循环并行化对状态更新进行并行处理稀疏矩阵对于大型状态空间使用稀疏矩阵存储# 向量化计算示例 def vectorized_value_iteration(env, theta1e-6): V np.zeros((env.size, env.size)) actions_map {up: (-1,0), down: (1,0), left: (0,-1), right: (0,1)} while True: V_prev V.copy() for action in env.actions: di, dj actions_map[action] next_i np.clip(np.arange(env.size)[:,None] di, 0, env.size-1) next_j np.clip(np.arange(env.size)[None,:] dj, 0, env.size-1) reward np.zeros((env.size, env.size)) reward[env.goal] 10 for obs in env.obstacles: reward[obs] -1 V_temp reward env.gamma * V_prev[next_i, next_j] if action env.actions[0]: V_new V_temp else: V_new np.maximum(V_new, V_temp) delta np.max(np.abs(V_new - V_prev)) V V_new if delta theta: break return V5. 高级话题截断策略迭代与贝尔曼最优方程截断策略迭代是值迭代和策略迭代的折中方案通过在策略评估阶段限制迭代次数来平衡计算成本和收敛速度。实现代码框架def truncated_policy_iteration(env, k3): policy initialize_random_policy() V np.zeros((env.size, env.size)) while True: # 截断策略评估只进行k次迭代 for _ in range(k): V partial_policy_evaluation(env, policy, V) # 策略改进 new_policy, policy_stable policy_improvement(env, V, policy) if policy_stable: break policy new_policy return V, policy贝尔曼最优方程为这些算法提供了理论基础V*(s) max_a [R(s,a) γ * Σ P(s|s,a) * V*(s)]其中V*(s)表示状态s的最优值函数这个方程表明最优值函数是自身的最佳估计。在实际项目中我发现值迭代更适合状态空间较大但策略空间相对简单的问题而策略迭代则在策略空间有明确结构时表现更优。截断策略迭代通过调整k值可以在两者之间找到平衡点通常k3到5就能获得不错的加速效果。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2469454.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!