【强化学习】13 —— Actor-Critic 算法

news2025/7/27 20:57:20

文章目录

  • REINFORCE 存在的问题
  • Actor-Critic
  • A2C: Advantageous Actor-Critic
  • 代码实践
    • 结果
  • 参考

REINFORCE 存在的问题

  • 基于片段式数据的任务
    • 通常情况下,任务需要有终止状态,REINFORCE才能直接计算累计折扣奖励
  • 低数据利用效率
    • 实际中,REINFORCE需要大量的训练数据
  • 高训练方差(最重要的缺陷
    • 从单个或多个片段中采样到的值函数具有很高的方差

Actor-Critic

在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,用于指导策略的更新。REINFOCE 算法用蒙特卡洛方法来估计 Q ( s , a ) Q(s,a) Q(s,a),能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。
在这里插入图片描述
评论家Critic Q Φ ( s , a ) Q_\Phi (s,a) QΦ(s,a):

  • 学会准确估计当前演员策略(actor policy)的动作价值。通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。 Q Φ ( s , a ) ≃ r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] Q_\Phi(s,a)\simeq r(s,a)+\gamma\mathbb{E}_{s^{\prime}\thicksim p(s^{\prime}|s,a),a^{\prime}\thicksim\pi_\theta(a^{\prime}|s^{\prime})}[Q_\Phi(s^{\prime},a^{\prime})] QΦ(s,a)r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]

演员Actor π θ ( s , a ) \pi_\theta(s,a) πθ(s,a):

  • 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。 J ( θ ) = E s ∼ p , π θ [ π θ ( a ∣ s ) Q Φ ( s , a ) ] ∂ f ( θ ) ∂ θ = E π θ [ ∂ log ⁡ π θ ( a ∣ s ) ∂ θ Q Φ ( s , a ) ] \begin{aligned}J(\theta)&=\mathbb{E}_{s\sim p,\pi_\theta}[\pi_\theta(a|s)Q_\Phi(s,a)]\\\\\frac{\partial f(\theta)}{\partial\theta}&=\mathbb{E}_{\pi_\theta}\left[\frac{\partial\log\pi_\theta(a|s)}{\partial\theta}Q_\Phi(s,a)\right]\end{aligned} J(θ)θf(θ)=Esp,πθ[πθ(as)QΦ(s,a)]=Eπθ[θlogπθ(as)QΦ(s,a)]

A2C: Advantageous Actor-Critic

思想:通过减去一个基线函数来标准化评论家的打分

  • 更多信息指导:降低较差动作概率,提高较优动作概率
  • 进一步降低方差

优势函数(Advantage Function) A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^\pi(s,a)=Q^\pi(s,a)-V^\pi(s) Aπ(s,a)=Qπ(s,a)Vπ(s)
在这里插入图片描述
若只采用动作值的方式,虽然也会选择A2,但是方差相对会更大,同时所有的动作都是出于上升的状态,只是上升程度的问题。而采用优势函数的方式,部分动作的优势函数值是负的,可以直接降低相应动作的概率,同时方差更小。

状态-动作值和状态值函数 Q π ( s , a ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) ] \begin{aligned} Q^{\pi}(s,a)& =r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a),a^{\prime}\sim\pi_\theta(a^{\prime}|s^{\prime})}\left[Q_\Phi(s^{\prime},a^{\prime})\right] \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})] \end{aligned} Qπ(s,a)=r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]=r(s,a)+γEsp(ss,a)[Vπ(s)]

因此我们只需要拟合状态值函数来拟合优势函数 A π ( s , a ) = Q π ( s , a ) − V π ( s ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) − V π ( s ) ] ≃ r ( s , a ) + γ ( V π ( s ′ ) − V π ( s ) ) \begin{aligned} A^{\pi}(s,a)& =Q^\pi(s,a)-V^\pi(s) \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})-V^{\pi}(s)] \\ &\simeq r(s,a)+\gamma(V^{\pi}(s^{\prime})-V^{\pi}(s)) \end{aligned} Aπ(s,a)=Qπ(s,a)Vπ(s)=r(s,a)+γEsp(ss,a)[Vπ(s)Vπ(s)]r(s,a)+γ(Vπ(s)Vπ(s))


在策略梯度中,可以把梯度写成下面这个更加一般的形式: g = E [ ∑ t = 0 T ψ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0Tψtθlogπθ(atst)]其中, ψ t \psi_t ψt可以有很多种形式: 1. ∑ t ′ = 0 T γ t ′ r t ′ : 轨迹的总回报; 2. ∑ t ′ = t T γ t ′ − t r t ′ : 动作 a t 之后的回报; 3. ∑ t ′ = t T γ t ′ − t r t ′ − b ( s t ) : 基准线版本的改进 ; 4. Q π θ ( s t , a t ) : 动作价值函数; 5. A π θ ( s t , a t ) : 优势函数; 6. r t + γ V π θ ( s t + 1 ) − V π θ ( s t ) : 时序差分残差。 \begin{aligned} &1.\sum_{t^{\prime}=0}^T\gamma^{t^{\prime}}r_{t^{\prime}}:\textit{轨迹的总回报;} \\ &2.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}:\textit{动作}a_t\textit{之后的回报;} \\ &\begin{aligned}3.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}-b(s_t):\textit{基准线版本的改进};\end{aligned} \\ &4.Q^{\pi_\theta}(s_t,a_t):\textit{动作价值函数;} \\ &5.A^{\pi_\theta}(s_t,a_t):\textit{优势函数;} \\ &6.r_t+\gamma V^{\pi_\theta}(s_{t+1})-V^{\pi_\theta}(s_t):\textit{时序差分残差。} \end{aligned} 1.t=0Tγtrt:轨迹的总回报;2.t=tTγttrt:动作at之后的回报;3.t=tTγttrtb(st):基准线版本的改进;4.Qπθ(st,at):动作价值函数;5.Aπθ(st,at):优势函数;6.rt+γVπθ(st+1)Vπθ(st):时序差分残差。

REINFORCE 通过蒙特卡洛采样的方法对策略梯度的估计是无偏的,但是方差非常大。我们可以用形式(3)引入基线函数 b ( s t ) b(s_t) b(st)(baseline function)来减小方差。此外,我们也可以采用 Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是形式(4)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了 A A A函数,我们称之为优势函数(advantage function),这便是形式(5)。更进一步,我们可以利用等式 Q = r + γ V Q=r+\gamma V Q=r+γV得到形式(6)。

Actor 的更新采用策略梯度的原则,那 Critic 如何更新呢?我们将 Critic 价值网络表示为 V ω V_\omega Vω,参数为 ω \omega ω。于是,我们可以采取时序差分残差的学习方式,对于单个数据定义如下价值函数的损失函数: L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac12(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21(r+γVω(st+1)Vω(st))2

与 DQN 中一样,我们采取类似于目标网络的方法,将上式中 r + γ V ω ( s t + 1 ) r+\gamma V_\omega(s_{t+1}) r+γVω(st+1)作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:
∇ ω L ( ω ) = − ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) ∇ ω V ω ( s t ) \nabla_{\omega}\mathcal{L}(\omega)=-(r+\gamma V_{\omega}(s_{t+1})-V_{\omega}(s_{t}))\nabla_{\omega}V_{\omega}(s_{t}) ωL(ω)=(r+γVω(st+1)Vω(st))ωVω(st)

然后使用梯度下降方法来更新 Critic 价值网络参数即可。

算法伪代码:

在这里插入图片描述

代码实践

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import util

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

# 输入是某个状态,输出则是状态的价值。
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,
                device, numOfEpisodes, env):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.gamma = gamma
        self.device = device
        self.env = env
        self.numOfEpisodes = numOfEpisodes

    # 根据动作概率分布随机采样
    def takeAction(self, state):
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        action_probs = self.actor(state)
        action_dist = torch.distributions.Categorical(action_probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
        terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)
        truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)
        # 时序差分目标
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - terminateds + truncateds)
        # 时序差分误差
        td_delta = td_target - self.critic(states)
        log_probs = torch.log(self.actor(states).gather(1, actions))
        # 均方误差损失函数
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.actor_optimizer.step()
        self.critic_optimizer.step()

    def ACTrain(self):
        returnList = []
        for i in range(10):
            with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:
                for episode in range(int(self.numOfEpisodes / 10)):
                    # initialize state
                    state, info = self.env.reset()
                    terminated = False
                    truncated = False
                    episodeReward = 0
                    transition_dict = {
                        'states': [],
                        'actions': [],
                        'next_states': [],
                        'rewards': [],
                        'terminateds': [],
                        'truncateds': []
                    }
                    # Loop for each step of episode:
                    while 1:
                        action = self.takeAction(state)
                        next_state, reward, terminated, truncated, info = self.env.step(action)
                        transition_dict['states'].append(state)
                        transition_dict['actions'].append(action)
                        transition_dict['next_states'].append(next_state)
                        transition_dict['rewards'].append(reward)
                        transition_dict['terminateds'].append(terminated)
                        transition_dict['truncateds'].append(truncated)
                        state = next_state
                        episodeReward += reward
                        if terminated or truncated:
                            break
                    self.update(transition_dict)
                    returnList.append(episodeReward)
                    if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报
                        pbar.set_postfix({
                            'episode':
                                '%d' % (self.numOfEpisodes / 10 * i + episode + 1),
                            'return':
                                '%.3f' % np.mean(returnList[-10:])
                        })
                    pbar.update(1)
        return returnList

结果

在这里插入图片描述
在这里插入图片描述
可以发现,Actor-Critic 算法很快便能收敛到最优策略,并且训练过程非常稳定,抖动情况相比 REINFORCE 算法有了明显的改进,这说明价值函数的引入减小了方差。

参考

[1] 伯禹AI
[2] https://www.davidsilver.uk/teaching/
[3] 动手学强化学习
[4] Reinforcement Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1156673.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java】多线程案例(单例模式,阻塞队列,定时器,线程池)

❤️ Author: 老九 ☕️ 个人博客:老九的CSDN博客 🙏 个人名言:不可控之事 乐观面对 😍 系列专栏: 文章目录 实现安全版本的单例模式饿汉模式类和对象的概念类对象类的静态成员与实例成员 懒汉模式如何保证…

C++设计模式_21_Iterator 迭代器(理解;面向对象的迭代器已过时;C++中使用泛型编程的方式实现)

Iterator 迭代器也是属于“数据结构”模式。GoF中面向对象的迭代器已经过时,C中目前使用泛型编程的方式实现,其他语言还在使用面向对象的迭代器。 文章目录 1. 动机(Motivation)2. 模式定义3. Iterator 迭代器代码分析4. 面向对象的迭代器与泛型编程实现…

一天写一个(前端、后端、全栈)个人简历项目(附详源码)

一、项目简介 此项目是用前端技术HTMLCSSjquery写的一个简单的个人简历项目模板,图片可点击放大查看,还可以直接下载你的word或者PDF的简历模板。 如果有需要的同学可以直接拿去使用,需自行填写个人的详细信息,发布,…

uniapp 开发微信小程序 v-bind给子组件传递函数,该函数中的this不是父组件的二是子组件的this

解决办法:子组件通过缓存子组件this然后,用bind改写this 这个方法因为定义了全局变量that 那么该变量就只能用一次,不然会有赋值覆盖的情况。 要么就弃用v-bind传入函数,改为emit传入自定义事件 [uniapp] uview(1.x) 二次封装u-navbar 导致…

程序开发设计原则

(图片来自网络) 单一职责 Single Responsibility Principle 不论是在设计类,接口还是方法,单一职责都会处处体现,单一职责的定义:我们把职责定义为系统变化的原因。所有在定 义类,接口&#xff…

CV2 将图片中某个点与中心点的角度变换成0-360度

众所周知,CV2中的坐标方向是这样的: 所以一般我们想计算图片中某个点P1(x1,y1)与中心点P0(x0,y0)的方向时,我们会先将y坐标翻上去,然后计算角度。即: p1_xint(x1) # p1_yint(y1)p0_xint(x0) #图像大小为512*512中心点坐标为25…

PO-提示json不能为空 not valid json at character 2 of ““““

问题描述: 调用第三方REST接口,提示提示json不能为空 not valid json at character 2 of """" 原因分析: 一般都是对方接收后出现错误没有处理,返回空值;有可能是他们映射有问题 解决方案&…

小程序获取头像和昵称的思路

小程序获取头像和昵称的基本方法是调用小程序自带的API wx.getUserProfile(),这也是小程序官方目前最推荐的做法。成功获取用户名头像之后,小程序允许保存调用的结果,以便下一次打开页面的时候自动显示头像和名字。保存用户名和头像并不是保存…

JMeter组件

1.JMeter常用组件 必须组件:测试计划,线程组(包含多个线程),取样器 测试计划,JMeter默认创建且仅有一个 线程组: 添加步骤: 选择TestPlan并点击鼠标右键添加 分类以及使用&…

MyBatisPlus 使用枚举

MyBatisPlus 使用枚举 表中的有些字段值是固定的,例如性别(男或女),此时我们可以使用MyBatis-Plus的通用枚举来实现 数据库表添加字段sex 创建通用枚举类型 Getter public enum SexEnum {MALE(1, "男"),FEMALE(2, &qu…

MySQL与MongoDB,该如何做技术选型?

hello,大家好,我是张张,「架构精进之路」公号作者。 引言 一般情况下,会考虑到MySQL与MongoDB如何做技术选型的时候,你一定是遇到了类似于非结构化数据JSON的存取难题,否则大家都直接MySQL开始搞起了。 为什…

西工大CSAPP第二章课后题2.56~2.58答案及解析

因为我获取并阅读CSAPP电子书的方式是通过第三方网站免费下载,没有付给原书作者相应的报酬,遵循价值交换原则,我会尽我所能通过博客的方式,推广这本书以及原书作者就职的大学,以此回馈原书作者的劳动成果。另外&#x…

【JMeter】逻辑控制器分类以及功能介绍

常用逻辑控制器的分类以及介绍 If Controller 满足if条件才会执行取样器 Loop Controller 对取样器循环多次 ForEach Controller

MySQL笔记--SQL语句

1--SQL的通用语法 2--SQL语句的分类 3--DDL语句 3-1--数据库操作 # 查询所有数据库 show databases;# 查询当前使用数据库 select database();# 创建数据库 create database 数据库名 create database if not exists 数据库名; # 不存在时创建,存在则不创建 creat…

【设计模式】第25节:行为型模式之“访问者模式”

一、简介 访问者模式允许一个或者多个操作应用到一组对象上,设计意图是解耦操作和对象本身,保持类职责单一、满足开闭原则以及应对代码的复杂性。 二、优点 分离操作和数据结构增加新操作更容易集中化操作 三、适用场景 数据结构稳定,操…

数据结构与算法-(7)---栈的应用拓展-前缀表达式转换+求值

🌈write in front🌈 🧸大家好,我是Aileen🧸.希望你看完之后,能对你有所帮助,不足请指正!共同学习交流. 🆔本文由Aileen_0v0🧸 原创 CSDN首发🐒 如…

测试C#调用Aplayer播放视频(3:编写简易播放器)

学习了参考文献1中的示例代码,也找出了前一篇文章中自己测试控件但无法播放视频的问题(没有将解码库文件复制到可执行程序所在的codecs文件夹内),本文基于APlayer组件编写简单的视频播放器,主要实现以下功能&#xff1…

iOS实现弹簧放大动画

效果图 实现代码 - (void)setUpContraints {CGFloat topImageCentery (SCREEN_HEIGHT - 370 * PLUS_SCALE) / 2;[self.topIconView mas_makeConstraints:^(MASConstraintMaker *make) {make.centerX.mas_equalTo(0);make.centerY.equalTo(self.view.mas_top).with.offset(t…

一定要看看的大模型【评测基准】及【评测报告】

评测标准 1.能力基础评测 为了检验大语言模型(LLM)的有效性和优越性,已有研究采用了大量的任务和基准数据集来进行实证评估和分析。根据任务定义,现有语言生成的任务主要可以分为语言建模、条件文本生成和代码合成任务。需要注意的是,代码合成不是典型的自然语言处理任务…

Qt6:子窗口向父窗口传值

终于解决了这个问题!这才怀着激动的心情跑来记录一下。你们是不知道这其中的艰辛啊,太难了,差亿点就放弃学Qt了…… 此处苦水省略一万字…… 关于子窗口向父窗口传值的方法,在网上搜了不下百遍,免费的、付费下载、会员…