强化学习原理python篇08——actor-critic
- 前置知识
 - TD Error
 - REINFORCE
 - QAC
 - Advantage actor-critic (A2C)
 
- torch实现步骤
 - 第一步
 - 第二步
 - 第三步
 - 训练
 - 结果
 
- Ref
 
本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Actor-Critic Methods 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。
前置知识
TD Error
如果用 
     
      
       
        
        
          v 
         
        
          ^ 
         
        
       
         ( 
        
       
         s 
        
       
         , 
        
       
         w 
        
       
         ) 
        
       
      
        \hat v(s,w) 
       
      
    v^(s,w)代表状态值函数,则TD Error表示为
  
      
       
        
         
         
           r 
          
          
          
            t 
           
          
            + 
           
          
            1 
           
          
         
        
          + 
         
        
          γ 
         
         
         
           v 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           s 
          
          
          
            t 
           
          
            + 
           
          
            1 
           
          
         
        
          , 
         
        
          w 
         
        
          ) 
         
        
          − 
         
         
         
           v 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           s 
          
         
           t 
          
         
        
          , 
         
        
          w 
         
        
          ) 
         
        
       
         r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w) 
        
       
     rt+1+γv^(st+1,w)−v^(st,w)
令损失函数
  
      
       
        
         
         
           J 
          
         
           w 
          
         
        
          = 
         
        
          E 
         
        
          [ 
         
        
          v 
         
        
          ( 
         
         
         
           s 
          
         
           t 
          
         
        
          ) 
         
        
          − 
         
         
         
           v 
          
         
           ^ 
          
         
        
          ( 
         
         
         
           s 
          
         
           t 
          
         
        
          , 
         
        
          w 
         
        
          ) 
         
         
         
           ] 
          
         
           2 
          
         
        
       
         J_w = E[ v(s_{t}) -\hat v(s_{t},w)]^2 
        
       
     Jw=E[v(st)−v^(st,w)]2
则利用梯度下降法最小化 
     
      
       
        
        
          J 
         
        
          θ 
         
        
       
      
        J_\theta 
       
      
    Jθ为
  
      
       
        
         
          
           
            
             
             
               w 
              
              
              
                k 
               
              
                + 
               
              
                1 
               
              
             
            
              = 
             
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              − 
             
            
              α 
             
             
             
               ∇ 
              
             
               w 
              
             
            
              J 
             
            
              ( 
             
             
             
               w 
              
             
               k 
              
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              − 
             
            
              α 
             
            
              [ 
             
            
              − 
             
            
              2 
             
            
              E 
             
            
              ( 
             
            
              [ 
             
             
             
               r 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              + 
             
            
              γ 
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              − 
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ] 
             
            
              ) 
             
            
              ] 
             
             
             
               ∇ 
              
             
               w 
              
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ) 
             
            
           
          
         
        
       
         \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k -\alpha[-2E([r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w)])]\nabla_w \hat v(s_{t},w)) \end{align*} 
        
       
     wk+1==wk−α∇wJ(wk)wk−α[−2E([rt+1+γv^(st+1,w)−v^(st,w)])]∇wv^(st,w))
用随机梯度来估算,则最小化 
     
      
       
        
        
          J 
         
        
          θ 
         
        
       
      
        J_\theta 
       
      
    Jθ为
  
      
       
        
         
          
           
            
             
             
               w 
              
              
              
                k 
               
              
                + 
               
              
                1 
               
              
             
            
              = 
             
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              − 
             
            
              α 
             
             
             
               ∇ 
              
             
               w 
              
             
            
              J 
             
            
              ( 
             
             
             
               w 
              
             
               k 
              
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              + 
             
            
              α 
             
            
              [ 
             
             
             
               r 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              + 
             
            
              γ 
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              − 
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ] 
             
             
             
               ∇ 
              
             
               w 
              
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              + 
             
            
              α 
             
            
              [ 
             
            
              v 
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              ) 
             
            
              − 
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ] 
             
             
             
               ∇ 
              
             
               w 
              
             
             
             
               v 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ) 
             
            
           
          
         
        
       
         \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k +\alpha[r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w)]\nabla_w \hat v(s_{t},w))\\ =& w_k +\alpha[ v(s_{t}) -\hat v(s_{t},w)]\nabla_w \hat v(s_{t},w))\\ \end{align*} 
        
       
     wk+1===wk−α∇wJ(wk)wk+α[rt+1+γv^(st+1,w)−v^(st,w)]∇wv^(st,w))wk+α[v(st)−v^(st,w)]∇wv^(st,w))
对于q—value来说,
  
      
       
        
         
          
           
            
             
             
               w 
              
              
              
                k 
               
              
                + 
               
              
                1 
               
              
             
            
              = 
             
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              − 
             
            
              α 
             
             
             
               ∇ 
              
             
               w 
              
             
            
              J 
             
            
              ( 
             
             
             
               w 
              
             
               k 
              
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               w 
              
             
               k 
              
             
            
              + 
             
            
              α 
             
            
              [ 
             
             
             
               r 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              + 
             
            
              γ 
             
             
             
               q 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              , 
             
             
             
               a 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              − 
             
             
             
               q 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
             
             
               a 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ] 
             
             
             
               ∇ 
              
             
               w 
              
             
             
             
               q 
              
             
               ^ 
              
             
            
              ( 
             
             
             
               s 
              
             
               t 
              
             
            
              , 
             
             
             
               a 
              
             
               t 
              
             
            
              , 
             
            
              w 
             
            
              ) 
             
            
              ) 
             
            
           
          
         
        
       
         \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k +\alpha[r_{t+1}+\gamma \hat q(s_{t+1}, a_{t+1},w) -\hat q(s_{t}, a_{t},w)]\nabla_w \hat q(s_{t},a_{t},w))\\ \end{align*} 
        
       
     wk+1==wk−α∇wJ(wk)wk+α[rt+1+γq^(st+1,at+1,w)−q^(st,at,w)]∇wq^(st,at,w))
REINFORCE
参考上一节
 
      
       
        
         
          
           
            
             
             
               θ 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              = 
             
            
           
          
          
           
            
             
             
             
               θ 
              
             
               t 
              
             
            
              + 
             
             
             
               ∇ 
              
             
               θ 
              
             
            
              J 
             
            
              ( 
             
             
             
               θ 
              
             
               t 
              
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               θ 
              
             
               t 
              
             
            
              + 
             
             
             
               ∇ 
              
             
               θ 
              
             
             
             
               E 
              
              
              
                S 
               
              
                − 
               
              
                d 
               
              
                , 
               
              
                a 
               
              
                − 
               
              
                π 
               
              
                ( 
               
              
                S 
               
              
                , 
               
              
                Θ 
               
              
                ) 
               
              
             
            
              [ 
             
            
              q 
             
            
              ( 
             
            
              s 
             
            
              , 
             
            
              a 
             
            
              ) 
             
             
             
               ∇ 
              
             
               θ 
              
             
            
              l 
             
            
              n 
             
            
              π 
             
            
              ( 
             
            
              a 
             
            
              ∣ 
             
            
              s 
             
            
              , 
             
            
              θ 
             
            
              ) 
             
            
              ] 
             
            
           
          
         
        
       
         \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}J(θ_t)\\=& θ_{t} + \nabla _{\theta}E_{S-d,a-\pi(S,\Theta)}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} 
        
       
     θt+1==θt+∇θJ(θt)θt+∇θES−d,a−π(S,Θ)[q(s,a)∇θlnπ(a∣s,θ)]
 一般来说, 
     
      
       
        
        
          ∇ 
         
        
          θ 
         
        
       
         l 
        
       
         n 
        
       
         π 
        
       
         ( 
        
       
         a 
        
       
         ∣ 
        
       
         s 
        
       
         , 
        
       
         θ 
        
       
         ) 
        
       
      
        \nabla _{\theta}ln\pi(a|s,\theta) 
       
      
    ∇θlnπ(a∣s,θ)是未知的,可以用随机梯度法来估计,则
  
      
       
        
         
          
           
            
             
             
               θ 
              
              
              
                t 
               
              
                + 
               
              
                1 
               
              
             
            
              = 
             
            
           
          
          
           
            
             
             
             
               θ 
              
             
               t 
              
             
            
              + 
             
             
             
               ∇ 
              
             
               θ 
              
             
            
              J 
             
            
              ( 
             
             
             
               θ 
              
             
               t 
              
             
            
              ) 
             
            
           
          
         
         
          
           
           
             = 
            
           
          
          
           
            
             
             
             
               θ 
              
             
               t 
              
             
            
              + 
             
             
             
               ∇ 
              
             
               θ 
              
             
            
              [ 
             
            
              q 
             
            
              ( 
             
            
              s 
             
            
              , 
             
            
              a 
             
            
              ) 
             
             
             
               ∇ 
              
             
               θ 
              
             
            
              l 
             
            
              n 
             
            
              π 
             
            
              ( 
             
            
              a 
             
            
              ∣ 
             
            
              s 
             
            
              , 
             
            
              θ 
             
            
              ) 
             
            
              ] 
             
            
           
          
         
        
       
         \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}J(θ_t)\\=& θ_{t} + \nabla _{\theta}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} 
        
       
     θt+1==θt+∇θJ(θt)θt+∇θ[q(s,a)∇θlnπ(a∣s,θ)]
QAC
The simplest actor-critic algorithm
-  
actor:更新策略
我们采用reinforce的方法来更新策略函数 π \pi π, θ t + 1 = θ t + ∇ θ [ q ( s , a ) ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+∇θ[q(s,a)∇θlnπ(a∣s,θ)]
 -  
critic:更新值
我们采用优化td-error的方法来更新行动值 q q q,
w k + 1 = w k + α [ r t + 1 + γ q ^ ( s t + 1 , a t + 1 , w ) − q ^ ( s t , a t , w ) ] ∇ w q ^ ( s t , a t , w ) ) \begin{align*} w_{k+1} =& w_k +\alpha[r_{t+1}+\gamma \hat q(s_{t+1}, a_{t+1},w) -\hat q(s_{t}, a_{t},w)]\nabla_w \hat q(s_{t},a_{t},w)) \end{align*} wk+1=wk+α[rt+1+γq^(st+1,at+1,w)−q^(st,at,w)]∇wq^(st,at,w)) 
Advantage actor-critic (A2C)
减小方差的下一步是使基线与状态相关(这是一个好主意,因为不同的状态可能具有非常不同的基线)。确实,要决定某个特定动作在某种状态下的适用性,我们会使用该动作的折扣总奖励。但是,总奖励本身可以表示为状态的价值加上动作的优势值:Q(s,a)=V(s)+A(s,a)(参见DuelingDQN)。
知道每个状态的价值(至少有一个近似值)后,我们就可以用它来计算策略梯度并更新策略网络,以增加具有良好优势值的动作的执行概率,并减少具有劣势优势值的动作的执行概率。策略网络(返回动作的概率分布)被称为行动者(actor),因为它会告诉我们该做什么。另一个网络称为评论家(critic),因为它能使我们了解自己的动作有多好。这种改进有一个众所周知的名称,即advantage actorcritic方法,通常被简称为A2C。
  
      
       
        
         
         
           E 
          
          
          
            S 
           
          
            − 
           
          
            d 
           
          
            , 
           
          
            a 
           
          
            − 
           
          
            π 
           
          
            ( 
           
          
            S 
           
          
            , 
           
          
            Θ 
           
          
            ) 
           
          
         
        
          [ 
         
        
          q 
         
        
          ( 
         
        
          s 
         
        
          , 
         
        
          a 
         
        
          ) 
         
         
         
           ∇ 
          
         
           θ 
          
         
        
          l 
         
        
          n 
         
        
          π 
         
        
          ( 
         
        
          a 
         
        
          ∣ 
         
        
          s 
         
        
          , 
         
        
          θ 
         
        
          ) 
         
        
          ] 
         
        
          = 
         
         
         
           E 
          
          
          
            S 
           
          
            − 
           
          
            d 
           
          
            , 
           
          
            a 
           
          
            − 
           
          
            π 
           
          
            ( 
           
          
            S 
           
          
            , 
           
          
            Θ 
           
          
            ) 
           
          
         
        
          [ 
         
         
         
           ∇ 
          
         
           θ 
          
         
        
          l 
         
        
          n 
         
        
          π 
         
        
          ( 
         
        
          a 
         
        
          ∣ 
         
        
          s 
         
        
          , 
         
        
          θ 
         
        
          ) 
         
        
          [ 
         
        
          q 
         
        
          ( 
         
        
          s 
         
        
          , 
         
        
          a 
         
        
          ) 
         
        
          − 
         
        
          v 
         
        
          ( 
         
        
          s 
         
        
          ) 
         
        
          ] 
         
        
          ] 
         
        
       
         E_{S-d,a-\pi(S,\Theta)}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)]=E_{S-d,a-\pi(S,\Theta)}[\nabla _{\theta}ln\pi(a|s,\theta)[q(s,a) -v(s)]] 
        
       
     ES−d,a−π(S,Θ)[q(s,a)∇θlnπ(a∣s,θ)]=ES−d,a−π(S,Θ)[∇θlnπ(a∣s,θ)[q(s,a)−v(s)]]
-  
Advantage(TD error)
δ t = r t + 1 + γ v ( s t + 1 ; w t ) − v ( s t ; w t ) \delta_t =r_{t+1}+\gamma v(s_{t+1};w_t)- v(s_t;w_t) δt=rt+1+γv(st+1;wt)−v(st;wt)
 -  
actor:更新策略
我们采用reinforce的方法来更新策略函数 π \pi π,
θ t + 1 = θ t + a δ t ∇ θ [ ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + a\delta_t\nabla _{\theta}[\nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+aδt∇θ[∇θlnπ(a∣s,θ)]
 -  
critic:更新值
1、我们采用优化td-error的方法来更新状态值 v v v, w k + 1 = w k − α ∇ w [ v ( s t , w ) − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[ v(s_{t},w) -\hat v(s_{t},w)]^2 \end{align*} wk+1=wk−α∇w[v(st,w)−v^(st,w)]2
2、在这里,使用实际发生的discount reward来估算 v ( s t , w ) v(s_{t},w) v(st,w)
3、 w k + 1 = w k − α ∇ w [ R − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[R -\hat v(s_{t},w)]^2 \end{align*} wk+1=wk−α∇w[R−v^(st,w)]2
 
torch实现步骤
第一步
- 初始化A2CNet,使其返回策略函数pi(s, a),和价值V(s)
 
import collections
import copy
import math
import random
import time
from collections import defaultdict
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter
class A2CNet(nn.Module):
    def __init__(self, obs_size, hidden_size, q_table_size):
        super().__init__()
        # 策略函数pi(s, a)
        self.policy_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, q_table_size),
            nn.Softmax(dim=1),
        )
        # 价值V(s)
        self.v_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )
    def forward(self, state):
        if len(torch.Tensor(state).size()) == 1:
            state = state.reshape(1, -1)
        return self.policy_net(state), self.v_net(state)
 
第二步
- 使用当前策略πθ在环境中交互N步,并保存状态(st)、动作(at)和奖励(rt)
 - 如果片段到达结尾,则R=0,否则为Vθ(st),这里采用环境产生的R来近似。
 
def discount_reward(R, gamma):
    # r 为历史得分
    n = len(R)
    dr = 0
    for i in range(n):
        dr += gamma**i * R[i]
    return dr
def generate_episode(env, n_steps, net, gamma, predict=False):
    episode_history = dict()
    r_list = []
    for _ in range(n_steps):
        episode = []
        predict_reward = []
        state, info = env.reset()
        while True:
            p, v = net(torch.Tensor(state))
            p = p.detach().numpy().reshape(-1)
            action = np.random.choice(list(range(env.action_space.n)), p=p)
            next_state, reward, terminated, truncted, info = env.step(action)
            # 如果截断,则展开 v(state) = r + gamma*v(next_state)
            if truncted and not terminated:
                reward = reward + gamma * float(
                    net(torch.Tensor(next_state))[1].detach()
                )
            episode.append([state, action, next_state, reward, terminated])
            predict_reward.append(reward)
            state = next_state
            if terminated or truncted:
                episode_history[_] = episode
                r_list.append(len(episode))
                episode = []
                predict_reward = []
                break
    if predict:
        return np.mean(r_list)
    return episode_history
def calculate_t_discount_reward(reward_list, gamma):
    discount_reward = []
    total_reward = 0
    for i in reward_list[::-1]:
        total_reward = total_reward * gamma + i
        discount_reward.append(total_reward)
    return discount_reward[::-1]
 
第三步
-  
累积策略梯度 θ t + 1 = θ t + a δ t ∇ θ [ ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + a\delta_t\nabla _{\theta}[\nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+aδt∇θ[∇θlnπ(a∣s,θ)]
 -  
累积价值梯度
w k + 1 = w k − α ∇ w [ R − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[R -\hat v(s_{t},w)]^2 \end{align*} wk+1=wk−α∇w[R−v^(st,w)]2 
# actor策略损失函数
def loss(net, batch, gamma, entropy_beta=False):
    l = 0
    for episode in batch.values():
        reward_list = [
            reward for state, action, next_state, reward, terminated in episode
        ]
        state = [state for state, action, next_state, reward, terminated in episode]
        action = [action for state, action, next_state, reward, terminated in episode]
        # actor策略损失函数
        ## max entropy
        qt = calculate_t_discount_reward(reward_list, gamma)
        pi = net(torch.Tensor(state))[0]
        entropy_loss = -torch.sum((pi * torch.log(pi)), axis=1).mean() * entropy_beta
        pi = pi.gather(dim=1, index=torch.LongTensor(action).reshape(-1, 1))
        l_policy = -torch.Tensor(qt) @ torch.log(pi)
        if entropy_beta:
            l_policy -= entropy_loss
        # critic损失函数
        critic_loss = nn.MSELoss()(
            net(torch.Tensor(state))[1].reshape(-1), torch.Tensor(qt)
        )
        l += l_policy + critic_loss
    return l / len(batch.values())
 
训练
## 初始化环境
env = gym.make("CartPole-v1", max_episode_steps=200)
# env = gym.make("CartPole-v1", render_mode = "human")
state, info = env.reset()
obs_n = env.observation_space.shape[0]
hidden_num = 64
act_n = env.action_space.n
a2c = A2CNet(obs_n, hidden_num, act_n)
# 定义优化器
opt = optim.Adam(a2c.parameters(), lr=0.01)
# 记录
writer = SummaryWriter(log_dir="logs/PolicyGradient/A2C", comment="test1")
epochs = 200
batch_size = 20
gamma = 0.9
entropy_beta = 0.01
# 避免梯度太大
CLIP_GRAD = 0.1
for epoch in range(epochs):
    batch = generate_episode(env, batch_size, a2c, gamma)
    l = loss(a2c, batch, gamma, entropy_beta)
    # 反向传播
    opt.zero_grad()
    l.backward()
    # 梯度裁剪
    nn_utils.clip_grad_norm_(a2c.parameters(), CLIP_GRAD)
    opt.step()
    max_steps = generate_episode(env, 10, a2c, gamma, predict=True)
    writer.add_scalars(
        "Loss",
        {"loss": l.item(), "max_steps": max_steps},
        epoch,
    )
    print("epoch:{},  Loss: {}, max_steps: {}".format(epoch, l.detach(), max_steps))
 
结果

 可以看到,对比上一节的几种方法,收敛速度和收敛方向都稳定了不少。
Ref
[1] Mathematical Foundations of Reinforcement Learning,Shiyu Zhao
 [2] 深度学习强化学习实践(第二版),Maxim Lapan














![[office] excel2010双向条形图制作 #经验分享#微信](https://img-blog.csdnimg.cn/img_convert/75daee5e3a71beb1cbba7365f481905c.jpeg)


