训练自定义游戏,构建Gymnasium训练环境
认识Gymnasium使用stable_baseline3只需要定义好Gymnasium环境关注训练的奖励机制将重点放在业务的开发上而不是复杂的算法。Gymnasium提供了几个核心的api方法功能返回值reset()将环境重置为初始状态开始新回合。obs, infostep(action)环境向前推进一步执行动作。obs, reward, terminated, truncated, inforender()可视化环境根据render_mode渲染图像或弹出窗口。视配置而定通常无或为np.arrayclose()释放环境资源关闭窗口、清理内存。无其中的各个返回值的含义observation(Object): 当前状态的描述。例如敌人玩家的位置玩家的状态等reward(Float): 上一步动作获得的奖励terminated(Bool): 是否由于任务逻辑结束。例如到达终点、掉进岩浆等truncated(Bool): 是否由于外部限制结束。例如达到最大步数 500 步info(Dict): 辅助诊断信息模型训练通常不用用于用户自定义调试或记录额外统计。手动构建环境案例案例描述利用pygame构建一个简单的游戏躲避掉落方块利用构建的奖励机制进行强化学习。import gymnasium as gym from gymnasium import spaces import numpy as np import pygame import random import cv2 import os from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.env_checker import check_env class MyEnv(gym.Env): def __init__(self, render_modeNone): super(MyEnv, self).__init__() #初始化参数 self.width 400 self.height 300 self.player_size 30 self.enemy_size 30 self.render_mode render_mode self.action_space spaces.Discrete(3) self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 ) pygame.init() if self.render_mode human: self.screen pygame.display.set_mode((self.width, self.height)) self.canvas pygame.Surface((self.width, self.height)) self.font pygame.font.SysFont(monospace, 15) def reset(self, seedNone, optionsNone): super().reset(seedseed) self.player_x self.width // 2 - self.player_size // 2 self.player_y self.height - self.player_size - 10 self.enemies [] self.score 0 self.frame_count 0 self.current_speed 5 self.spawn_rate 30 return self._get_obs(), {} def step(self, action): reward 0 terminated False truncated False move_speed 8 if action 1 and self.player_x 0: # self.player_x - move_speed reward - 0.05 if action 2 and self.player_x self.width - self.player_size: self.player_x move_speed reward - 0.05 self.frame_count 1 level self.score // 5 self.current_speed 5 level self.spawn_rate 30 - level * 2 spawn_rate max(10, 30 - level) if self.frame_count spawn_rate: self.frame_count 0 enemy_x random.randint(0, self.width - self.enemy_size) self.enemies.append([enemy_x, 0]) # [x, y] for enemy in self.enemies: enemy[1] self.current_speed player_rect pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size) enemy_rect pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size) if player_rect.colliderect(enemy_rect): reward -10 terminated True elif enemy[1] self.height: self.enemies.remove(enemy) self.score 1 reward 1 if not terminated: if self.score 100: reward 0.01 reward 0.01 obs self._get_obs() if self.render_mode human: self._render_window() return obs, reward, terminated, truncated, {} def _get_obs(self): self.canvas.fill((0, 0, 0)) pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size)) for enemy in self.enemies: pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size)) img_array pygame.surfarray.array3d(self.canvas) img_array np.transpose(img_array, (1, 0, 2)) obs cv2.resize(img_array, (84, 84), interpolationcv2.INTER_AREA) return obs.astype(np.uint8) def _render_window(self): self.screen.blit(self.canvas, (0, 0)) text self.font.render(fScore: {self.score}, True, (255, 255, 255)) self.screen.blit(text, (10, 10)) pygame.display.flip() for event in pygame.event.get(): if event.type pygame.QUIT: pygame.quit() def train(): log_dir logs/DodgeGame os.makedirs(log_dir, exist_okTrue) env MyEnv() check_env(env) print(环境检查通过...) model_path models/dodge_ai.zip if not os.path.exists(model_path): print( 未发现旧模型从头开始训练...) model PPO( CnnPolicy, env, verbose1, tensorboard_loglog_dir, learning_rate0.0001, n_steps4096, batch_size256, devicecuda) reset_timesteps True else: print(发现旧模型加载并继续训练...) model PPO.load( model_path, envenv, devicecuda, custom_objects{learning_rate: 0.0001, n_steps: 4096, batch_size: 256} ) reset_timesteps False print(开始训练...) model.learn( total_timesteps50000, reset_num_timestepsreset_timesteps ) model.save(models/dodge_ai) print(模型已保存) env.close() def prodict(): env MyEnv(render_modehuman) model PPO.load(models/dodge_ai, envenv, devicecuda) obs, _ env.reset() while True: action, _states model.predict(obs, deterministicTrue) obs, reward, terminated, truncated, info env.step(action) if terminated or truncated: obs, _ env.reset() pygame.time.Clock().tick(30) if __name__ __main__: train() prodict()代码解析代码流程如下构建游戏环境-训练模型-模型预测本篇重点讲构建游戏环境其中的pygame相关代码简略另外两个流程参考之前文章。构建游戏环境初始化类该类继承gym.Env类class MyEnv(gym.Env):构造函数__init__def __init__(self, render_modeNone): super(MyEnv, self).__init__() #初始化参数 self.width 400 self.height 300 self.player_size 30 self.enemy_size 30 self.render_mode render_mode self.action_space spaces.Discrete(3) self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 ) pygame.init() if self.render_mode human: self.screen pygame.display.set_mode((self.width, self.height)) self.canvas pygame.Surface((self.width, self.height)) self.font pygame.font.SysFont(monospace, 15)在构造函数中我们主要完成的是声明训练的维度和输入输入self.action_space spaces.Discrete(3)其中的self.action_space是固定名称的父类变量。spaces.Discrete(3)声明输入的数量例如向左 向右 和 不动3个输入。观测维度self.observation_space也是固定名称的父类变量。spaces.Box声明观测维度。self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 )low观测参数的最小值high观测参数的最大值shape声明维度。例如观测图片shape(高宽RGB)观测一个平面shape(高,宽)dtype每个变量类型这里选np.uint8能够节省训练成本默认是浮点型的。任务重置 reset相当于初始化游戏状态游戏的重新开始。返回的是观测值和状态信息用于调试日志def reset(self, seedNone, optionsNone): super().reset(seedseed) self.player_x self.width // 2 - self.player_size // 2 self.player_y self.height - self.player_size - 10 self.enemies [] self.score 0 self.frame_count 0 self.current_speed 5 self.spawn_rate 30 return self._get_obs(), {}观测值_get_obs通过pygame画出的画面然后用opencv进行简单处理转换坐标轴由于opencv坐标xy轴跟pygame的xy是颠倒的将画面缩放到84 * 84可以提高训练效率def _get_obs(self): self.canvas.fill((0, 0, 0)) pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size)) for enemy in self.enemies: pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size)) img_array pygame.surfarray.array3d(self.canvas) img_array np.transpose(img_array, (1, 0, 2)) obs cv2.resize(img_array, (84, 84), interpolationcv2.INTER_AREA) return obs.astype(np.uint8)步 step重要这个函数是强化训练的核心规定了在一帧或者一步我们给AI的分数。分数的设置至关重要这直接决定了训练出来AI的质量根据下面代码大部分都是游戏逻辑主要讲设置奖励分数在AI进行移动时 惩罚 0.05 分在AI存活时 奖励 0.01分游戏分数大于100时 存活奖励 0.02分在障碍物完全下落时 奖励 1 分在与障碍物碰撞时 惩罚 10 分def step(self, action): reward 0 terminated False truncated False move_speed 8 if action 1 and self.player_x 0: # self.player_x - move_speed reward - 0.05 if action 2 and self.player_x self.width - self.player_size: self.player_x move_speed reward - 0.05 self.frame_count 1 level self.score // 5 self.current_speed 5 level self.spawn_rate 30 - level * 2 spawn_rate max(10, 30 - level) if self.frame_count spawn_rate: self.frame_count 0 enemy_x random.randint(0, self.width - self.enemy_size) self.enemies.append([enemy_x, 0]) # [x, y] for enemy in self.enemies: enemy[1] self.current_speed player_rect pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size) enemy_rect pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size) if player_rect.colliderect(enemy_rect): reward -10 terminated True elif enemy[1] self.height: self.enemies.remove(enemy) self.score 1 reward 1 if not terminated: if self.score 100: reward 0.01 reward 0.01 obs self._get_obs() if self.render_mode human: self._render_window() return obs, reward, terminated, truncated, {}展示游戏画面下面完全是pygame代码用于显示游戏画面这里就不解释了。def _render_window(self): self.screen.blit(self.canvas, (0, 0)) text self.font.render(fScore: {self.score}, True, (255, 255, 255)) self.screen.blit(text, (10, 10)) pygame.display.flip() for event in pygame.event.get(): if event.type pygame.QUIT: pygame.quit()
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2476685.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!