可视化脚本包括了两个方法:远程下载 huggingface 上的数据集和使用本地数据集
脚本主要使用两个:
目前来说,ACT 采集训练用的是统一时间长度的数据集,此外,这两个脚本最大的问题在于不能裁剪,这也是比较好的升级方向;
目录
1 可视化运行
1.1 远程 html
1.2 本地数据集
2 代码详解 visualize_dataset_html.py
2.1 综述
2.2 流程概览
2.3 库引用
2.4 mian() 函数
2.5 关键函数
2.5.1 run_server() —— Flask 应用核心
2.5.2 get_ep_csv_fname(episode_id)
2.5.3 get_episode_data()
2.5.4 get_episode_video_paths
2.5.5 get_episode_language_instruction
2.5.6 get_dataset_info(repo_id)
2.5.7 visualize_dataset_html
3 代码详解 visualize_dataset.py
3.1 综述
3.2 流程概览
3.3 库引用
3.4 mian() 函数
3.5 关键函数
3.5.1 采样器(EpisodeSampler)
3.5.2 图像转换(to_hwc_uint8_numpy)
3.5.3 核心可视化函数(visualize_dataset())
1 可视化运行
1.1 远程 html
对于开源数据集,只需要在 huggingface 上查看 id,比如 aloha_static_coffee 这个:
点进去选择 use this dataset,可以看到id
然后运行脚本:
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/aloha_static_coffee
下载数据集后生成 web browser:http://127.0.0.1:9090
可以看到采集的各类信息:
可以看到结果保存地址:
其中,下载的数据集默认存储在了 /home/yejiangchen/.cache/huggingface/lerobot/lerobot
下次再运行会直接调用无需下载
此外,如果想运行本地数据集,则需要指定 --root:
python lerobot/scripts/visualize_dataset_html.py \
--root /home/yejiangchen/.cache/huggingface/lerobot/lerobot/aloha_static_coffee \
--repo-id lerobot/aloha_static_coffee
即可正常运行:
1.2 本地数据集
本地的话可以直接使用 visualize_dataset.py 脚本,测试一下之前下载的数据
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/aloha_static_coffee \
--episode-index 0
2 代码详解 visualize_dataset_html.py
2.1 综述
此脚本将 LeRobotDataset 中的视频+时序传感数据(动作、状态等)渲染成交互式网页,方便快速浏览与排查
-
视频:在浏览器原生 <video> 标签播放
-
时序数值:转成 CSV 字符串,交给前端 Dygraphs JavaScript 库即刻绘制折线图
-
语言任务描述:展示在同一页面
-
部署:内置 Flask 服务器(默认 127.0.0.1:9090)即可本地或经 SSH‑tunnel 远程查看
2.2 流程概览
main() -> visualize_dataset_html() ->(配置、软链接)-> run_server() ->(HTTP 请求)-> get_dataset_info()、get_episode_data()、get_episode_language_instruction() 等
main()
└─ 解析 CLI 参数
└─ (可选)加载本地/远程数据集 → LeRobotDataset 或 IterableNamespace
└─ visualize_dataset_html()
├─ 创建/复用输出目录(含模板与静态文件)
├─ (本地数据集)软链接视频到 static/videos
└─ run_server() ←– 关键:注册所有 Flask 路由
├─ "/" : 首页 / 数据集选择页
├─ "/<ns>/<name>" : 自动跳到 episode_0
└─ "/<ns>/<name>/episode_<id>" : 主可视化页面
2.3 库引用
import argparse # 用于解析命令行参数
import csv # 用于生成 CSV 格式字符串
import json # 用于解析和生成 JSON 数据
import logging # 用于日志记录
import re # 用于正则表达式处理
import shutil # 用于文件和目录操作,如复制、删除
import tempfile # 用于创建临时目录
from io import StringIO # 用于将字符串当作文件读写
from pathlib import Path # 用于跨平台路径操作
import numpy as np # 数值计算库
import pandas as pd # 数据处理库
import requests # 用于发起 HTTP 请求
from flask import Flask, redirect, render_template, request, url_for # Flask Web 框架核心组件
from lerobot import available_datasets # 导入可用数据集列表
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset # LeRobotDataset 类
from lerobot.common.datasets.utils import IterableNamespace # 简单 namespace 类型
from lerobot.common.utils.utils import init_logging # 初始化日志设置
2.4 mian() 函数
作为脚本入口,负责解析所有命令行参数并据此准备数据集实例,最后调用 visualize_dataset_html() 启动可视化流程
核心流程:
- 用 argparse 定义并读取参数(如 --repo-id、--root、--episodes、--serve 等)
- 根据 --load-from-hf-hub 决定是实例化完整的 LeRobotDataset(加载本地/缓存数据与视频),还是只拉取元信息 (get_dataset_info)
- 将解析好的 dataset 对象与其它参数传入 visualize_dataset_html()
参数 | 作用 | 典型值 |
---|---|---|
--repo-id | HF Hub 上的数据集 namespace/name | lerobot/pusht |
--root | 本地数据集根目录 | ./data |
--load-from-hf-hub | 整数;为 1 时只下拉 meta / parquet / mp4,不构造完整 LeRobotDataset | 0/1 |
--episodes | 想看的 episode 索引列表 | 0 3 5 |
--host , --port | Flask 服务地址 | 默认 127.0.0.1:9090 |
--tolerance-s | 时间戳容差,保证 fps 一致性 | 1e-4 |
def main():
# 入口:解析命令行并调用可视化函数
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`).",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6`).",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
"If not given, defaults to 1e-4."
),
)
args = parser.parse_args() # 解析命令行参数
kwargs = vars(args)
repo_id = kwargs.pop("repo-id") # 获取 repo-id 并从 kwargs 删除
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
dataset = None
if repo_id:
# 根据 load_from_hf_hub 决定实例化 LeRobotDataset 还是只读 meta
dataset = (
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args)) # 调用主可视化入口
if __name__ == "__main__":
# 脚本直接运行时进入 main
main()
2.5 关键函数
2.5.1 run_server() —— Flask 应用核心
全局配置:app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # 每次刷新都拉最新资源
路由配置:
路由 | 功能 |
---|---|
/ | • 如果脚本在**“单数据集模式”**(已传 dataset ),立刻重定向到 episode 0• 否则渲染首页,列出推荐 ( featured_datasets ) + 全部可用数据集 (lerobot_datasets ) |
/<ns>/<name> | 纯跳转:把 <dataset>/episode_0 作为入口 |
/<ns>/<name>/episode_<id> | 主工作函数: 1.若脚本启动时没载数据,就动态 get_dataset_info() 2.检查数据集版本 <2 则拒绝(旧格式) 3.调用 get_episode_data() → CSV + 列信息;拼装 Video‑URL / Tasks‑Text4.把所有信息喂给 visualize_dataset_template.html 渲染 |
def run_server(
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
"""
启动 Flask HTTP 服务,渲染可视化页面。
参数:
- dataset: 已加载的数据集实例或 None
- episodes: 要展示的 episode 列表或 None
- host, port: 服务监听地址与端口
- static_folder: 静态文件目录(视频、JS、CSS)
- template_folder: Jinja2 模板目录
"""
app = Flask(
__name__,
static_folder=static_folder.resolve(), # 静态资源路径
template_folder=template_folder.resolve() # 模板文件路径
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # 禁用浏览器缓存,确保每次都拉最新的资源
@app.route("/")
def hommepage(dataset=dataset):
"""应用根路由:根据有无 dataset 参数决定重定向或渲染选择页"""
if dataset:
# 如果在脚本启动时传入 dataset,直接跳转到第 0 集
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
# 否则尝试从 query 参数读取 dataset & episode 再跳转
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
# 默认渲染首页,列出 featured + 全部 available datasets
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
2.5.2 get_ep_csv_fname(episode_id)
简单工具,按约定返回某集 CSV 文件名 episode_{id}.csv
def get_ep_csv_fname(episode_id: int):
# 根据 episode 索引构造 CSV 文件名
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
2.5.3 get_episode_data()
把单个 episode 的多通道数值数据 -> 二维列表 -> CSV 字符串(返给前端 JS)
1. 列挑选
selected = [col for col, ft in ds.features.items()
if ft["dtype"] in ["float32", "int32"]]
selected.remove("timestamp")
2. 过滤高维张量:shape 维度 > 1 的列记入 ignored_columns,避免动态图崩溃
3. 列名展开:如果在 meta 里有 names 用定义好的;否则按 col_0 … col_n 生成
4. 取数据:本地 LeRobotDataset 利用 .episode_data_index 截取 parquet;Hub‑Only 直接 pd.read_parquet(url)
5. 转换为 CSV string(StringIO + csv.writer)
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""
获取 episode 的时序数据,并将其转换为 CSV 字符串返回。
Returns:
- csv_string: CSV 格式的整个 episode 数据
- columns: [{key: 原始列名, value: 展开后子列名列表}, ...]
- ignored_columns: 被忽略的高维列名称列表
"""
columns = [] # 存储展开后列的信息
# 选出所有 dtype 为 float32/int32 的数值列
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp") # timestamp 先单独处理
ignored_columns = [] # 高维列名称
for column_name in selected_columns:
shape = dataset.features[column_name]["shape"] # 列的原始 shape
shape_dim = len(shape)
if shape_dim > 1:
# 如果维度 >1,则忽略,不支持 Dygraph 绘多维张量
selected_columns.remove(column_name)
ignored_columns.append(column_name)
# CSV header: timestamp + 各子列名
header = ["timestamp"]
# 遍历每个一维列,展开成多列子名称
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
# 如果 meta 中定义了 names,则使用自定义子列名
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
# 否则按 col_0...col_n 展开
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names # 累加到 CSV header
# timestamp 放回最前
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
# 本地模式:根据 index 范围 select pandas DataFrame
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("pandas")
)
else:
# 远程模式:通过 HTTP 拉取 parquet,然后筛列
repo_id = dataset.repo_id
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index
)
)
df = pd.read_parquet(url)
data = df[selected_columns]
# 构造 numpy 二维数组:首列 timestamp,其余为各子列值
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# 写 CSV 到内存字符串
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
csv_writer.writerow(header)
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns, ignored_columns
2.5.4 get_episode_video_paths
仅在本地 LeRobotDataset 场景下,获取指定 episode 在底层 HF 数据集中的视频文件路径列表(内部没用到,备用)
- 找到该集第一帧在整表中的行索引
- 针对每个 dataset.meta.video_keys,在对应列读取 ["path"] 字段
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# hack: 取该 episode 第一帧索引以定位 video path
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.meta.video_keys
]
2.5.5 get_episode_language_instruction
仅在数据集包含 language_instruction 特征时调用,从对应行抽取并清洗掉 Tensor 的包装字符串,返回指令文本
- 判断 dataset.features 是否存在 language_instruction
- 取该集第一帧索引,读取字段,去掉前后缀冗余信息
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# 如果数据集含 language_instruction 特征,则提取并清洗字符串
if "language_instruction" not in dataset.features:
return None
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# 去除 Tensor 格式冗余包装
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
2.5.6 get_dataset_info(repo_id)
远程数据辅助:拉 meta/info.json 并包成 IterableNamespace
额外用 episodes.jsonl 找每一集的 tasks 列表
def get_dataset_info(repo_id: str) -> IterableNamespace:
# 远程拉取 meta/info.json 并转为 IterableNamespace
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
timeout=5
)
response.raise_for_status()
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
2.5.7 visualize_dataset_html
搭建静态目录结构(HTML 模板 + 静态资源),并根据是否已有数据集对象决定是否创建视频软链接,最后根据 serve 标志调用 run_server()
- 调用 init_logging() 初始化日志设置
- 计算模板目录 templates,创建或清空(若 force_override)输出目录以及 static 子目录
- 若传入本地 LeRobotDataset,在 static/videos 下打软链接指向数据集的 videos 文件夹
- 若 serve 为真,调用 run_server() 启动 Flask 服务
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
# 主函数:准备静态目录 & 启动服务器
init_logging() # 配置根日志级别等
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
# 未指定输出目录时,创建临时目录
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir) # 强制覆盖时先删掉
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None:
# 仅在无本地 dataset 且 serve=True 时进入 run_server
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
# 本地数据集:在 static/videos 创建软链接到 dataset.root/videos
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
# 启动服务器
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
3 代码详解 visualize_dataset.py
3.1 综述
此脚本基于 Rerun SDK,实现对 LeRobotDataset 中单个 episode 进行可视化或记录,主要有三种模式:
- 本地交互模式(mode="local") 直接在当前机器弹出可视化窗口,用于快速调试与观测
- 远端服务模式(mode="distant") 在数据存放的远端机器上启动 WebSocket+HTTP 服务,本地通过 rerun ws://… 连接浏览
- 离线保存模式(--save 1) 将整次会话记录到一个 .rrd 文件,后续可通过 rerun path/to/file.rrd 离线回放
其中,脚本既能实时显示视频帧,也能同步绘制动作、状态、奖励等时序数值
3.2 流程概览
main()
├─ 解析 CLI 参数
├─ LeRobotDataset(repo_id, root, tolerance_s)
└─ visualize_dataset(...)
├─ EpisodeSampler(dataset, episode_index)
├─ DataLoader(dataset, sampler, batch_size, num_workers)
├─ rr.init(namespace, spawn=local_viewer?)
├─ gc.collect() # 避免多 worker 卡死
├─ (mode=="distant")? rr.serve(web_port, ws_port)
├─ for batch in DataLoader:
│ ├─ for each frame in batch:
│ │ ├─ rr.set_time_sequence/frame_index
│ │ ├─ rr.set_time_seconds/timestamp
│ │ ├─ rr.log(Image) for each camera
│ │ └─ rr.log(Scalar) for each numeric field
└─ 会话结束
├─ local+save → rr.save(.rrd)
└─ distant → 阻塞等待 Ctrl–C
3.3 库引用
import argparse # 解析命令行参数模块
import gc # 垃圾回收模块,用于手动触发回收
import logging # 日志记录模块
import time # 时间相关函数模块
from pathlib import Path # 跨平台路径操作
from typing import Iterator # 类型提示:迭代器
import numpy as np # 数值计算库
import rerun as rr # Rerun SDK,用于实时可视化
import torch # PyTorch 深度学习库
import torch.utils.data # PyTorch 数据加载工具
import tqdm # 进度条库
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset # 自定义 LeRobotDataset 数据集类
3.4 mian() 函数
- 强制要求:--repo-id(数据集标识)和 --episode-index(要可视化的集号)
- 可选:数据集根目录、DataLoader 配置(--batch-size、--num-workers)、模式切换(--mode、--save、--output-dir、--web-port、--ws-port)等
- 最终实例化 LeRobotDataset 并调用 visualize_dataset()
def main():
# 脚本入口:解析参数并调用可视化函数
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HF Hub 上数据集标识,例如 `lerobot/pusht`。",
)
parser.add_argument(
"--episode-index",
type=int,
required=True,
help="要可视化的 episode 索引。",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="本地数据集根目录,例如 `--root data`。默认使用 HuggingFace 缓存。",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="保存 .rrd 文件的目录,当 `--save 1` 时生效。",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="DataLoader 的 batch 大小。",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="DataLoader 的并行工作进程数。",
)
parser.add_argument(
"--mode",
type=str,
default="local",
help=(
"可视化模式:'local' 或 'distant'。`
"local` 会本地弹出 viewer;`
"distant` 则启动服务供远程浏览。"
),
)
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="`--mode distant` 时的 HTTP 服务端口。",
)
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="`--mode distant` 时的 WebSocket 服务端口。",
)
parser.add_argument(
"--save",
type=int,
default=0,
help=(
"是否保存为 .rrd 文件,启用后会禁用弹窗。"
"使用 `--output-dir path` 指定目录。"
),
)
parser.add_argument(
"--tolerance-s",
type=float,
default=1e-4,
help=(
"时间戳容差,保证与 fps 一致。"
"传入 LeRobotDataset 构造参数。"
),
)
args = parser.parse_args() # 解析命令行
kwargs = vars(args)
repo_id = kwargs.pop("repo_id") # 提取 repo_id
root = kwargs.pop("root") # 提取 root
tolerance_s = kwargs.pop("tolerance_s") # 提取容差参数
logging.info("Loading dataset") # 日志:开始加载数据集
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) # 构造数据集
visualize_dataset(dataset, **vars(args)) # 调用可视化主函数
if __name__ == "__main__":
# 如果脚本被直接执行,则运行 main()
main()
3.5 关键函数
3.5.1 采样器(EpisodeSampler)
只遍历指定 episode 在底层数据表(Parquet)中的帧索引范围,供 PyTorch DataLoader 使用
class EpisodeSampler(torch.utils.data.Sampler):
# 自定义数据采样器,仅遍历指定 episode 的帧索引
def __init__(self, dataset: LeRobotDataset, episode_index: int):
# 根据 episode_index 从 dataset 中获取起始和结束的全局帧索引
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
# 保存帧索引范围,用于 DataLoader 的 sampler
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
# 返回一个针对帧索引的迭代器
return iter(self.frame_ids)
def __len__(self) -> int:
# 返回此 sampler 的总采样数量(即帧数)
return len(self.frame_ids)
3.5.2 图像转换(to_hwc_uint8_numpy)
把 PyTorch 的 C×H×W 浮点图像张量(float32, 值域 [0,1])转换为 NumPy 的 H×W×C uint8 数组(值域 [0,255]),以便 Rerun 显示
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
# 将 C×H×W 的 float32 Torch 张量转换为 H×W×C 的 uint8 NumPy 数组
assert chw_float32_torch.dtype == torch.float32 # 确保数据类型为 float32
assert chw_float32_torch.ndim == 3 # 确保是 3 维
c, h, w = chw_float32_torch.shape # 解包通道、高度、宽度
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
# 先乘 255,再转 uint8,然后 permute 到 HWC,最后转换为 NumPy
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy # 返回处理后的图像数组
3.5.3 核心可视化函数(visualize_dataset())
1. 初始化
- 构造 DataLoader(dataset, sampler=EpisodeSampler, batch_size, num_workers)
- 调用 rr.init() 启动 Rerun 会话
- 在远端模式下额外执行 rr.serve() 开启 WebSocket+HTTP 服务
2. 数据记录
- 遍历每个 batch、每帧
- 用 rr.set_time_sequence/rr.set_time_seconds 标注时间信息
- 对所有摄像头键(camera_keys)逐帧 rr.log(Image)
- 逐维 rr.log(Scalar) 记录 action、observation.state、next.reward、next.done、next.success 等数值
3. 会话收尾
- 本地保存模式:rr.save() 写出 .rrd 文件并返回路径
- 远端服务模式:进入阻塞循环以保持 WebSocket 连接,直至 Ctrl–C 退出
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
) -> Path | None:
# 主可视化函数,根据模式(Local/Distant)实时或离线记录并展示数据
if save:
# 如果要保存为 .rrd 文件,必须传入 output_dir
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
)
repo_id = dataset.repo_id # 获取数据集唯一标识
logging.info("Loading dataloader") # 日志:开始加载 DataLoader
episode_sampler = EpisodeSampler(dataset, episode_index) # 创建只遍历指定 episode 的 sampler
dataloader = torch.utils.data.DataLoader(
dataset, # 数据集
num_workers=num_workers, # 并行加载进程数
batch_size=batch_size, # 每个 batch 的帧数
sampler=episode_sampler, # 自定义 sampler
)
logging.info("Starting Rerun") # 日志:启动 Rerun 会话
if mode not in ["local", "distant"]:
# 不支持其它模式时抛错
raise ValueError(mode)
# 本地模式且不保存时,自动 spawn viewer;否则不弹出
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Rerun v0.16 前的 workaround:触发垃圾回收,避免多进程 DataLoader 卡住
gc.collect()
if mode == "distant":
# 远端模式:启动 WebSocket + HTTP 服务,不自动打开浏览器
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun") # 日志:开始写入 Rerun 数据
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# 遍历每个 batch,显示进度条
for i in range(len(batch["index"])):
# 记录时间序列:帧索引与时间戳
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# 遍历所有 camera key,记录图像
for key in dataset.meta.camera_keys:
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# 如果存在 action 字段,则按维度记录每个动作值
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# 如果存在 observation.state,则按维度记录状态值
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
# 可选字段:next.done, next.reward, next.success
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
if mode == "local" and save:
# 本地保存模式:写入 .rrd 文件并返回路径
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
rr.save(rrd_path)
return rrd_path
elif mode == "distant":
# 远端模式:阻塞当前进程,直到手动按 Ctrl-C
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
4 本地数据集效果
python lerobot/scripts/visualize_dataset.py
--repo-id loacalhost/square_into_box
--root=./collections/square_into_box/
--episode-index 0