365打卡第R3周: RNN-心脏病预测

news2025/7/16 17:24:03

🍨 本文为🔗365天深度学习训练营中的学习记录博客
🍖 原作者:K同学啊

🏡 我的环境:

语言环境:Python3.10
编译器:Jupyter Lab
深度学习环境:torch==2.5.1    torchvision==0.20.1
------------------------------分割线---------------------------------

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns
 
#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

df = pd.read_csv("data/heart.csv")
 
df

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
 
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
 
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X  = sc.fit_transform(X)
 
X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)
 
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size = 0.1, 
                                                    random_state = 1)
 
X_train.shape, y_train.shape
 
from torch.utils.data import TensorDataset, DataLoader
 
train_dl = DataLoader(TensorDataset(X_train, y_train),
                      batch_size=64, 
                      shuffle=False)
 
test_dl  = DataLoader(TensorDataset(X_test, y_test),
                      batch_size=64, 
                      shuffle=False)

class model_rnn(nn.Module):
    def __init__(self):
        super(model_rnn, self).__init__()
        self.rnn0 = nn.RNN(input_size=13 ,hidden_size=200, 
                           num_layers=1, batch_first=True)
 
        self.fc0   = nn.Linear(200, 50)
        self.fc1   = nn.Linear(50, 2)
 
    def forward(self, x):
 
        out, hidden1 = self.rnn0(x) 
        out    = self.fc0(out) 
        out    = self.fc1(out) 
        return out   
 
model = model_rnn().to(device)
model

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)
 
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches
 
    return train_acc, train_loss
 
def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()
 
    test_acc  /= size
    test_loss /= num_batches
 
    return test_acc, test_loss
 
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs     = 50
 
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
 
for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
 
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
 
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = opt.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
print("="*20, 'Done', "="*20)

import matplotlib.pyplot as plt
from datetime import datetime
#隐藏警告
import warnings
warnings.filterwarnings("ignore")        #忽略警告信息
 
current_time = datetime.now() # 获取当前时间
 
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 200        #分辨率
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

print("==============输入数据Shape为==============")
print("X_test.shape:",X_test.shape)
print("y_test.shape:",y_test.shape)
 
pred = model(X_test.to(device)).argmax(1).cpu().numpy()
 
print("\n==============输出数据Shape为==============")
print("pred.shape:",pred.shape)
 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)
 
plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
 
# 修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)
 
# 显示图
plt.tight_layout()  # 调整布局防止重叠
plt.show()
 
test_X = X_test[0].reshape(1, -1) # X_test[0]即我们的输入数据
 
pred = model(test_X.to(device)).argmax(1).item()
print("模型预测结果为:",pred)
print("=="*20)
print("0:不会患心脏病")
print("1:可能患心脏病")

测试集预测准确率为87.1%,为了进一步提高模型性能,做了如下改进:

1、调整数据加载方式(启用Shuffle),打乱训练数据顺序,使每个batch的数据分布更均匀,提升模型泛化能力。

# 修改前(shuffle=False)
train_dl = DataLoader(..., shuffle=False)

# 修改后
train_dl = DataLoader(..., shuffle=True)

2、调整优化器和学习率

# 修改前
learn_rate = 1e-4
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)

# 修改后
opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)  # 增加L2正则
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=5)  # 动态调整学习率

修改之后预测率最高超过了90%,预测准确率也有所提高

------------------小结---------------------------

循环神经网络(RNN)是一种专门处理序列数据的神经网络模型,其核心特点可归纳如下:

一、​​序列建模能力​

RNN通过​​循环连接结构​​处理序列输入(如文本、语音、时间序列),每个时间步的隐藏状态(hidden state)不仅依赖当前输入,还继承前一时间步的状态。这种设计使其能够捕捉​​时间依赖性​​,例如句子中单词的上下文关系或股价的历史趋势。

二、​​参数共享机制​

所有时间步共享同一组权重参数(如输入-隐藏层和隐藏-隐藏层矩阵),显著减少模型复杂度并提升泛化能力。这一特性使其能灵活处理​​变长序列​​,无需为不同长度输入单独设计网络。

三、​​记忆性与动态状态​

RNN通过隐藏状态的递归更新实现​​记忆功能​​,理论上可保留无限长的历史信息。例如,在机器翻译中,模型能利用前文信息生成连贯的译文。但实际中受限于梯度问题,长程依赖能力较弱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2344137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【实战】基于强化学习的 Agent 训练框架全流程拆解

一、引言 在人工智能蓬勃发展的今天,强化学习(Reinforcement Learning, RL)作为让智能体(Agent)在复杂环境中自主学习并做出最优决策的核心技术,正日益受到关注。从游戏领域中击败人类顶尖选手的 AlphaGo&a…

【音视频】⾳频处理基本概念及⾳频重采样

一、重采样 1.1 什么是重采样 所谓的重采样,就是改变⾳频的采样率、sample format、声道数等参数,使之按照我们期望的参数输出。 1.2 为什么要重采样 为什么要重采样? 当然是原有的⾳频参数不满⾜我们的需求,⽐如在FFmpeg解码⾳频的时候…

Prompt 结构化提示工程

Prompt 结构化提示工程 目前ai开发工具都大同小异,随着deepseek的流行,ai工具的能力都差不太多,功能基本都覆盖到了。而prompt能力反而是需要更加关注的(说白了就是能不能把需求清晰的输出成文档)。因此大家可能需要加…

Pycharm 代理配置

Pycharm 代理配置 文章目录 Pycharm 代理配置1. 设置系统代理1.1 作用范围1.2 使用场景1.3 设置步骤 2. 设置 python 运行/调试代理2.1 作用范围2.2 使用场景2.3 设置步骤 Pycharm 工具作为一款强大的 IDE,其代理配置在实际开发中也是必不可少的,下面介绍…

Spring Native:GraalVM原生镜像编译与性能优化

文章目录 引言一、Spring Native与GraalVM基础1.1 GraalVM原理与优势1.2 Spring Native架构设计 二、原生镜像编译实践2.1 构建配置与过程2.2 常见问题与解决方案 三、性能优化技巧3.1 内存占用优化3.2 启动时间优化3.3 实践案例分析 总结 引言 微服务架构的普及推动了轻量级、…

药监平台上传数据报资源码不存在

问题:电子监管码上传药监平台提示“导入的资源码不存在” 现象:从生产系统导出的关联关系数据包上传到药监平台时显示: 原因:上传数据包的通道的资源码与数据包的资源码不匹配。 解决方法:检查药监平台和生产系统的药…

【Linux应用】交叉编译环境配置,以及最简单粗暴的环境移植(直接从目标板上复制)

【Linux应用】交叉编译环境配置,以及最简单粗暴的环境移植(直接从目标板上复制) 文章目录 交叉编译器含有三方库的交叉编译直接从目标板上复制编译环境glibc库不一致报错方法1方法2 附录:ZERO 3烧录ZERO 3串口shell外设挂载连接Wi…

CSS3布局方式介绍

CSS3布局方式介绍 CSS3布局(Layout)系统是现代网页设计中用于构建页面结构和控制元素排列的一组强大工具。CSS3提供了多种布局方式,每种方式都有其适用场景,其中最常用的是Flexbox和CSS Grid。 先看传统上几种布局方式&#xff…

FPGA设计 时空变换

1、时空变换基本概念 1.1、时空概念简介 时钟速度决定完成任务需要的时间,规模的大小决定完成任务所需要的空间(资源),因此速度和规模就是FPGA中时间和空间的体现。 如果要提高FPGA的时钟,每个clk内组合逻辑所能做的事…

《AI大模型趣味实战》智能Agent和MCP协议的应用实例:搭建一个能阅读DOC文件并实时显示润色改写过程的Python Flask应用

智能Agent和MCP协议的应用实例:搭建一个能阅读DOC文件并实时显示润色改写过程的Python Flask应用 引言 随着人工智能技术的飞速发展,智能Agent与模型上下文协议(MCP)的应用场景越来越广泛。本报告将详细介绍如何基于Python Flask框架构建一个智能应用&…

uniapp开发03-轮播图组件swiper的简单使用案例

uniapp开发03-轮播图组件swiper的简单使用案例!这个仅仅是官方提供的一个轮播图组件啊。实际上我们项目开发的时候,会应用到其他第三方公司的轮播图组件资源!效果更强大。兼容性更强。 废话不多说,我们直接上代码。分析代码。 &l…

【Android】四大组件之Service

目录 一、什么是Service 二、启停 Service 三、绑定 Service 四、前台服务 五、远程服务扩展 六、服务保活 七、服务启动方法混用 你可以把Service想象成一个“后台默默打工的工人”。它没有UI界面,默默地在后台干活,比如播放音乐、下载文件、处理…

TRO再添新案 TME再拿下一热门IP,涉及Paddington多个商标

4月2日和4月8日,TME律所代理Paddington & Company Ltd.对热门IP Paddington Bear帕丁顿熊的多类商标发起维权,覆盖文具、家居用品、毛绒玩具、纺织用品、游戏、电影、咖啡、填充玩具等领域。跨境卖家需立即排查店铺内的相关产品! 案件基…

WPF之项目创建

文章目录 引言先决条件创建 WPF 项目步骤理解项目结构XAML 与 C# 代码隐藏第一个 "Hello, WPF!" 示例构建和运行应用程序总结相关学习资源 引言 Windows Presentation Foundation (WPF) 是 Microsoft 用于构建具有丰富用户界面的 Windows 桌面应用程序的现代框架。它…

AI数字人:未来职业的重塑(9/10)

摘要:AI 数字人凭借计算机视觉、自然语言处理与深度学习技术,从虚拟形象进化为智能交互个体,广泛渗透金融、教育、电商等多领域,重构职业生态。其通过降本提效、场景拓展与体验升级机制,替代重复岗位工作,催…

深入解析Mlivus Cloud中的etcd配置:最佳实践与高级调优指南

作为大禹智库的向量数据库高级研究员,我在《向量数据库指南》一书中详细阐述了向量数据库的核心组件及其优化策略。今天,我将基于30余年的实战经验,深入剖析Mlivus Cloud中etcd这一关键依赖的配置细节与优化方法。对于希望深入掌握Mlivus Cloud的读者,我强烈建议参考《向量…

前端面试宝典---vue原理

vue的Observer简化版 class Observer {constructor(value) {if (!value || typeof value ! object) returnthis.walk(value) // 对对象的所有属性进行遍历并定义响应式}walk (obj) {Object.keys(obj).forEach(key > defineReactive(obj, key, obj[key]))} } // 定义核心方法…

PyTorch卷积层填充(Padding)与步幅(Stride)详解及代码示例

本文通过具体代码示例讲解PyTorch中卷积操作的填充(Padding)和步幅(Stride)对输出形状的影响,帮助读者掌握卷积层的参数配置技巧。 一、填充与步幅基础 填充(Padding):在输入数据边缘…

用go从零构建写一个RPC(仿gRPC,tRPC)--- 版本1

希望借助手写这个go的中间件项目,能够理解go语言的特性以及用go写中间件的优势之处,同时也是为了更好的使用和优化公司用到的trpc,并且作者之前也使用过grpc并有一定的兴趣,所以打算从0构建一个rpc系统,对于生产环境已…

django之账号管理功能

账号管理功能 目录 1.账号管理页面 2.新增账号 3.修改账号 4.账号重置密码 5.删除账号功能 6.所有代码展示集合 7.运行结果 这一片文章, 我们需要新增账号管理功能, 今天我们写到的代码, 基本上都是用到以前所过的知识, 不过也有需要注意的细节。 一、账号管理界面 …