强化学习实战:用Python手把手实现Grid World的Policy Iteration与Value Iteration

张开发
2026/6/10 5:13:00 15 分钟阅读
强化学习实战:用Python手把手实现Grid World的Policy Iteration与Value Iteration
强化学习实战用Python手把手实现Grid World的Policy Iteration与Value Iteration在强化学习的入门阶段Grid World网格世界是一个经典的实验环境。这个简单的二维网格不仅能够直观展示智能体如何学习最优策略还能帮助我们理解两种基础但强大的算法Policy Iteration策略迭代和Value Iteration价值迭代。本文将带你从零开始用Python实现这两种算法并通过可视化展示它们的学习过程。1. 环境搭建与问题定义首先我们需要构建一个4x4的Grid World环境。这个环境中状态空间16个格子编号从0到15左上角为0右下角为15动作空间4个基本动作上、下、左、右终止状态左上角(0)和右下角(15)为终止状态即时奖励每步移动获得-1奖励到达终止状态获得0奖励状态转移尝试移出边界时会保持在原状态让我们用Python定义这个环境import numpy as np class GridWorld: def __init__(self): self.size 4 self.terminal_states [0, 15] self.actions [up, down, left, right] self.action_effects { up: (-1, 0), down: (1, 0), left: (0, -1), right: (0, 1) } def get_next_state(self, state, action): if state in self.terminal_states: return state, 0, True row, col state // self.size, state % self.size dr, dc self.action_effects[action] new_row, new_col row dr, col dc # 边界检查 if not (0 new_row self.size and 0 new_col self.size): return state, -1, False new_state new_row * self.size new_col reward 0 if new_state in self.terminal_states else -1 return new_state, reward, (new_state in self.terminal_states)2. 策略迭代(Policy Iteration)实现策略迭代分为两个交替进行的步骤策略评估和策略改进。让我们逐步实现这两个部分。2.1 策略评估策略评估的目标是计算当前策略下的状态价值函数。我们使用迭代法来求解贝尔曼方程def policy_evaluation(env, policy, gamma1.0, theta1e-6): V np.zeros(env.size**2) while True: delta 0 for state in range(env.size**2): if state in env.terminal_states: continue old_value V[state] action policy[state] new_state, reward, _ env.get_next_state(state, action) V[state] reward gamma * V[new_state] delta max(delta, abs(old_value - V[state])) if delta theta: break return V2.2 策略改进基于评估得到的状态价值我们改进策略def policy_improvement(env, V, gamma1.0): policy np.zeros(env.size**2, dtypeint) for state in range(env.size**2): if state in env.terminal_states: continue action_values [] for action in env.actions: new_state, reward, _ env.get_next_state(state, action) action_values.append(reward gamma * V[new_state]) best_action np.argmax(action_values) policy[state] best_action return policy2.3 完整策略迭代将两个步骤结合起来def policy_iteration(env, gamma1.0): # 初始化随机策略 policy np.random.choice(len(env.actions), sizeenv.size**2) while True: V policy_evaluation(env, policy, gamma) new_policy policy_improvement(env, V, gamma) if np.array_equal(policy, new_policy): break policy new_policy return policy, V3. 价值迭代(Value Iteration)实现价值迭代将策略评估和改进合并为一个步骤直接迭代价值函数def value_iteration(env, gamma1.0, theta1e-6): V np.zeros(env.size**2) while True: delta 0 for state in range(env.size**2): if state in env.terminal_states: continue old_value V[state] action_values [] for action in env.actions: new_state, reward, _ env.get_next_state(state, action) action_values.append(reward gamma * V[new_state]) V[state] max(action_values) delta max(delta, abs(old_value - V[state])) if delta theta: break # 从最优价值函数中提取策略 policy np.zeros(env.size**2, dtypeint) for state in range(env.size**2): if state in env.terminal_states: continue action_values [] for action in env.actions: new_state, reward, _ env.get_next_state(state, action) action_values.append(reward gamma * V[new_state]) best_action np.argmax(action_values) policy[state] best_action return policy, V4. 结果分析与可视化让我们运行两种算法并比较结果env GridWorld() # 策略迭代 pi_policy, pi_values policy_iteration(env) print(策略迭代结果) print(策略, [env.actions[a] for a in pi_policy]) print(价值, pi_values.reshape(4,4)) # 价值迭代 vi_policy, vi_values value_iteration(env) print(\n价值迭代结果) print(策略, [env.actions[a] for a in vi_policy]) print(价值, vi_values.reshape(4,4))典型输出结果策略迭代结果 策略 [left, left, down, left, up, up, down, left, up, down, down, left, up, right, right, left] 价值 [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]] 价值迭代结果 策略 [left, left, down, left, up, up, down, left, up, down, down, left, up, right, right, left] 价值 [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]我们可以用matplotlib将价值函数可视化import matplotlib.pyplot as plt def plot_values(values, title): plt.figure(figsize(6,6)) plt.imshow(values.reshape(4,4), cmapcoolwarm) plt.colorbar() for i in range(4): for j in range(4): plt.text(j, i, f{values[i*4j]:.1f}, hacenter, vacenter) plt.title(title) plt.axis(off) plt.show() plot_values(pi_values, Policy Iteration Value Function) plot_values(vi_values, Value Iteration Value Function)5. 算法比较与实战建议5.1 性能对比特性策略迭代价值迭代收敛速度通常较快策略空间小可能较慢需要价值收敛每次迭代计算量较大需要完整策略评估较小单次更新适用场景策略空间小的问题状态空间大的问题实现复杂度较高两个独立阶段较低单一更新循环5.2 调试技巧收敛问题检查折扣因子γ是否合理通常接近1确保终止状态的价值固定为0验证状态转移和奖励计算是否正确性能优化对大型状态空间考虑异步更新使用向量化操作替代循环对价值迭代可以设置最大迭代次数防止无限循环可视化调试定期打印策略和价值函数绘制价值函数变化曲线对Grid World可以用箭头表示策略方向def plot_policy(policy, title): arrows [↑, ↓, ←, →] policy_arrows [arrows[p] for p in policy] plt.figure(figsize(6,6)) plt.imshow(np.zeros((4,4)), cmapbinary) for i in range(4): for j in range(4): plt.text(j, i, policy_arrows[i*4j], hacenter, vacenter, fontsize20) plt.title(title) plt.axis(off) plt.show() plot_policy(pi_policy, Policy Iteration Optimal Policy) plot_policy(vi_policy, Value Iteration Optimal Policy)5.3 扩展思考随机策略当前环境是确定性的尝试修改环境加入随机转移概率如80%概率按指令移动20%概率随机移动观察算法表现变化。不同奖励结构修改奖励函数比如某些格子有额外惩罚如-5加入正向奖励的宝藏格子设计必须绕行的障碍物大规模扩展将4x4网格扩展到更大尺寸如10x10观察算法收敛时间变化考虑如何优化。其他算法对比实现蒙特卡洛法和时序差分学习如Q-learning与动态规划方法比较优缺点。

更多文章