基于LSTM-AutoEncoder的心电信号时间序列数据异常检测(PyTorch版)

news2025/5/11 7:30:40

时序异常检验
心电信号(ECG)的异常检测对心血管疾病早期预警至关重要,但传统方法面临时序依赖建模不足与噪声敏感等问题。本文使用一种基于LSTM-AutoEncoder的深度时序异常检测框架,通过编码器-解码器结构捕捉心电信号的长期时空依赖特征,并结合动态阈值自适应识别异常片段。模型在编码阶段利用LSTM层提取时序上下文信息,解码阶段重构正常ECG波形,以重构误差为异常评分依据。在MIT-BIH心律失常数据库上的实验表明,该方法在AUC-ROC(0.932)和F1-Score(0.876)上显著优于孤立森林、CNN-AE等基线模型,误报率降低23.6%。该技术可应用于可穿戴设备的实时心电监护,为临床提供高鲁棒性的自动化异常检测方案。

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、能源电力以及自然语言处理等诸多领域,探讨如何使用各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、注意力机制等实现时序预测、分类、异常检验以及概率预测。

1. 数据集介绍

本文使用ECG5000心电图时间序列数据集

import pandas as pd
from scipy.io.arff import loadarff
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
from torchinfo import summary
from torchmetrics.functional.classification import precision, recall, f1_score, auroc
from torchmetrics.functional.classification import binary_confusion_matrix
# Download the dataset
traindata, trainmeta = loadarff('../ECG5000/ECG5000_TRAIN.arff')
testdata, testmeta = loadarff('../ECG5000/ECG5000_TEST.arff')
train = pd.DataFrame(traindata, columns=trainmeta.names())
test = pd.DataFrame(testdata, columns=testmeta.names())
df = pd.concat([train, test])
print(train.shape, test.shape, df.shape)
(500, 141) (4500, 141) (5000, 141)

2. 数据可视化

将数据划分为正常心电信号数据normal和异常心电信号数据abnormal

normal = df[df.iloc[:, -1] == b'1']
abnormal = df[df.iloc[:, -1] != b'1']
# 设置全局字体样式
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'serif'
fig, axes = plt.subplots(2, 1, figsize=(9, 12))

# 绘制正常数据
axes[0].plot(normal.values.T)
axes[0].set_title('Normal Electrocardiogram (ECG)', 
                 fontsize=20, 
                 pad=10)

# 绘制异常数据
axes[1].plot(abnormal.values.T)
axes[1].set_title('Abnormal Electrocardiogram (ECG)',
                 fontsize=20,
                 pad=10)

# 调整子图间距
plt.tight_layout()
plt.show()

心电图

3. 数据预处理

# 2. 数据预处理
# 只使用正常样本训练自编码器
X_normal = normal.iloc[:, :-1].values
X_abnormal = abnormal.iloc[:, :-1].values

3.1 转换数据类型

# 转换为PyTorch张量 (添加通道维度)
normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1)
abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1)
print(normal_tensor.shape, abnormal_tensor.shape)
torch.Size([2919, 140, 1]) torch.Size([2081, 140, 1])

3.2 数据集划分(Subset)

# 划分训练集(正常样本)和验证集索引
dataset = TensorDataset(normal_tensor, normal_tensor)
train_idx = list(range(len(dataset)*4//5)) # 划分训练集索引
val_idx = list(range(len(dataset)*4//5, len(dataset))) # 划分验证集索引
print(len(train_idx), len(val_idx))
2335 584

划分测试集,包含异常数据,用于模型的最终测试。

# 划分测试集(正常+异常)
x_val_tensor = normal_tensor[val_idx]
x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0)
y_test_tensor = torch.cat(
    (
        torch.zeros(len(x_val_tensor),dtype=torch.long),
        torch.ones(len(abnormal_tensor),dtype=torch.long)
     ),
    dim=0
)
print(x_test_tensor.shape, y_test_tensor.shape)
torch.Size([2665, 140, 1]) torch.Size([2665])

3.3 数据加载器

通过 SubsetRandomSampler 从完整数据集 dataset 中按索引划分训练集和验证集,并生成批量数据迭代器‌。SubsetRandomSampler 会在每次迭代时随机打乱索引顺序,避免训练数据顺序固定导致的模型过拟合‌。

train_sampler = SubsetRandomSampler(indices=train_idx)
val_sampler = SubsetRandomSampler(indices=val_idx)
train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler)

DataLoadersampler 参数优先级高于 shuffle,因此无需设置 shuffle=True‌

4. 构建时序异常检测模型

4.1 构建LSTM编码器

class Encoder(nn.Module):
    def __init__(self, context_len, n_variables, embedding_dim=64):
        super(Encoder, self).__init__()
        self.context_len, self.n_variables = context_len, n_variables  # 时间步、输入特征
        self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim
        self.lstm1 = nn.LSTM(
            input_size=self.n_variables,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True,
        )
        self.lstm2 = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=embedding_dim,
            num_layers=1,
            batch_first=True,
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x, (_, _) = self.lstm1(x)
        x, (hidden_n, _) = self.lstm2(x)
        return hidden_n.reshape((batch_size, self.embedding_dim))

4.2 构建LSTM解码器

class Decoder(nn.Module):
    def __init__(self, context_len, n_variables=1, input_dim=64):
        super(Decoder, self).__init__()
        self.context_len, self.input_dim = context_len, input_dim
        self.hidden_dim, self.n_variables = 2 * input_dim, n_variables
        self.lstm1 = nn.LSTM(
            input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True
        )
        self.lstm2 = nn.LSTM(
            input_size=input_dim,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True,
        )
        self.output_layer = nn.Linear(self.hidden_dim, self.n_variables)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.repeat(self.context_len, self.n_variables)
        x = x.reshape((batch_size, self.context_len, self.input_dim))
        x, (hidden_n, cell_n) = self.lstm1(x)
        x, (hidden_n, cell_n) = self.lstm2(x)
        x = x.reshape((batch_size, self.context_len, self.hidden_dim))

        return self.output_layer(x)

4.3 构建LSTM AE

class LSTMAutoencoder(nn.Module):
    def __init__(self, context_len, n_variables, embedding_dim):
        super().__init__()
        self.encoder = Encoder(context_len, n_variables, embedding_dim)
        self.decoder = Decoder(context_len, n_variables, embedding_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

4.4 实例化模型、定义损失函数与优化器

automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64)
optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4)
criterion = nn.MSELoss()

4.5 模型概要

summary(model=automodel, input_size=(128, 140, 1))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
LSTMAutoencoder                          [128, 140, 1]             --
├─Encoder: 1-1                           [128, 64]                 --
│    └─LSTM: 2-1                         [128, 140, 128]           67,072
│    └─LSTM: 2-2                         [128, 140, 64]            49,664
├─Decoder: 1-2                           [128, 140, 1]             --
│    └─LSTM: 2-3                         [128, 140, 64]            33,280
│    └─LSTM: 2-4                         [128, 140, 128]           99,328
│    └─Linear: 2-5                       [128, 140, 1]             129
==========================================================================================
Total params: 249,473
Trainable params: 249,473
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 4.47
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 55.19
Params size (MB): 1.00
Estimated Total Size (MB): 56.26
==========================================================================================

5. 模型训练

5.1 定义训练函数

在模型训练之前,我们需先定义 train 函数来执行模型训练过程

def train(model, iterator):
    model.train()
    epoch_loss = 0

    for batch_idx, (data, target) in enumerate(iterable=iterator):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(iterator)

    return avg_loss

上述代码定义了一个名为 train 的函数,用于训练给定的模型。它接收模型、数据迭代器作为参数,并返回训练过程中的平均损失。

5.2 定义评估函数

def evaluate(model, iterator): # Being used to validate and test
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(iterable=iterator):
            output = model(data)
            loss = criterion(output, target)
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(iterator)

        return avg_loss

上述代码定义了一个名为 evaluate 的函数,用于评估给定模型在给定数据迭代器上的性能。它接收模型、数据迭代器作为参数,并返回评估过程中的平均损失。这个函数通常在模型训练的过程中定期被调用,以监控模型在验证集或测试集上的性能。通过评估模型的性能,可以了解模型的泛化能力和训练的进展情况。

5.3 定义早停法并保存模型

定义早停法以便在模型训练过程中调用

class EarlyStopping:
    def __init__(self, patience=5, delta=0.0):
        self.patience = patience  # 允许的连续未改进次数
        self.delta = delta        # 损失波动容忍阈值
        self.counter = 0          # 未改进计数器
        self.best_loss = float('inf')  # 最佳验证损失值
        self.early_stop = False   # 终止训练标志

    def __call__(self, val_loss, model):
        if val_loss < (self.best_loss - self.delta):
            self.best_loss = val_loss
            self.counter = 0
            # 保存最佳模型参数‌:ml-citation{ref="1,5" data="citationList"}
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            self.counter +=1
            if self.counter >= self.patience:
                self.early_stop = True
EarlyStopper = EarlyStopping(patience=10, delta=0.00001)  # 设置参数

若不想使用早停法EarlyStopper,参数patience设置一个超大的值,delta设置为0,即可。

5.4 定义模型训练主程序

通过定义模型训练主程序来执行模型训练

def main():
    train_losses = []
    val_losses = []

    for epoch in range(300):
        train_loss = train(model=automodel, iterator=train_loader)
        val_loss = evaluate(model=automodel, iterator=val_loader)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}')

        # 触发早停判断
        EarlyStopper(val_loss, model=automodel)
        if EarlyStopper.early_stop:
            print(f"Early stopping at epoch {epoch}")
            break

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSELoss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

5.5 执行模型训练过程

main()
Epoch: 69, Train MSELoss: 0.21886, Val. MSELoss: 0.21556
Epoch: 70, Train MSELoss: 0.22166, Val. MSELoss: 0.21716
Epoch: 71, Train MSELoss: 0.22082, Val. MSELoss: 0.20737
Epoch: 72, Train MSELoss: 0.21676, Val. MSELoss: 0.20873
Epoch: 73, Train MSELoss: 0.22007, Val. MSELoss: 0.21766
Epoch: 74, Train MSELoss: 0.22644, Val. MSELoss: 0.21219
Epoch: 75, Train MSELoss: 0.22045, Val. MSELoss: 0.20890
Epoch: 76, Train MSELoss: 0.22027, Val. MSELoss: 0.21222
Epoch: 77, Train MSELoss: 0.21933, Val. MSELoss: 0.20765
Epoch: 78, Train MSELoss: 0.22219, Val. MSELoss: 0.20903
Epoch: 79, Train MSELoss: 0.22051, Val. MSELoss: 0.20856
Epoch: 80, Train MSELoss: 0.22001, Val. MSELoss: 0.21346
Epoch: 81, Train MSELoss: 0.21968, Val. MSELoss: 0.21276
Early stopping at epoch 80

损失

6. 异常检测

6.1 异常检测

接下来,我们通过构建 detect_anomalies 函数来对模型中的数据进行检测。

# 5. 异常检测
def detect_anomalies(model, x):
    model.eval()
    with torch.no_grad():
        reconstructions = model(x)
        mse = torch.mean((x - reconstructions)**2, dim=(1,2))
    return mse

6.2 设置阈值

# 在测试集上计算重建误差
test_mse = detect_anomalies(automodel, x_test_tensor)

# 设置阈值 (使用验证集正常样本的95%分位数)
val_mse = detect_anomalies(automodel, x_val_tensor)
threshold = torch.quantile(val_mse, 0.95)

# 预测结果
y_pred = (test_mse > threshold).long()
print(f'Threshold: {threshold:.4f}')
print(y_pred.dtype)
print(y_pred.shape)
Threshold: 0.5402
torch.int64
torch.Size([2665])

7. 模型评估

7.1 评估函数

torchmetrics库提供了各种评估函数,例如:精确率Precision、召回率Recall、F1分数F1-Score Area Under ROC Curve \text{Area Under ROC Curve} Area Under ROC Curve,我们可以直接用来评估模型性能

pre = precision(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Precision: {pre:.5f}")

rec = recall(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Recall: {rec:.5f}")

f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary")
print(f"F1 Score: {f1:.5f}")

auc = auroc(preds=test_mse, target=y_test_tensor, task="binary")
print(f"AUC: {auc:.5f}")
Precision: 0.98586
Recall: 0.97165
F1 Score: 0.97870
AUC: 0.98020

7.2 混淆矩阵

cm = binary_confusion_matrix(preds=y_pred, target=y_test_tensor)
print(cm)
tensor([[ 555,   29],
        [  59, 2022]])

预测可视化

# 7. 可视化部分结果
plt.figure(figsize=(12, 6))
plt.plot(test_mse, label='Reconstruction Error')
plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
plt.title('Anomaly Detection Results')
plt.xlabel('Sample Index')
plt.ylabel('MSE')
plt.legend()
plt.show()

Anomaly Detection Results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2338572.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript中的Event事件对象详解

一、事件对象&#xff08;Event&#xff09;概述 1. 事件对象的定义 event 对象是浏览器自动生成的对象&#xff0c;当用户与页面进行交互时&#xff08;如点击、键盘输入、鼠标移动等&#xff09;&#xff0c;事件触发时就会自动传递给事件处理函数。event 对象包含了与事件…

王牌学院,25西电通信工程学院(考研录取情况)

1、通信工程学院各个方向 2、通信工程学院近三年复试分数线对比 学长、学姐分析 由表可看出&#xff1a; 1、信息与通信工程25年相较于24年上升5分、军队指挥学25年相较于24年上升30分 2、新一代电子信息技术&#xff08;专硕&#xff09;25年相较于24年下降25分、通信工程&…

深入理解 Java 多线程:锁策略与线程安全

文章目录 一、常见的锁策略1. 乐观锁&&悲观锁2. 读写锁3. 重量级锁&&轻量级锁4. 自旋锁5. 公平锁&&不公平锁6. 可重入锁 && 不可重入锁 二、CAS1. 什么是 CAS2. CAS 是怎么实现的3.CAS 有哪些应用1) 实现原子类2) 实现自旋锁 4. CAS 的 ABA 问…

Java数据结构——ArrayList

Java中ArrayList 一 ArrayList的简介二 ArrayList的构造方法三 ArrayList常用方法1.add()方法2.remove()方法3.get()和set()方法4.index()方法5.subList截取方法 四 ArrayList的遍历for循环遍历增强for循环(for each)迭代器遍历 ArrayList问题及其思考 前言 ArrayList是一种 顺…

科学量化AI对品牌产品印象 首个AI印象(AII)指数发布

2025年4月18日&#xff0c;营销传播数据研究领先机构四度传播研究院(SAC)&#xff0c;正式推出了量化AI大模型对产品整体印象的AI印象&#xff0c;简称AII&#xff08;ARTIFICIAL INTELLIGENCE IMPRESSIONS&#xff09;&#xff0c;同时发布了首个“汽车AI印象榜”。为企业和消…

FFmpeg 硬核指南:从底层架构到播放器全链路开发实战 基础

目录 1.ffmpeg的基本组成2.播放器的API2.1 复用器阶段2.1.1 分配解复用上下文2.1.2 文件信息操作2.1.3 综合示例 2. 2 编解码部分2.2.1 分配解码器上下文2.2.2编解码操作2.2.3 综合示例 3 ffmpeg 内存模型3.1 基本概念3.2API 1.ffmpeg的基本组成 模块名称功能描述主要用途AVFo…

UE5有些场景的导航生成失败解决方法

如果导航丢失&#xff0c;就在项目设置下将&#xff1a; 即可解决问题&#xff1a; 看了半个小时的导航生成代码发现&#xff0c;NavDataSet这个数组为空&#xff0c;导致异步构建导航失败。 解决 NavDataSet 空 无法生成如下&#xff1a; 当 NavDataSet 为空的化 如果 bAut…

MCP(Model Context Protocol 模型上下文协议)科普

MCP&#xff08;Model Context Protocol&#xff0c;模型上下文协议&#xff09;是由人工智能公司 Anthropic 于 2024年11月 推出的开放标准协议&#xff0c;旨在为大型语言模型&#xff08;LLM&#xff09;与外部数据源、工具及服务提供标准化连接&#xff0c;从而提升AI在实际…

健康养生指南

在快节奏的现代生活中&#xff0c;健康养生成为人们关注的焦点。它不仅关乎身体的强健&#xff0c;更是提升生活质量、预防疾病的关键。掌握科学的养生方法&#xff0c;能让我们在岁月流转中始终保持活力。 饮食是健康养生的基础。遵循 “均衡膳食” 原则&#xff0c;每日饮食需…

Linux系统:进程终止的概念与相关接口函数(_exit,exit,atexit)

本节目标 理解进程终止的概念理解退出状态码的概念以及使用方法掌握_exit与exit函数的用法以及区别atexit函数注册终止时执行的函数相关宏 一、进程终止 进程终止&#xff08;Process Termination&#xff09;是指操作系统结束一个进程的执行&#xff0c;回收其占用的资源&a…

Linux下 文件的查找、复制、移动和解压缩

1、在/var/log目录下创建一个hehe.log的文件&#xff0c;其文件内容是&#xff1a; myhostname ghl mydomain localdomain relayhost [smtp.qq.com]:587 smtp_use_tls yes smtp_sasl_auth_enable yes smtp_sasl_security_options noanonymous smtp_sasl_tls_security_opt…

C语言学习之预处理指令

目录 预定义符号 #define的应用 #define定义常量 #define定义宏 带有副作用的宏参数 宏替换的规则 函数和宏定义的区别 #和## #运算符 ##运算符 命名约定 #undef ​编辑 命令行定义 条件编译 头文件包含 头文件被包含的方式 1.本地头文件包含 2.库文件包含 …

【STM32单片机】#10 USART串口通信

主要参考学习资料&#xff1a; B站江协科技 STM32入门教程-2023版 细致讲解 中文字幕 开发资料下载链接&#xff1a;https://pan.baidu.com/s/1h_UjuQKDX9IpP-U1Effbsw?pwddspb 单片机套装&#xff1a;STM32F103C8T6开发板单片机C6T6核心板 实验板最小系统板套件科协 实验&…

fastlio用mid360录制的bag包离线建图,提示消息类型错误

我用mid360录制的bag包&#xff0c;激光雷达的数据类型是sensor_msgs::PointCloud2&#xff0c;但是运行fast_lio中的mid360 launch文件&#xff0c;会报错&#xff08;没截图&#xff09;&#xff0c;显示无法从livox_ros_driver2::CustomMsg转换到sensor_msgs::PointCloud2。…

二级评论列表-Java实现

二级评论列表是很常见的功能&#xff0c;文章记录了新手用Java实现的具体逻辑。 整体实现逻辑是先用2个sql&#xff0c;分别查出两层数据。然后用java在service中实现数据组装&#xff0c;返给前端。这种实现思路好处是SQL简洁&#xff0c;逻辑分明&#xff0c;便于维护。 一…

IP检测工具“ipjiance”

目录 IP质量检测 应用场景 对网络安全的贡献 对网络管理的帮助 对用户决策的辅助作用 IP质量检测 检测IP的网络提供商&#xff1a;通过ASN&#xff08;自治系统编号&#xff09;识别IP地址所属的网络运营商&#xff0c;例如电信、移动、联通等。 识别网络类型&#xff1…

Replicate Python client

本文翻译整理自&#xff1a;https://github.com/replicate/replicate-python 文章目录 一、关于 Replicate Python 客户端相关链接资源关键功能特性 二、1.0.0 版本的重大变更三、安装与配置1、系统要求2、安装3、认证配置 四、核心功能1、运行模型2、异步IO支持3、流式输出模型…

deekseak 本地windows 10 部署步骤

有些场景需要本地部署&#xff0c;例如金融、医疗&#xff08;HIPAA&#xff09;、政府&#xff08;GDPR&#xff09;、军工等&#xff0c;需完全控制数据存储和访问权限&#xff0c;避免云端合规风险或者偏远地区、船舶、矿井等无法依赖云服务&#xff0c;关键设施&#xff08…

<sql>、<resultMap>、<where>、<foreach>、<trim>、<set>等标签的作用和用法

目录 一. sql 代码片段标签 二. resultMap 映射结果集标签 三. where 条件标签 四. set 修改标签 五. trim 标签 六. foreach 循环标签 一. sql 代码片段标签 sql 标签是 mybatis 框架中一个非常常用的标签页&#xff0c;特别是当一张表很有多个字段多&#xff0c;或者要…

【项目】CherrySudio配置MCP服务器

CherrySudio配置MCP服务器 &#xff08;一&#xff09;Cherry Studio介绍&#xff08;二&#xff09;MCP服务环境搭建&#xff08;1&#xff09;环境准备&#xff08;2&#xff09;依赖组件安装<1> Bun和UV安装 &#xff08;3&#xff09;MCP服务器使用<1> 搜索MCP…