环境如下:
 
 这是一个简单的环境,绿色方块代表终点,白色方块代表可行点,灰色方块代表陷阱
 用Sarsa算法和Q_learning算法训练得到value表格
 代码如下:
 (jupyter notebook上的代码,所以顺序看起来有点儿奇怪)
def get_state(row,col):
    if row!=3:
        return 'ground'
    elif col==0:
        return 'ground'
    elif col==11:
        return 'terminal'
    else:
        return 'trap'
def envirment(row,col,action):
    if action==0:
        row-=1
    elif action==1:
        row+=1
    elif action==2:
        col-=1
    elif action==3:
        col+=1
    next_row=min(max(0,row),3)
    next_col=min(max(0,col),11)
    reward=-1
    if get_state(next_row,next_col)=='trap':
        reward=-100
    elif get_state(next_row,next_col)=='terminal':
        reward=100
    return next_row,next_col,reward
import numpy as np
import random
Q_pi=np.zeros([4,12,4])
def get_action(row,col):#获取下一步的动作
    if random.random()<0.1:
        return random.choice(range(4))#随机选一个动作
    else:
        return Q_pi[row,col].argmax()#选择Q_pi大的动作
def TD_sarsa(row,col,action,reward,next_row,next_col,next_action):
    TD_target=reward+0.9*Q_pi[next_row,next_col,next_action] #sarsa
#     TD_target=reward+0.9*Q_pi[next_row,next_col].max()#Q_learn
    TD_error=Q_pi[row,col,action]-TD_target
    return TD_error
def train():
    for epoch in range(3000):
        row = random.choice(range(4))
        col = 0
        action = get_action(row, col)
        reward_sum = 0
        #         print(action)
        while get_state(row, col) not in ['terminal', 'trap']:
            next_row, next_col, reward = envirment(row, col, action)
            reward_sum += reward
            #             print(row,col,next_row,next_col)
            next_action = get_action(next_row, next_col)
            TD_error = TD_sarsa(row, col, action, reward, next_row, next_col, next_action)  # Q_learn时可以少传一个变量next_actio
            Q_pi[row, col, action] -= 0.1 * TD_error
            #             print(row,col,next_row,next_col)
            row = next_row
            col = next_col
            action = next_action
        #         print("epoch")
        if epoch % 150 == 0:
            print(epoch, reward_sum)
train()
#打印游戏,方便测试
def show(row, col, action):
    graph = [
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
        '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',
        '○', '○', '○', '○', '○', '❤'
    ]
    action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]
    graph[row * 12 + col] = action
    graph = ''.join(graph)
    for i in range(0, 4 * 12, 12):
        print(graph[i:i + 12])
from IPython import display
import time
def test():
    #起点
    row = random.choice(range(4))
    col = 0
    #最多玩N步
    for _ in range(200):
        #获取当前状态,如果状态是终点或者掉陷阱则终止
        if get_state(row, col) in ['trap', 'terminal']:
            break
        #选择最优动作
        action = Q_pi[row, col].argmax()
        #打印这个动作
        display.clear_output(wait=True)
        time.sleep(0.1)
        show(row, col, action)
        #执行动作
        row, col, reward = envirment(row, col, action)
print(test())
#打印所有格子的动作倾向
for row in range(4):
    line = ''
    for col in range(12):
        action = Q_pi[row, col].argmax()
        action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]
        line += action
    print(line)
结果:
 value表格指示的action
 
测试结果如下:
 
 需要注意的是sarsa算法是跟一个策略函数
    
     
      
       
        π
       
      
      
       \pi
      
     
    π相关联的,应该是通过
    
     
      
       
        π
       
      
      
       \pi
      
     
    π来获取
    
     
      
       
        
         a
        
        
         t
        
       
       
        和
       
       
        
         a
        
        
         
          t
         
         
          +
         
         
          1
         
        
       
      
      
       a_t和a_{t+1}
      
     
    at和at+1的,但是这个代码里没有策略函数
    
     
      
       
        π
       
      
      
       \pi
      
     
    π,所以直接用value表格来求
    
     
      
       
        
         a
        
        
         t
        
       
       
        和
       
       
        
         a
        
        
         
          t
         
         
          +
         
         
          1
         
        
       
      
      
       a_t和a_{t+1}
      
     
    at和at+1了,sarsa算法通常是在 Actor-Critic中担任’裁判’



















