【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)

news2025/5/31 21:20:22

目录

  • 1 准备工作:python库包安装
    • 1.1 安装必要库
  • 案例说明:模拟视频帧的时序建模
    • ConvLSTM概述
    • 损失函数说明
    • (python全代码)
  • 参考

ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。

1 准备工作:python库包安装

1.1 安装必要库

pip install torch torchvision matplotlib numpy

案例说明:模拟视频帧的时序建模

🎯 目标:给定一个人工生成的动态图像序列(例如移动的方块),使用 ConvLSTM 对其进行建模,输出预测结果,并查看输出的维度和特征变化。

ConvLSTM概述

ConvLSTM 的基本结构,包括:

  • ConvLSTMCell:实现了一个时间步的 ConvLSTM 单元,类似于一个“时刻”的神经元。
  • ConvLSTM:实现了多层ConvLSTM结构,能够处理一整个时间序列的视频帧数据。

损失函数说明

MSE(均方误差) 衡量预测值和真实值之间的平均平方差。
在这里插入图片描述

关于训练终止条件:
可以根据 MSE是否达到某个阈值(如 < 0.001)提前终止训练,这是所谓的 “Early Stopping(提前停止)策略”。

(python全代码)

MSE损失函数曲线如下:可知MSE一直在下降,虽然存在振荡
在这里插入图片描述

前9帧图像及预测的第十帧图像得到的动图如下:
在这里插入图片描述

python完整代码如下:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 设置字体
plt.rcParams['font.family'] = 'Times New Roman'

# 创建保存图像目录
os.makedirs("./Figures", exist_ok=True)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================================
# 一、ConvLSTM 模型结构
# ====================================

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()
        padding = kernel_size // 2
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding, bias=bias)

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):
        super(ConvLSTM, self).__init__()
        self.num_layers = num_layers
        layers = []
        for i in range(num_layers):
            in_channels = input_channels if i == 0 else hidden_channels
            layers.append(ConvLSTMCell(in_channels, hidden_channels, kernel_size))
        self.layers = nn.ModuleList(layers)

    def forward(self, input_seq):
        b, t, c, h, w = input_seq.size()
        h_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]
        c_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]

        for time in range(t):
            x = input_seq[:, time]
            for i, layer in enumerate(self.layers):
                h_t[i], c_t[i] = layer(x, h_t[i], c_t[i])
                x = h_t[i]

        return h_t[-1]  # 返回最后一层最后一帧的隐藏状态

# ====================================
# 二、生成移动方块序列数据
# ====================================

def generate_moving_square_sequence(batch_size, time_steps, height, width):
    data = torch.zeros((batch_size, time_steps, 1, height, width))
    for b in range(batch_size):
        dx = np.random.randint(1, 3)
        dy = np.random.randint(1, 3)
        x = np.random.randint(0, width - 6)
        y = np.random.randint(0, height - 6)
        for t in range(time_steps):
            data[b, t, 0, y:y+5, x:x+5] = 1.0
            x = (x + dx) % (width - 5)
            y = (y + dy) % (height - 5)
    return data

# ====================================
# 三、模型、损失、优化器
# ====================================

class ConvLSTM_Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.convlstm = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)
        self.decoder = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)

    def forward(self, input_seq):
        hidden = self.convlstm(input_seq)
        pred = self.decoder(hidden)
        return pred

model = ConvLSTM_Predictor().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ====================================
# 四、训练过程
# ====================================

mse_list = []
max_epochs = 100
mse_threshold = 0.001
height, width = 64, 64

for epoch in range(max_epochs):
    model.train()
    seq = generate_moving_square_sequence(8, 10, height, width).to(device)
    input_seq = seq[:, :9]
    target_frame = seq[:, 9, 0].unsqueeze(1)

    optimizer.zero_grad()
    output = model(input_seq)
    loss = criterion(output, target_frame)
    loss.backward()
    optimizer.step()

    mse = loss.item()
    mse_list.append(mse)

    print(f"Epoch {epoch+1}/{max_epochs}, MSE: {mse:.6f}")

    # 提前停止条件
    if mse < mse_threshold:
        print(f"✅ 提前停止:MSE 已达到阈值 {mse_threshold}")
        break

# ====================================
# 五、测试与可视化结果
# ====================================

model.eval()
with torch.no_grad():
    test_seq = generate_moving_square_sequence(1, 10, height, width).to(device)
    input_seq = test_seq[:, :9]
    true_frame = test_seq[:, 9, 0]
    pred_frame = model(input_seq)[0, 0].cpu().numpy()

# 保存输入帧
for t in range(9):
    frame = input_seq[0, t, 0].cpu().numpy()
    plt.imshow(frame, cmap='gray')
    plt.title(f"Input Frame t={t}")
    plt.colorbar()
    plt.savefig(f"./Figures/input_frame_{t}.png")
    plt.close()

# 保存 Ground Truth
plt.imshow(true_frame[0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth Frame t=9")
plt.colorbar()
plt.savefig("./Figures/ground_truth_t9.png")
plt.close()

# 保存预测帧
plt.imshow(pred_frame, cmap='gray')
plt.title("Predicted Frame t=9")
plt.colorbar()
plt.savefig("./Figures/predicted_t9.png")
plt.close()

# 保存 MSE 曲线图
plt.plot(mse_list)
plt.title("Training MSE Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(True)
plt.savefig("./Figures/mse_curve.png")
plt.close()

# ---------------- 生成动图 ----------------

frames = []

# 添加前9帧输入
for t in range(9):
    img = Image.open(f"./Figures/input_frame_{t}.png")
    frames.append(img.copy())

# 添加预测帧
img = Image.open("./Figures/predicted_t9.png")
frames.append(img.copy())

# 保存动图
frames[0].save("./Figures/sequence.gif", save_all=True, append_images=frames[1:], duration=500, loop=0)
print("✅ 所有图像和动图已保存至 ./Figures 文件夹")

参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2391837.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文解读】DETR: 用Transformer实现真正的End2End目标检测

1st authors: About me - Nicolas Carion‪Francisco Massa‬ - ‪Google Scholar‬ paper: [2005.12872] End-to-End Object Detection with Transformers ECCV 2020 code: facebookresearch/detr: End-to-End Object Detection with Transformers 1. 背景 目标检测&#…

ElasticSearch简介及常用操作指南

一. ElasticSearch简介 ElasticSearch 是一个基于 Lucene 构建的开源、分布式、RESTful 风格的搜索和分析引擎。 1. 核心功能 强大的搜索能力 它能够提供全文检索功能。例如&#xff0c;在海量的文档数据中&#xff0c;可以快速准确地查找到包含特定关键词的文档。这在处理诸如…

纤维组织效应偏斜如何影响您的高速设计

随着比特率继续飙升&#xff0c;光纤编织效应 &#xff08;FWE&#xff09; 偏移&#xff0c;也称为玻璃编织偏移 &#xff08;GWS&#xff09;&#xff0c;正变得越来越成为一个问题。今天的 56GB/s 是高速路由器中最先进的&#xff0c;而 112 GB/s 指日可待。而用于个人计算机…

Rust使用Cargo构建项目

文章目录 你好&#xff0c;Cargo&#xff01;验证Cargo安装使用Cargo创建项目新建项目配置文件解析默认代码结构 Cargo工作流常用命令速查表详细使用说明1. 编译项目2. 运行程序3.快速检查4. 发布版本构建 Cargo的设计哲学约定优于配置工程化优势 开发建议1. 新项目初始化​2. …

Python训练营打卡Day39

DAY 39 图像数据与显存 知识点回顾 1.图像数据的格式&#xff1a;灰度和彩色数据 2.模型的定义 3.显存占用的4种地方 a.模型参数梯度参数 b.优化器参数 c.数据批量所占显存 d.神经元输出中间状态 4.batchisize和训练的关系 作业&#xff1a;今日代码较少&#xff0c;理解内容…

UE5蓝图中播放背景音乐和使用代码播放声音

UE5蓝图中播放背景音乐 1.创建背景音乐Cube 2.勾选looping 循环播放背景音乐 3.在关卡蓝图中 Event BeginPlay-PlaySound2D Sound选择自己创建的Bgm_Cube 蓝图播放声音方法二&#xff1a; 使用代码播放声音方法一 .h文件中 头文件引用 #include "Kismet/GameplayS…

AI 赋能数据可视化:漏斗图制作的创新攻略

在数据可视化的广阔天地里&#xff0c;漏斗图以其独特的形状和强大的功能&#xff0c;成为展示流程转化、分析数据变化的得力助手。传统绘制漏斗图的方式往往需要耗费大量时间和精力&#xff0c;对使用者的绘图技能和软件操作熟练度要求颇高。但随着技术的蓬勃发展&#xff0c;…

用 Python 模拟下雨效果

用 Python 模拟下雨效果 雨天别有一番浪漫情怀&#xff1a;淅淅沥沥的雨滴、湿润的空气、朦胧的光影……在屏幕上也能感受下雨的美妙。本文将带你用一份简单的 Python 脚本&#xff0c;手把手实现「下雨效果」动画。文章深入浅出&#xff0c;零基础也能快速上手&#xff0c;完…

C#对象集合去重的一种方式

前言 现在AI越来越强大了&#xff0c;有很多问题其实不需要在去各个网站上查了&#xff0c;直接问AI就好了&#xff0c;但是呢&#xff0c;AI给的代码可能能用&#xff0c;也可能需要调整&#xff0c;但是自己肯定是要会的&#xff0c;所以还是总结一下吧。 问题 如果有一个…

在ROS2(humble)+Gazebo+rqt下,实时显示仿真无人机的相机图像

文章目录 前言一、版本检查检查ROS2版本 二、步骤1.下载对应版本的PX4(1)检查PX4版本(2)修改文件名(3)下载正确的PX4版本 2.下载对应版本的Gazebo(1)检查Gazebo版本(2)卸载不正确的Gazebo版本(3)下载正确的Gazebo版本 3.安装bridge包4.启动 总结 前言 在ROS2的环境下&#xff…

github双重认证怎么做

引言 好久没登陆github了&#xff0c; 今天登陆github后&#xff0c;提醒进行2FA认证。 查看了github通知&#xff0c;自 2023 年 3 月起&#xff0c;GitHub 要求所有在 GitHub.com 上贡献代码的用户启用一种或多种形式的双重身份验证 (2FA)。 假如你也遇到这个问题&#xf…

数据的类型——认识你的数据

第02篇&#xff1a;数据的类型——认识你的数据 写在前面&#xff1a;嗨&#xff0c;大家好&#xff01;我是蓝皮怪。在上一篇文章中&#xff0c;我们聊了统计学的基本概念&#xff0c;今天我们来深入了解一个非常重要的话题——数据的类型。你可能会想&#xff1a;"数据就…

第五十二节:增强现实基础-简单 AR 应用实现

引言 增强现实(Augmented Reality, AR)是一种将虚拟信息叠加到真实世界的技术,广泛应用于游戏、教育、工业维护等领域。与传统虚拟现实(VR)不同,AR强调虚实结合,用户无需完全沉浸到虚拟环境中。本文将通过Python和OpenCV库,从零开始实现一个基础的AR应用:在检测到特定…

LLaMaFactory 微调QwenCoder模型

步骤一&#xff1a;准备LLamaFactory环境 首先,让我们尝试使用github的方式克隆仓库: git config --global http.sslVerify false && git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git # 创建新环境&#xff0c;指定 Python 版本&#xff08;以 3.…

【最新版】Arduino IDE的安装入门Demo

1、背景说明 1、本教程编写日期为2025-5-24 2、Arduino IDE的版本为&#xff1a;Arduino IDE 2.3.6 3、使用的Arduino为Arduino Uno 1、ArduinoIDE的安装 1、下载。网址如下&#xff1a;官网 2、然后一路安装即可。 期间会默认安装相关驱动&#xff0c;默认安装即可。 3、安…

不起火,不爆炸,高速摄像机、数字图像相关DIC技术在动力电池新国标安全性能测试中的应用

2026年7月1日&#xff0c;我国将正式实施GB38031-2025《电动汽车用动力蓄电池安全要求》——这项被称为“史上最严电池安全令”的新国标&#xff0c;首次将“热失控不蔓延、不起火、不爆炸”从企业技术储备上升为强制性要求&#xff0c;标志着电池安全进入“零容忍”时代&#…

thinkadmin中使用layui日期选择器,数据库存储时间戳

form.html <div class="layui-form-item label-required-prev" id="jiezhi_time-div">

WSL中ubuntu通过Windows带代理访问github

WSL中ubuntu通过Windows带代理访问github 前言: WSL是Windows下的ubuntu访问工具&#xff0c;目前无法访问外网&#xff0c;因此需要配置一下。 步骤一 代理中进行如下设置: 步骤二 ubuntu22.04中修改配置 使用如下命令获取IP地址&#xff1a; ip route | grep default | aw…

RISC-V特权模式及切换

1 RISC-V特权模式基本概念 1.1 RISC-V特权模式介绍 RISC-V 指令集架构&#xff08;ISA&#xff09;采用多特权级别设计作为其核心安全机制&#xff0c;通过层次化的权限管理实现系统资源的隔离与保护。该架构明确定义了四个层次化的特权模式&#xff0c;按照权限等级由高至低…

【深度学习】11. Transformer解析: Self-Attention、ELMo、Bert、GPT

Transformer 神经网络 Self-Attention 的提出动机 传统的循环神经网络&#xff08;RNN&#xff09;处理序列信息依赖时间步的先后顺序&#xff0c;无法并行&#xff0c;而且在捕捉长距离依赖关系时存在明显困难。为了解决这些问题&#xff0c;Transformer 引入了 Self-Attent…