【深度强化学习 DRL 快速实践】逆向强化学习算法 (IRL)

news2025/7/10 19:40:42

Inverse Reinforcement Learning (IRL) 详解

什么是 Inverse Reinforcement Learning?

在传统的强化学习 (Reinforcement Learning, RL) 中,奖励函数是已知的,智能体的任务是学习一个策略来最大化奖励

而在逆向强化学习 (Inverse Reinforcement Learning, IRL) 中,情况相反:

  • 我们不知道奖励函数 缺失的
  • 但是我们有专家的示范轨迹(比如专家怎么开车、怎么走路): τ = ( s 0 , a 0 , s 1 , a 1 , … , s T ) \tau = (s_0, a_0, s_1, a_1, \dots, s_T) τ=(s0,a0,s1,a1,,sT)
  • 目标是:推断出奖励函数,使得专家行为在该奖励下是最优的

简单来说,IRL 是"从专家行为中推断动机"

  • Initialize an actor
  • In each iteration
    • The actor interacts with the environrment to obtain some trajectories
    • Define a reward functlon, which makes thetrajectories of the teacher better than the actor
    • The actor learns to maximize the reward based on the new reward function
  • Output the reward function and the actor learned from the reward function

IRL算法之 GAIL 算法详解

GAIL(生成对抗模仿学习)结合了:生成对抗网络 GAN(Generator 对抗 Discriminator)和 强化学习 Policy Gradient(策略梯度)

  • 让智能体学会产生像专家一样的轨迹,但不直接学习奖励函数,只通过模仿专家行为来训练策略
判别器 (Discriminator) :试图区分 “专家轨迹” 和 “生成器轨迹”

判别器的目标是最大化对数似然:判别器希望对于专家数据 expert 输出接近 1,对于生成数据 policy 输出接近 0
max ⁡ D E expert [ log ⁡ D ( s , a ) ] + E policy [ log ⁡ ( 1 − D ( s , a ) ) ] \max_D \mathbb{E}_{\text{expert}} [\log D(s, a)] + \mathbb{E}_{\text{policy}} [\log (1 - D(s, a))] DmaxEexpert[logD(s,a)]+Epolicy[log(1D(s,a))]

生成器(策略网络 Policy):试图“欺骗”判别器,让判别器以为它生成的轨迹是专家生成的

生成器的目标是最小化:
min ⁡ π E τ ∼ π [ log ⁡ ( 1 − D ( s , a ) ) ] \min_{\pi} \mathbb{E}_{\tau \sim \pi} [\log (1 - D(s, a))] πminEτπ[log(1D(s,a))]

这其实可以等价强化学习问题,奖励信号变成了:
r ( s , a ) = − log ⁡ ( 1 − D ( s , a ) ) r(s, a) = - \log (1 - D(s, a)) r(s,a)=log(1D(s,a))

  • 这样,跟标准的 policy gradient 非常类似,只不过奖励是来自判别器

GAIL 简单代码示例

import gym
from stable_baselines3 import PPO
from imitation.algorithms.adversarial import GAIL
from imitation.data.types import TrajectoryWithRew
from imitation.data import rollout

# 1. 创建环境
env = gym.make("CartPole-v1")

# 2. 加载或创建专家模型
expert = PPO("MlpPolicy", env, verbose=0)
expert.learn(10000)

# 3. 收集专家轨迹数据
trajectories = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=20)
)

# 4. 创建新模型作为 actor
learner = PPO("MlpPolicy", env, verbose=1)

# 5. 使用 GAIL 进行逆强化学习训练
gail_trainer = GAIL(
    venv=env,
    demonstrations=trajectories,
    gen_algo=learner
)
gail_trainer.train(10000)

# 6. 测试训练后的模型
obs = env.reset()
for _ in range(1000):
    action, _states = learner.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        obs = env.reset()

env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2343375.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《普通逻辑》学习记录——性质命题及其推理

目录 一、性质命题概述 二、性质命题的种类 2.1、性质命题按质的分类 2.2、性质命题按量的分类 2.3、性质命题按质和量结合的分类 2.4、性质命题的基本形式归纳 三、四种命题的真假关系 3.1、性质命题与对象关系 3.2、四种命题的真假判定 3.3、四种命题的对当关系 四、四种命题…

人工智能(AI)对网络管理的影响

近年来,人工智能(AI)尤其是大语言模型(LLMs)的快速发展,正在深刻改变网络管理领域。AI的核心价值在于其能够通过自动化、模式分析和智能决策,显著提升网络运维效率并应对复杂挑战。具体表现为: 快速数据查询与分析​​&#xff1…

embedding_model模型通没有自带有归一化层该怎么处理?

embedding_model 是什么: 嵌入式模型(Embedding)是一种广泛应用于自然语言处理(NLP)和计算机视觉(CV)等领域的机器学习模型,它可以将高维度的数据转化为低维度的嵌入空间&#xff0…

八大排序——冒泡排序/归并排序

八大排序——冒泡排序/归并排序 一、冒泡排序 1.1 冒泡排序 1.2 冒泡排序优化 二、归并排序 1.1 归并排序(递归) 1.2 递归排序(非递归) 一、冒泡排序 1.1 冒泡排序 比较相邻的元素。如果第一个比第二个大,就交换…

银发科技:AI健康小屋如何破解老龄化困局

随着全球人口老龄化程度的不断加深,如何保障老年人的健康、提升他们的生活质量,成为了社会各界关注的焦点。 在这场应对老龄化挑战的战役中,智绅科技顺势而生,七彩喜智慧养老系统构筑居家养老安全网。 而AI健康小屋作为一项创新…

命令行指引的尝试

效果 步骤 首先初始化一个空的项目,然后安装一些依赖 npm init -y npm install inquirer execa chalk ora至于这些依赖是干嘛的,如下图所示: 然后再 package.json 中补充一个 bin 然后再根目录下新建一个 index.js , 其中的内容如下 #!/…

【Dify系列教程重置精品版】第1课 相关概念介绍

文章目录 一、Dify是什么二、Dify有什么用三、如何玩转Dify?从螺丝刀到机甲战士的进阶指南官方网站:https://dify.ai github地址:https://github.com/langgenius/dify 一、Dify是什么 Dify(D​​efine + ​​I​​mplement + ​​F​​or ​​Y​​ou)。这是一款开源的大…

leetcode0106. 从中序与后序遍历序列构造二叉树-medium

1 题目:从中序与后序遍历序列构造二叉树 官方标定难度:中 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入…

Spring Boot默认缓存管理

Spring框架支持透明地向应用程序添加缓存,以及对缓存进行管理,其管理缓存的核心是将缓存应用于操作数据的方法,从而减少操作数据的执行次数,同时不会对程序本身造成任何干扰。Spring Boot继承了Spring框架的缓存管理功能&#xff…

XYNU2024信安杯-REVERSE(复现)

前言 记录记录 1.Can_you_find_me? 签到题,秒了 2.ea_re 快速定位 int __cdecl main_0(int argc, const char **argv, const char **envp) {int v4; // [esp0h] [ebp-1A0h]const char **v5; // [esp4h] [ebp-19Ch]const char **v6; // [esp8h] [ebp-198h]char v7;…

MySQL的MVCC【学习笔记】

MVCC 事务的隔离级别分为四种,其中Read Committed和Repeatable Read隔离级别,部分实现就是通过MVCC(Multi-Version Concurrency Control,多版本并发控制) 版本链 版本链是通过undo日志实现的, 事务每次修改…

达梦数据库压力测试报错超出全局hash join空间,适当增加HJ_BUF_GLOBAL_SIZE解决

1.名词解释:达梦数据库中的HJ_BUF_GLOBAL_SIZE是所有哈希连接操作可用的最大哈希缓冲区大小,单位为兆字节(MB) 2.达梦压测报错: 3.找到达梦数据库安装文件 4.压力测试脚本 import http.client import multiprocessi…

Oracle--SQL性能优化与提升策略

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 一、导致性能问题的内在原因 系统性能问题的底层原因主要有三个方面: CPU占用率过高导致资源争用和等待内存使用率过高导致内存不足并需…

六个能够白嫖学习资料的网站

一、咖喱君的资源库 地址:https://flowus.cn/galijun/share/de0f6d2f-df17-4075-86ed-ebead0394a77 这是一个学习资料/学习网站分享平台,包含了英语、法语、德语、韩语、日语、泰语等几十种外国语言的学习资料及平台,这个网站的优势就是外语…

IntelliJ IDEA 中配置 Spring MVC 环境的详细步骤

以下是在 IntelliJ IDEA 中配置 Spring MVC 环境的详细步骤: 步骤 1:创建 Maven Web 项目 新建项目 File -> New -> Project → 选择 Maven → 勾选 Create from archetype → 选择 maven-archetype-webapp。输入 GroupId(如 com.examp…

手机打电话时电脑坐席同时收听对方说话并插入IVR预录声音片段

手机打电话时电脑坐席同时收听对方说话并插入IVR预录声音片段 --本地AI电话机器人 前言 书接上一篇,《手机打电话通话时如何向对方播放录制的IVR引导词声音》中介绍了【蓝牙电话SDK示例App】可以实现手机app在电话通话过程中插播预先录制的开场白等语音片段的功能。…

SpringCloud——负载均衡

一.负载均衡 1.问题提出 上一篇文章写了服务注册和服务发现的相关内容。这里再提出一个新问题,如果我给一个服务开了多个端口,这几个端口都可以访问服务。 例如,在上一篇文章的基础上,我又新开了9091和9092端口,现在…

string的基本使用

string的模拟实现 string的基本用法string的遍历(三种方式):关于auto(自动推导):范围for: 迭代器普通迭代器(可读可改)const迭代器(可读不可改) string细小知识点string的常见接口引…

深入解析Mlivus Cloud核心架构:rootcoord组件的最佳实践与调优指南

作为大禹智库的向量数据库高级研究员,同时也是《向量数据库指南》的作者,我在过去30年的向量数据库和AI应用实战中见证了这项技术的演进与革新。今天,我将以专业视角为您深入剖析Mlivus Cloud的核心组件之一——rootcoord,这个组件在系统架构中扮演着至关重要的角色。如果您…

Python常用的第三方模块之【pymysql库】操作数据库

pymysql是在Python3.x版本中用于连接MySQL服务器的一个实现库,Python2中则是使用musqldb。 PyMySQL 是一个纯 Python 实现的 MySQL 客户端库,它允许我们直接在 Python 中执行 SQL 语句并与 MySQL 数据库进行交互。下面我们将详细介绍如何使用 PyMySQL 进…