部分依赖图(Partial Dependence Plots)以及实战-疾病引起原因解释

news2025/7/18 10:44:12

接上篇,特征重要性解释

特征重要性展示了每个特征发挥的作用情况,partial dependence plots可以展示一个特征怎样影响的了预测结果。
前提同样是应用在模型建立完成后进行使用,概述如下:

  • 首先选中一个样本数据,此时想观察Ball Possession列对结果的影响。

  • 保证其他特征列不变,改变当前观察列的值,例如选择40%,50%,60%(大小)分别进行预测,得到各自的结果。

  • 对比结果就能知道当前列(Ball Possession)对结果的影响情况。

  • 包: pdpbox

单特征观察

from matplotlib import pyplot as plt
from pdpbox import pdp, get_dataset, info_plots

pdp_goals = pdp.pdp_isolate(model=my_model, dataset=val_X, model_features=feature_names, feature='Goal Scored')

pdp.pdp_plot(pdp_goals, 'Goal Scored')
plt.show()

在这里插入图片描述
y轴表示预测结果的变化对比于基本模型,由于在观察时不可能只看一个样本数据,肯定要选择多个样本数据,蓝色区域表示的是置信度.

feature_to_plot = 'Distance Covered (Kms)'

rf_model = RandomForestClassifier(random_state=0).fit(train_X, train_y)

pdp_dist = pdp.pdp_isolate(model=rf_model, dataset=val_X, model_features=feature_names, feature=feature_to_plot)

pdp.pdp_plot(pdp_dist, feature_to_plot)
plt.show()

在这里插入图片描述

双特征观察

需要先改一下源码:将Anaconda3\Lib\site-packages\pdpbox 下pdp_plot_utils.py文件中contour_label_fontsize=fontsiz改成fontsize=fontsize

features_to_plot = ['Goal Scored', 'Distance Covered (Kms)']
inter1  =  pdp.pdp_interact(model=rf_model, dataset=val_X, model_features=feature_names, features=features_to_plot)

pdp.pdp_interact_plot(pdp_interact_out=inter1, feature_names=features_to_plot, plot_type='contour')
plt.show()

在这里插入图片描述

SHAP VALUES

可以直观的展示每一个特征对结果走势的影响
缺点:只传入一个样本点

row_to_show = 5
data_for_prediction = val_X.iloc[row_to_show] 
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)


my_model.predict_proba(data_for_prediction_array)

导入shap工具包可能出现问题,在Anaconda3/lib/site-packages/numpy/lib/arraypad.py中添加下面两个函数保存,重新加载即可,保存成utf-8

import shap  

#实例化
explainer = shap.TreeExplainer(my_model)

#计算
shap_values = explainer.shap_values(data_for_prediction)

返回的SHAP values中包括了negative和positive两种情况,通常选择一种(positive)即可

shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1], data_for_prediction)

在这里插入图片描述

Summary Plots

explainer = shap.TreeExplainer(my_model)

shap_values = explainer.shap_values(val_X)

shap.summary_plot(shap_values[1], val_X)

在这里插入图片描述

实战-疾病引起原因模型解释

疾病引起原因模型解释

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns 
from sklearn.ensemble import RandomForestClassifier 
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz 
from sklearn.metrics import roc_curve, auc 
from sklearn.metrics import classification_report 
from sklearn.metrics import confusion_matrix 
from sklearn.model_selection import train_test_split 
import eli5 
from eli5.sklearn import PermutationImportance
import shap 
from pdpbox import pdp, info_plots
np.random.seed(123) 

pd.options.mode.chained_assignment = None  

dt = pd.read_csv("heart.csv")
dt.head()

在这里插入图片描述

  • age:该人的年龄
  • sex:该人的性别(1 =男性,0 =女性)
  • cp:胸痛经历(值1:典型心绞痛,值2:非典型心绞痛,值3:非心绞痛,值4:无症状)
  • trestbps:该人的静息血压(入院时为mm Hg)
  • chol:人体胆固醇测量单位为mg / dl
  • fbs:该人的空腹血糖(> 120 mg / dl,1 = true; 0 = false)
  • restecg:静息心电图测量(0 =正常,1 =有ST-T波异常,2 =按Estes标准显示可能或明确的左心室肥厚)
  • thalach:达到了该人的最大心率
  • exang:运动诱发心绞痛(1 =是; 0 =否)
  • oldpeak:运动相对于休息引起的ST段压低('ST’与ECG图上的位置有关。点击此处查看更多内容)
  • slope:峰值运动ST段的斜率(值1:上升,值2:平坦,值3:下降)
  • ca:主要数量(0-3)
  • thal:称为地中海贫血的血液疾病(3 =正常; 6 =固定缺陷; 7 =可逆缺陷)
  • target:心脏病(0 =不,1 =是)
dt.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure', 'cholesterol', 'fasting_blood_sugar', 'rest_ecg', 'max_heart_rate_achieved',
       'exercise_induced_angina', 'st_depression', 'st_slope', 'num_major_vessels', 'thalassemia', 'target']

dt['sex'][dt['sex'] == 0] = 'female'
dt['sex'][dt['sex'] == 1] = 'male'

dt['chest_pain_type'][dt['chest_pain_type'] == 1] = 'typical angina'
dt['chest_pain_type'][dt['chest_pain_type'] == 2] = 'atypical angina'
dt['chest_pain_type'][dt['chest_pain_type'] == 3] = 'non-anginal pain'
dt['chest_pain_type'][dt['chest_pain_type'] == 4] = 'asymptomatic'

dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 0] = 'lower than 120mg/ml'
dt['fasting_blood_sugar'][dt['fasting_blood_sugar'] == 1] = 'greater than 120mg/ml'

dt['rest_ecg'][dt['rest_ecg'] == 0] = 'normal'
dt['rest_ecg'][dt['rest_ecg'] == 1] = 'ST-T wave abnormality'
dt['rest_ecg'][dt['rest_ecg'] == 2] = 'left ventricular hypertrophy'

dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 0] = 'no'
dt['exercise_induced_angina'][dt['exercise_induced_angina'] == 1] = 'yes'

dt['st_slope'][dt['st_slope'] == 1] = 'upsloping'
dt['st_slope'][dt['st_slope'] == 2] = 'flat'
dt['st_slope'][dt['st_slope'] == 3] = 'downsloping'

dt['thalassemia'][dt['thalassemia'] == 1] = 'normal'
dt['thalassemia'][dt['thalassemia'] == 2] = 'fixed defect'
dt['thalassemia'][dt['thalassemia'] == 3] = 'reversable defect'

dt.dtypes

dt['sex'] = dt['sex'].astype('object')
dt['chest_pain_type'] = dt['chest_pain_type'].astype('object')
dt['fasting_blood_sugar'] = dt['fasting_blood_sugar'].astype('object')
dt['rest_ecg'] = dt['rest_ecg'].astype('object')
dt['exercise_induced_angina'] = dt['exercise_induced_angina'].astype('object')
dt['st_slope'] = dt['st_slope'].astype('object')
dt['thalassemia'] = dt['thalassemia'].astype('object')

dt = pd.get_dummies(dt, drop_first=True)
dt.head()

X_train, X_test, y_train, y_test = train_test_split(dt.drop('target', 1), dt['target'], test_size = .2, random_state=10) 

model = RandomForestClassifier(max_depth=5)
model.fit(X_train, y_train)

estimator = model.estimators_[1]
feature_names = [i for i in X_train.columns]

y_train_str = y_train.astype('str')
y_train_str[y_train_str == '0'] = 'no disease'
y_train_str[y_train_str == '1'] = 'disease'
y_train_str = y_train_str.values

export_graphviz(estimator, out_file='tree.dot', 
                feature_names = feature_names,
                class_names = y_train_str,
                rounded = True, proportion = True, 
                label='root',
                precision = 2, filled = True)


from subprocess import call
call(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png', '-Gdpi=600'])

from IPython.display import Image
Image(filename = 'tree.png')

在这里插入图片描述

y_predict = model.predict(X_test)
y_pred_quant = model.predict_proba(X_test)[:, 1]
y_pred_bin = model.predict(X_test)

confusion_matrix = confusion_matrix(y_test, y_pred_bin)
confusion_matrix
fpr, tpr, thresholds = roc_curve(y_test, y_pred_quant)

fig, ax = plt.subplots()
ax.plot(fpr, tpr)
ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c=".3")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.rcParams['font.size'] = 12
plt.title('ROC curve for diabetes classifier')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.grid(True)
plt.show()

在这里插入图片描述

auc(fpr, tpr)
perm = PermutationImportance(model, random_state=1).fit(X_test, y_test)
eli5.show_weights(perm, feature_names = X_test.columns.tolist())
base_features = dt.columns.values.tolist()
base_features.remove('target')

feat_name = 'num_major_vessels'
pdp_dist = pdp.pdp_isolate(model=model, dataset=X_test, model_features=base_features, feature=feat_name)

pdp.pdp_plot(pdp_dist, feat_name)
plt.show()
feat_name = 'age'
pdp_dist = pdp.pdp_isolate(model=model, dataset=X_test, model_features=base_features, feature=feat_name)

pdp.pdp_plot(pdp_dist, feat_name)
plt.show()

inter1  =  pdp.pdp_interact(model=model, dataset=X_test, model_features=base_features, features=['st_slope_upsloping', 'st_depression'])

pdp.pdp_interact_plot(pdp_interact_out=inter1, feature_names=['st_slope_upsloping', 'st_depression'], plot_type='contour')
plt.show()

inter1  =  pdp.pdp_interact(model=model, dataset=X_test, model_features=base_features, features=['st_slope_flat', 'st_depression'])

pdp.pdp_interact_plot(pdp_interact_out=inter1, feature_names=['st_slope_flat', 'st_depression'], plot_type='contour')
plt.show()

在这里插入图片描述


explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[1], X_test)

在这里插入图片描述

def heart_disease_risk_factors(model, patient):

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(patient)
    shap.initjs()
    return shap.force_plot(explainer.expected_value[1], shap_values[1], patient)

data_for_prediction = X_test.iloc[1,:].astype(float)
heart_disease_risk_factors(model, data_for_prediction)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1010576.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

企业级镜像仓库Harbor的安装与配置

企业级镜像仓库Harbor的安装与配置 HarborHarbor概述安装Harbor配置 Harbor运行安装程序脚本登录启动与停止Harbor 登录Harbor仓库登录异常解决方案登录退出 推送拉取Harbor镜像镜像命名规范创建项目推送镜像拉取镜像 Harbor Harbor概述 Harbor是一个开源的容器镜像仓库管理系…

零售超市如何应对消费者需求?非常全面!

随着科技的飞速发展和消费者期望的不断演变,零售行业正经历着一场深刻的革命。传统零售模式逐渐被新零售模式所取代,而其中一个备受关注的元素是自动售货机。 自动售货机不仅在商场、车站和办公楼等高流量地点迅速扩张,还在重新定义我们如何购…

日志平台搭建第七章:Linux安装kafka-manager

相关链接https://github.com/yahoo/kafka-manager/releases kafka-manager-2.0.0.2下载地址 百度云链接:https://pan.baidu.com/s/1XinGcwpXU9YBF46qkrKS_A 提取码:tzvg 一、安装部署 1.把kafka-manager-2.0.0.2.zip拷贝到目录 /opt/app/elk 2.解压…

VSD Viewer 6.16.1(Visio绘图文件阅读器)

VSD Viewer是一款用于查看和打开Microsoft Visio文件的应用程序。Visio是一种流程图和图表设计工具,常用于创建各种类型的图形和图表,如组织结构图、流程图、网络拓扑图等。 VSD Viewer允许用户在没有安装Visio软件的情况下浏览和查看Visio文件。它提供…

变频器频率传感器信号转电压或电流信号隔离变送器0-1KHz / 0-5KHz / 0-10KHz转0-5V/0-10V/0-10mA/4-20mA

主要特性 精度等级&#xff1a;0.2 级全量程内极高的线性度&#xff08;非线性度<0.1%&#xff09;辅助电源/信号输入/信号输出&#xff1a; 2500VDC 三隔离辅助电源&#xff1a;5VDC&#xff0c;12VDC&#xff0c;24VDC 等单电源供电输入频率信号&#xff1a;0-1KHz / 0-5…

openGauss学习笔记-69 openGauss 数据库管理-创建和管理普通表-更新表中数据

文章目录 openGauss学习笔记-69 openGauss 数据库管理-创建和管理普通表-更新表中数据 openGauss学习笔记-69 openGauss 数据库管理-创建和管理普通表-更新表中数据 修改已经存储在数据库中数据的行为叫做更新。用户可以更新单独一行、所有行或者指定的部分行。还可以独立更新…

C++(Qt)软件调试---GCC编译参数学习-程序检测(13)

C(Qt)软件调试—GCC编译参数学习-程序检测&#xff08;13&#xff09; 文章目录 C(Qt)软件调试---GCC编译参数学习-程序检测&#xff08;13&#xff09;1、前言1.1 概述1.2 测试环境 2、GCC编译警告选项1.1 编译警告的作用1.2 GCC常用的编译警告选项 3、GCC程序检测选项1.1 性能…

解决Agora声网音视频在后台没有声音的问题

前言:本文会介绍 Android 与 iOS 两个平台的处理方式 一、Android高版本在应用退到后台时,系统为了省电会限制应用的后台活动,因此我们需要开启一个前台服务,在前台服务中发送常驻任务栏通知,以此来保证App 退到后台时不会被限制活动. 前台服务代码如下: package com.notify…

在ubuntu安装vncserver,可以打开远程桌面

1.网上查找vncserver的资料 2.在ubuntu20.04上面使用 https://www.5axxw.com/questions/simple/ywnpl5 参考了这个和其他很多 我直接运行vncserver&#xff0c;没有成功&#xff01; 当我使用这个命令时 xtigervncviewer -SecurityTypes VncAuth -passwd /home/yx/.vnc/passw…

构建高效的接口自动化测试框架思路

在选择接口测试自动化框架时&#xff0c;需要根据团队的技术栈和项目需求来综合考虑。对于测试团队来说&#xff0c;使用Python相关的测试框架更为便捷。无论选择哪种框架&#xff0c;重要的是确保 框架功能完备&#xff0c;易于维护和扩展&#xff0c;提高测试效率和准确性。今…

Biome-BGC生态系统模型与Python融合技术教程

详情点击公众号链接&#xff1a;Biome-BGC生态系统模型与Python融合技术教程 前言 Biome-BGC是利用站点描述数据、气象数据和植被生理生态参数&#xff0c;模拟日尺度碳、水和氮通量的有效模型&#xff0c;其研究的空间尺度可以从点尺度扩展到陆地生态系统。 在Biome-BGC模型…

crm、scrm、ocrm、acrm、ccrm等等分别是什么?有什么区别?

crm、SCRM、OCRM、ACRM、CCRM等等分别是什么&#xff1f;crm&#xff0c;scrm&#xff0c;ocrm&#xff0c;acrm&#xff0c;ccrm有什么区别&#xff1f;又有什么联系&#xff1f;这些系统各自之间都有哪些优势和缺点呢&#xff1f;今天将带领大家深入浅出的系统了解crm&#x…

idea 创建java web项目 run后出现404现象

1、创建新项目 创建的新项目只是单纯的java项目&#xff0c;如图 2、添加lib库文件&#xff0c;里面存放jar包&#xff0c;并导入库配置 这里要注意&#xff0c;需要先添加lib库文件再去配置模块和工件否则会出现404现象 3、打开模块设置&#xff0c;设置项目配置 将本…

c语言练习58:⾃定义类型:结构体

⾃定义类型&#xff1a;结构体 结构体的概念 结构是⼀些值的集合&#xff0c;这些值称为成员变量。结构的每个成员可以是不同类型的变量。 结构体是一个种自定义的数据类型&#xff0c;它可以由很多个默认数据类型组成。它主要用于描述复杂场景下的变量。 例如&#xff0c;想…

go语言基础--杂谈

常量运行期间&#xff0c;不可以改变的值 const PI float64 3.14字面常量 所谓字面常量&#xff0c;是值程序中硬编码的常量 123 //整型类型常量 156.78 //浮点类型的常量 true //布尔类型的常量 “abc”//字符串类型的常量//const PI float64 3.14 const PI 3.14 // //P…

在Kubernetes集群中部署 dolphindcheduler-3.1.8

温故知新 &#x1f4da;第一章 前言&#x1f4d7;背景&#x1f4d7;目的&#x1f4d7;总体方向 &#x1f4da;第二章 部署&#x1f4d7;安装helm&#x1f4d7;安装dolphindcheduler&#xff08;使用k8s的部署用户操作&#xff09;&#x1f4d5;通过命令验证&#x1f4d5;通过Ku…

OpenCV(四十二):Harris角点检测

1.Harris角点介绍 什么是角点&#xff1f; 角点指的是两条边的交点&#xff0c;图中红色圈起来的点就是角点。 Harris角点检测原理&#xff1a;首先定义一个矩形区域&#xff0c;然后将这个矩形区域放置在我的图像中&#xff0c;求取这个区域内所有的像素值之和&#xff0c;之…

【C语言】每日一题(半月斩)——day2

目录 一.选择题 1、以下程序段的输出结果是( ) 2、若有以下程序&#xff0c;则运行后的输出结果是&#xff08; &#xff09; 3、如下函数的 f(1) 的值为&#xff08; &#xff09; 4、下面3段程序代码的效果一样吗( ) 5、对于下面的说法&#xff0c;正确的是&#xf…

2024CFA一级notes百度网盘下载

024CFA一级notes百度网盘下载 2024CFA一级notes2024年CFA考纲已经正式发布&#xff0c;相比与老考纲&#xff0c;新考纲变化实在不算小。 面对2024年CFA新考纲的变化&#xff0c;我们在第一时间对2024年考试的新趋势和新变化&#xff0c;进行深度解读。具体总结如下&#xff…

ddtrace 系列篇之 dd-trace-java 项目编译

dd-trace-java 是 Datadog 开源的 java APM 框架&#xff0c;本文主要讲解如何编译 dd-trace-java 项目。 环境准备 JDK 编译环境(三个都要&#xff1a;jdk8\jdk11\jdk17) Gradle 8 Maven 3.9 (需要 15G 以上的存储空间存放依赖) Git >2 (低于会出现一想不到的异常&#xf…