【强化学习】TensorFlow2实现DQN(处理CartPole问题)

news2025/8/17 2:26:40

文章目录

  • 1. 情景介绍
  • 2. DQN(Deep Q Network)核心思路:
  • 3. DQN算法流程
  • 4. 代码实现以及注释
  • 5. 实验结果

文章阅读预备知识:Q Learning算法的基本流程、TensorFlow2多层感知机的实现。

1. 情景介绍

CartPole问题:黑色小车上面支撑的一个连接杆,连杆会自由摆动,我们需要控制黑色小车,通过控制小车左右移动,保持连杆的平衡。
在这里插入图片描述
该问题的动作空间是离散的且有限的,只有两种执行动作(0或1),但是该问题的状态空间是一个连续空间,且每个状态是一个四维向量。

执行动作:

### Action Space

The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction of the fixed force the cart is pushed with.
| Num | Action                 |
|-----|------------------------|
| 0   | Push cart to the left  |
| 1   | Push cart to the right |

状态空间:

### Observation Space

The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:

| Num | Observation           | Min                  | Max                |
|-----|-----------------------|----------------------|--------------------|
| 0   | Cart Position         | -4.8                 | 4.8                |
| 1   | Cart Velocity         | -Inf                 | Inf                |
| 2   | Pole Angle            | ~ -0.418 rad (-24°)  | ~ 0.418 rad (24°)  |
| 3   | Pole Angular Velocity | -Inf                 | Inf                |

**Note:** While the ranges above denote the possible values for observation space of each element, it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
-  The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates if the cart leaves the `(-2.4, 2.4)` range.
-  The pole angle can be observed between  `(-.418, .418)` radians (or **±24°**), but the episode terminates if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

2. DQN(Deep Q Network)核心思路:

  • 因传统的Q Learning、Sarsa算法不适合处理状态空间和动作空间是连续空间的问题,因此 使用深度学习神经网络表示Q函数(代替Q表),训练的数据是状态s,训练的标签是状态s对应的每个动作的Q值,即标签是由Q值组成的向量,向量的长度与动作空间的长度相同。
  • Q 值的更新与Q Learning 算法相同。
  • 动作选择的算法使用 ϵ \epsilon ϵ-贪婪算法,其中 ϵ \epsilon ϵ可以是静态的也可以随时间设置动态变化。
  • 定义一段记忆体(经验回放池、Replay Memory),在记忆体中保存具体某一时刻的当前状态、奖励、动作、迁移到下一个状态、状态是否结束等信息,定期冲记忆体中随机选择固定大小的一段记忆训练神经网络。

3. DQN算法流程

在这里插入图片描述
论文地址: Playing Atari with Deep Reinforcement Learning(https://arxiv.org/pdf/1312.5602.pdf)

4. 代码实现以及注释

版本信息

  • Python:3.7.0
  • TensorFlow: 2.5
  • gym:0.23.1
# -*- coding: utf-8 -*-
import random
import gym  # 版本0.23.1
import numpy as np
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

EPISODES = 1000

class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000) # 记忆体使用队列实现,队列满后根据插入顺序自动删除老数据
        self.gamma = 0.95    # discount rate
        self.epsilon = 0.4  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = self._build_model()
        # 可视化MLP结构
        # plot_model(self.model, to_file='dqn-cartpole-v0-mlp.png', show_shapes=False)

    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        model = Sequential() # 顺序模型,搭建神经网络(多层感知机)
        model.add(Dense(24, input_dim=self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate)) # 指定损失函数以及优化器
        return model

    # 在记忆体(经验回放池)中保存具体某一时刻的当前状态信息
    def remember(self, state, action, reward, next_state, done):
        # 当前状态、动作、奖励、下一个状态、是否结束
        self.memory.append((state, action, reward, next_state, done))

    # 根据模型预测结果返回动作
    def act(self, state):
        if np.random.rand() <= self.epsilon: # 如果随机数(0-1之间)小于epsilon,则随机返回一个动作
            return random.randrange(self.action_size) # 随机返回动作0或1
        act_values = self.model.predict(state) # eg:[[0.35821578 0.11153378]]
        # print("model.predict act_values:",act_values)
        return np.argmax(act_values[0])  # returns action 返回价值最大的

    # 记忆回放,训练神经网络模型
    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done: # 没有结束
                target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs=1, verbose=0) # 训练神经网络

    # 加载模型权重文件
    def load(self, name):
        self.model.load_weights(name)
        
    # 保存模型 (参数:filepath)
    def save(self, name):
        self.model.save_weights(name)


if __name__ == "__main__":
    env = gym.make('CartPole-v0')

    print(env.action_space)
    print(env.observation_space)

    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    print("state_size:",state_size) # 4
    print("action_size:",action_size) # 2

    agent = DQNAgent(state_size, action_size)

    done = False
    batch_size = 32
    avg=0

    for e in range(EPISODES): # 循环学习次数,每次学习都需要初始化环境
        state = env.reset() # 环境初始化,返回state例如[-0.1240581  -1.3752123   0.18474717  2.2276523 ]
        state = np.reshape(state, [1, state_size]) # 扩展维度(用于神经网络训练) # [[-0.1240581  -1.3752123   0.18474717  2.2276523 ]]
        for time in range(500): # 每次学习的步长为500
            env.render() # 渲染可视化图像
            # print(state)
            action = agent.act(state) # 根据模型预测结果返回动作
            # print(action) # 0或者1
            next_state, reward, done, _ = env.step(action) # 返回下一个状态、奖励、以及是否结束游戏(当摆杆出界或倾斜浮动等状态信息不符要求或步长大于内置值时结束游戏)
            reward = reward if not done else -10  # 结束游戏时,设置奖励为-10
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done) # 放入记忆体
            state = next_state
            if done:
                print("episode: {}/{}, score(time): {}" .format(e, EPISODES, time))
                avg += time
                break
        # 定期检查记忆大小,进行记忆回放
        if len(agent.memory) > batch_size:
            agent.replay(batch_size)
    print("Avg score:{}".format(avg/1000))

5. 实验结果

前期,智能体(Agent)控制小车移动只能玩10秒左右

episode: 0/1000, score(time): 8
episode: 1/1000, score(time): 10
episode: 2/1000, score(time): 8
episode: 3/1000, score(time): 14
episode: 4/1000, score(time): 8
episode: 5/1000, score(time): 10
episode: 6/1000, score(time): 9
episode: 7/1000, score(time): 11
episode: 8/1000, score(time): 12
episode: 9/1000, score(time): 9
episode: 10/1000, score(time): 15
episode: 11/1000, score(time): 9
episode: 12/1000, score(time): 13
episode: 13/1000, score(time): 11
episode: 14/1000, score(time): 9
episode: 15/1000, score(time): 12
episode: 16/1000, score(time): 9
episode: 17/1000, score(time): 11

通过神经网络模型的不断训练…

episode: 267/1000, score(time): 155
episode: 268/1000, score(time): 188
episode: 269/1000, score(time): 100
episode: 270/1000, score(time): 136
episode: 271/1000, score(time): 126
episode: 272/1000, score(time): 155
episode: 273/1000, score(time): 179
episode: 274/1000, score(time): 104
episode: 275/1000, score(time): 111
episode: 276/1000, score(time): 199
episode: 277/1000, score(time): 128
episode: 278/1000, score(time): 199

可以看到智能体(Agent)的游戏水平不断提高

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/14830.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【并发编程五】c++进程通信——信号量(semaphore)

【并发编程五】c进程通信——信号量&#xff08;semaphore&#xff09;一、概述二、信号量三、原理四、过程1、进程A过程2、进程B过程五、demo1、进程A2、进程B六、输出七、windows api介绍1. 创建信号量 CreateSemaphore()2. 打开信号量 OpenSemaphore()3. 等待 WaitForSingle…

一种基于IO口的模拟串口(LOG)实现方法

一、使用背景 当MCU的串口不够用时&#xff0c;可以通过IO模拟的方式将任意一个具有输出功能的管脚配置为串口输出&#xff0c;从而方便开发和调试。 二、实现原理 通过IO口模拟串口发送波形&#xff0c;配置对应的波特率等信息&#xff0c;然后映射printf函数&#xff0c;从…

基于粒子群优化算法的冷热电联供型综合能源系统运行优化(Matlab代码实现)

&#x1f468;‍&#x1f393;个人主页&#xff1a;研学社的博客 &#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜…

redis 支持的数据类型

Redis 数据库支持五种数据类型。 字符串&#xff08;string&#xff09; 哈希&#xff08;hash&#xff09; 列表&#xff08;list&#xff09; 集合&#xff08;set&#xff09; 有序集合&#xff08;sorted set&#xff09; 位图 ( Bitmaps ) 基数统计 ( HyperLogLogs ) 字…

Vue3.2 + Element-Plus 二次封装 el-table(Pro版)

前言 &#x1f4d6; ProTable 组件目前已是 2.0版本&#x1f308;&#xff0c;在 1.0版本 中大家提出的问题与功能优化&#xff0c;目前已经得到优化和解决。 &#x1f600; 欢迎大家在使用过程中发现任何问题或更好的想法&#xff0c;都可以在下方评论区留言&#xff0c;或者我…

【计算机网络】局域网体系结构、以太网Ethernet详解

注&#xff1a;最后有面试挑战&#xff0c;看看自己掌握了吗 文章目录局域网LAN决定局域网的要素网络拓扑传输介质局域网的分类以太网令牌环网FDDI网----Fiber Distributed Data InterfaceATM网---Asynchronous Transfer Mode无线局域网WLAN----Wireless Local Area NetworkMAC…

Red Hat Enterprise Linux (RHEL) 9 更新了哪些新特性?

文章目录1. 前言2. 软件3. 支持的硬件架构4. GNOME更新到40版5. 安全和身份6. 构建容器的通用基础镜像7. 改进了用于管理 RHEL 9 的 Cockpit Web 控制台1. 前言 体验一下最新的rhel 9.0 是什么感觉。它会飞吗&#xff1f; Red Hat Enterprise Linux (RHEL) 9现已普遍可用 (GA…

吃柿子的禁忌靠谱吗?

图片来源&#xff1a;pixabay 秋冬是柿子上市的季节&#xff0c;虽然柿子并不是苹果、香蕉这样的大宗水果&#xff0c;但是秋天不吃个柿子&#xff0c;冬天不吃个柿饼&#xff0c;总觉得少了点什么。 关于吃柿子有很多禁忌&#xff0c;比如说柿子不能与螃蟹同时吃&#xff0c;柿…

​怎么保留硬盘数据合并分区 ,如何才能合并且不丢失数据

硬盘分区合并是比较常见的操作&#xff0c;​怎么保留硬盘数据合并分区&#xff0c;还是具有一定的难度。因为在Windows操作系统中&#xff0c;用户可以通过磁盘管理来实现硬盘分区合并&#xff0c;但是要删除该磁盘分区右侧的相邻分区&#xff0c;但是对于部分不懂计算机的用户…

Tailscale的子网路由和出口节点

2 年前&#xff0c;老苏写了 『 外网访问群晖的新方案Tailscale 』&#xff0c;第一次隆重的给大家推荐了 Tailscale&#xff0c;但当时还有很多功能并不具备&#xff0c;比如今天要介绍的 Subnet Router 和 Exit Node 【特别说明】&#xff1a;老苏使用的是DSM6 &#xff0c;所…

RabbitMQ初步到精通-第一章-消息中间件介绍

第一章 消息中间件介绍 1.MQ概述 MQ全称是Message Queue&#xff0c;消息的队列&#xff0c;因为是队列&#xff0c;所以遵循FIFO 先进先出的原则&#xff0c;它是一种跨进程的通信机制&#xff0c;用于上下游传递消息。 在互联网架构中&#xff0c;MQ是一种非常常见的上下游“…

论文阅读笔记 | 三维目标检测——VeloFCN算法

如有错误&#xff0c;恳请指出。 文章目录paper&#xff1a;《Vehicle Detection from 3D Lidar Using Fully Convolutional Network》 对于64线激光雷达全范围扫描出来的点云进行特征图的构建。对于具体的点&#xff08;xyz坐标&#xff09;&#xff0c;其在水平方向上可以通…

一个是证书服务和web安全访问配置,一个是PGP的使用

一个是证书服务和web安全访问配置&#xff0c;一个是PGP的使用 IIS介绍 IIS是本机自带的服务&#xff0c;用于上线web网页&#xff1b;虽然是自带但因为非开发人员用不到&#xff0c;所以属于预安装&#xff1b;在本机搜索下载即可&#xff0c; 打开后 证书服务&#xff0c;认…

LeetCode[105]从前序与中序遍历序列构造二叉树

难度&#xff1a;中等 题目&#xff1a; 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示例 1: 输入: preorder [3,9,20,15,7], inorder [9,3,1…

Vue基础4

Vue基础4计算属性姓名案例 - 第一种用click.keyup的方法姓名案例 - 第二种用v-model双向绑定的方法姓名案例 - 第三种使用methods方法姓名案例 - 第四种使用计算属性的方法计算属性的简写—只考虑读取&#xff0c;不考虑修改时候使用监视属性第一种普通写法第二种用计算属性的写…

【信号处理】卡尔曼(Kalman)滤波(Matlab代码实现)

&#x1f468;‍&#x1f393;个人主页&#xff1a;研学社的博客 &#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜…

Java内部类分类

文章目录内部类分类局部内部类的使用匿名内部类成员内部类静态内部类一个类的内部又完整的嵌套了另一个类结构。被嵌套的类称为内部类(inner class),嵌套其他类的类称为外部类(outer class)。是我们类的第五大成员 思考:类的五大成员是哪些? - 属性、方法、构造器、代码块、内…

Windows安装Git教程(2022.11.18 Git2.38.1)

&#xff08;1&#xff09;首先前往Git官网&#xff0c;下载安装文件&#xff1a; &#xff08;2&#xff09;打开安装程序&#xff0c;把Only show new options的勾去掉&#xff0c;点击Next&#xff1a; &#xff08;3&#xff09;此处可以选用默认设置&#xff0c;也可以勾…

ProCAST一键导出有限元模型的几何拓扑和属性信息

第一次将ProCast有限元后处理中的数据导出&#xff0c;当时没有经验&#xff0c;方法比较粗暴&#xff0c;详情见文章&#xff1a;ProCast导出节点应力数据并格式化。 最近发现了一种更高效的数据导出“新姿势”&#xff0c;能够快速得到有限元模型的几何拓扑和节点属性数据&a…

电科大离散数学-2-命题逻辑-2

目录 2.7 范式 2.7.1 范式的定义 2.7.2 范式存在定理 2.8 主析取范式和主合取范式 2.8.1 极小项和极大项的定义和编码 2.8.2 极小项和极大项的性质 2.8.3 主析取范式和主合取范式的定义 2.8.4 主范式求解定理 2.8.5 真值表技术 2.8.6 范式的相互转化 2.8.7 主范式的…