DAY 23 训练

news2025/12/17 8:57:10

DAY 23 训练

  • DAY23 机器学习管道 pipeline
    • 基础概念
      • 转换器(Transformer)
      • 估计器(Estimator)
    • 管道(Pipeline)
    • 代码演示
      • 没有 pipeline 的代码
      • pipeline 的代码教学
        • 导入库和数据加载
        • 分离特征和标签,划分数据集
        • 定义预处理步骤
        • 构建完整 pipeline
        • 使用 Pipeline 进行训练和评估
    • 总结


DAY23 机器学习管道 pipeline

在机器学习领域,“管道”(pipeline)是一个至关重要的概念。它为我们提供了一种高效且结构化的方式来组织和执行机器学习工作流程。今天,我将带领大家一起深入理解机器学习管道的概念和应用,并通过代码演示如何构建一个完整的机器学习管道。

基础概念

转换器(Transformer)

转换器是用于对数据进行预处理和特征提取的 estimator。它实现了 transform 方法,用于对数据进行预处理和特征提取。常见的转换器包括数据缩放器(如 StandardScaler、MinMaxScaler)、特征选择器(如 SelectKBest、PCA)、特征提取器(如 CountVectorizer、TF-IDFVectorizer)等。

from sklearn.preprocessing import StandardScaler

# 初始化转换器
scaler = StandardScaler()

# 学习训练数据的缩放规则
scaler.fit(X_train)

# 应用规则到训练数据和测试数据
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)

估计器(Estimator)

估计器是实现机器学习算法的对象或类。它用于拟合(fit)数据并进行预测(predict)。常见的估计器包括分类器(classifier)、回归器(regresser)、聚类器(clusterer)。

from sklearn.linear_model import LinearRegression

# 创建一个回归器
model = LinearRegression()

# 在训练集上训练模型
model.fit(X_train_scaled, y_train)

# 对测试集进行预测
y_pred = model.predict(X_test_scaled)

管道(Pipeline)

管道是一种将多个转换器和估计器按顺序连接在一起的机制,可以构建一个完整的数据处理和模型训练流程。在管道机制中,可以使用 Pipeline 类来组织和连接不同的转换器和估计器。

from sklearn.pipeline import Pipeline

# 构建一个简单的管道
pipeline = Pipeline([
    ('scaler', StandardScaler()),  # 数据标准化
    ('classifier', RandomForestClassifier())  # 分类器
])

# 在训练集上训练管道
pipeline.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = pipeline.predict(X_test)

代码演示

没有 pipeline 的代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

data = pd.read_csv('data.csv')

# 先筛选字符串变量
discrete_features = data.select_dtypes(include=['object']).columns.tolist()

# Home Ownership 标签编码
home_ownership_mapping = {
    'Own Home': 1,
    'Rent': 2,
    'Have Mortgage': 3,
    'Home Mortgage': 4
}
data['Home Ownership'] = data['Home Ownership'].map(home_ownership_mapping)

# Years in current job 标签编码
years_in_job_mapping = {
    '< 1 year': 1,
    '1 year': 2,
    '2 years': 3,
    '3 years': 4,
    '4 years': 5,
    '5 years': 6,
    '6 years': 7,
    '7 years': 8,
    '8 years': 9,
    '9 years': 10,
    '10+ years': 11
}
data['Years in current job'] = data['Years in current job'].map(years_in_job_mapping)

# Purpose 独热编码
data = pd.get_dummies(data, columns=['Purpose'])
data2 = pd.read_csv("data.csv")
list_final = []
for i in data.columns:
    if i not in data2.columns:
        list_final.append(i)
for i in list_final:
    data[i] = data[i].astype(int)

# Term 0 - 1 映射
term_mapping = {
    'Short Term': 0,
    'Long Term': 1
}
data['Term'] = data['Term'].map(term_mapping)
data.rename(columns={'Term': 'Long Term'}, inplace=True)

continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()

# 连续特征用中位数补全
for feature in continuous_features:
    mode_value = data[feature].mode()[0]
    data[feature].fillna(mode_value, inplace=True)

from sklearn.model_selection import train_test_split
X = data.drop(['Credit Default'], axis=1)
y = data['Credit Default']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings("ignore")

# --- 1. 默认参数的随机森林 ---
print("--- 1. 默认参数随机森林 (训练集 -> 测试集) ---")
import time
start_time = time.time()
rf_model = RandomForestClassifier(random_state=42)
rf_model.fit(X_train, y_train)
rf_pred = rf_model.predict(X_test)
end_time = time.time()

print(f"训练与预测耗时: {end_time - start_time:.4f} 秒")
print("\n默认随机森林 在测试集上的分类报告:")
print(classification_report(y_test, rf_pred))
print("默认随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, rf_pred))

pipeline 的代码教学

导入库和数据加载
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import warnings

warnings.filterwarnings("ignore")

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

data = pd.read_csv('data.csv')

print("原始数据加载完成,形状为:", data.shape)
分离特征和标签,划分数据集
y = data['Credit Default']
X = data.drop(['Credit Default'], axis=1)

print("\n特征和标签分离完成。")
print("特征 X 的形状:", X.shape)
print("标签 y 的形状:", y.shape)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print("\n数据集划分完成 (预处理之前)。")
print("X_train 形状:", X_train.shape)
print("X_test 形状:", X_test.shape)
print("y_train 形状:", y_train.shape)
print("y_test 形状:", y_test.shape)
定义预处理步骤
object_cols = X.select_dtypes(include=['object']).columns.tolist()
numeric_cols = X.select_dtypes(exclude=['object']).columns.tolist()

ordinal_features = ['Home Ownership', 'Years in current job', 'Term']
ordinal_categories = [
    ['Own Home', 'Rent', 'Have Mortgage', 'Home Mortgage'],
    ['< 1 year', '1 year', '2 years', '3 years', '4 years', '5 years', '6 years', '7 years', '8 years', '9 years', '10+ years'],
    ['Short Term', 'Long Term']
]

ordinal_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('encoder', OrdinalEncoder(categories=ordinal_categories, handle_unknown='use_encoded_value', unknown_value=-1))
])

nominal_features = ['Purpose']

nominal_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
])

continuous_features = [f for f in X.columns if f not in ordinal_features + nominal_features]

continuous_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('scaler', StandardScaler())
])

preprocessor = ColumnTransformer(
    transformers=[
        ('ordinal', ordinal_transformer, ordinal_features),
        ('nominal', nominal_transformer, nominal_features),
        ('continuous', continuous_transformer, continuous_features)
    ],
    remainder='passthrough'
)

print("\nColumnTransformer (预处理器) 定义完成。")
构建完整 pipeline
pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', RandomForestClassifier(random_state=42))
])

print("\n完整的 Pipeline 定义完成。")
使用 Pipeline 进行训练和评估
print("\n--- 1. 默认参数随机森林 (训练集 -> 测试集) ---")
start_time = time.time()

pipeline.fit(X_train, y_train)

pipeline_pred = pipeline.predict(X_test)

end_time = time.time()

print(f"训练与预测耗时: {end_time - start_time:.4f} 秒")

print("\n默认随机森林 在测试集上的分类报告:")
print(classification_report(y_test, pipeline_pred))
print("默认随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, pipeline_pred))

总结

通过使用 pipeline,我们可以将整个机器学习工作流程封装成一个简洁的流程,提高代码的可读性和可维护性。同时,pipeline 还可以帮助我们防止数据泄露,简化超参数调优,提高模型的性能和稳定性。

在实际项目中,我们可以使用 pipeline 来构建复杂的机器学习工作流,提高我们的工作效率。希望今天的分享能够帮助大家更好地理解和使用机器学习管道。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2378604.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

回溯法理论基础 LeetCode 77. 组合 LeetCode 216.组合总和III LeetCode 17.电话号码的字母组合

目录 回溯法理论基础 回溯法 回溯法的效率 用回溯法解决的问题 如何理解回溯法 回溯法模板 LeetCode 77. 组合 回溯算法的剪枝操作 LeetCode 216.组合总和III LeetCode 17.电话号码的字母组合 回溯法理论基础 回溯法 回溯法也可以叫做回溯搜索法&#xff0c;它是一…

【进程控制二】进程替换和bash解释器

【进程控制二】进程替换 1.exec系列接口2.execl系列2.1execl接口2.2execlp接口2.3execle 3.execv系列3.1execv3.2总结 4.实现一个bash解释器4.1内建命令 通过fork创建的子进程&#xff0c;会继承父进程的代码和数据&#xff0c;因此本质上还是在执行父进程的代码 进程替换可以将…

JavaScript 的编译与执行原理

文章目录 前言&#x1f9e0; 一、JavaScript 编译与执行过程1. 编译阶段&#xff08;发生在代码执行前&#xff09;✅ 1.1 词法分析&#xff08;Lexical Analysis&#xff09;✅ 1.2 语法分析&#xff08;Parsing&#xff09;✅ 1.3 语义分析与生成执行上下文 &#x1f9f0; 二…

NHANES指标推荐:FMI

文章题目&#xff1a;Exploring the relationship between fat mass index and metabolic syndrome among cancer patients in the U.S: An NHANES analysis DOI&#xff1a;10.1038/s41598-025-90792-9 中文标题&#xff1a;探索美国癌症患者脂肪量指数与代谢综合征之间的关系…

【JDBC】JDBC常见错误处理方法及驱动的加载

MySQL8中数据库连接的四个参数有两个发生了变化 String driver "com.mysql.cj.jdbc.Driver"; String url "jdbc:mysql://127.0.0.1:3306/mydb?useSSLfalse&useUnicodetrue&characterEncodingutf8&serverTimezoneAsia/Shanghai"; 或者Strin…

车载以太网驱动智能化:域控架构设计与开发实践

title: 车载以太网驱动专用车智能化&#xff1a;域控架构设计与开发实践 date: 2023-12-01 categories: 新能源汽车 tags: [车载以太网, 电子电气架构, 域控架构, 专用车智能化, SOME/IP, AUTOSAR] 引言&#xff1a;专用车智能化转型的挑战与机遇 专用车作为城市建设与工业运输…

如何利用技术手段提升小学数学练习效率

在日常辅导孩子数学作业的过程中&#xff0c;我发现了一款比较实用的练习题生成工具。这个工具的安装包仅1.8MB大小&#xff0c;但基本能满足小学阶段的数学练习需求。 主要功能特点&#xff1a; 参数化出题 可自由设置数字范围&#xff08;如10以内、100以内&#xff09; 支…

BGP路由策略 基础实验

要求: 1.使用Preva1策略&#xff0c;确保R4通过R2到达192.168.10.0/24 2.用AS_Path策略&#xff0c;确保R4通过R3到达192.168.11.0/24 3.配置MED策略&#xff0c;确保R4通过R3到达192.168.12.0/24 4.使用Local Preference策略&#xff0c;确保R1通过R2到达192.168.1.0/24 …

第9讲、深入理解Scaled Dot-Product Attention

Scaled Dot-Product Attention是Transformer架构的核心组件&#xff0c;也是现代深度学习中最重要的注意力机制之一。本文将从原理、实现和应用三个方面深入剖析这一机制。 1. 基本原理 Scaled Dot-Product Attention的本质是一种加权求和机制&#xff0c;通过计算查询(Query…

双向长短期记忆网络-BiLSTM

5月14日复盘 二、BiLSTM 1. 概述 双向长短期记忆网络&#xff08;Bi-directional Long Short-Term Memory&#xff0c;BiLSTM&#xff09;是一种扩展自长短期记忆网络&#xff08;LSTM&#xff09;的结构&#xff0c;旨在解决传统 LSTM 模型只能考虑到过去信息的问题。BiLST…

MySQL UPDATE 执行流程全解析

引言 当你在 MySQL 中执行一条 UPDATE 语句时&#xff0c;背后隐藏着一套精密的协作机制。从解析器到存储引擎&#xff0c;从锁管理到 WAL 日志&#xff0c;每个环节都直接影响数据一致性和性能。 本文将通过 Mermaid 流程图 和 时序图&#xff0c;完整还原 UPDATE 语句的执行…

亚马逊云科技:开启数字化转型的无限可能

在数字技术蓬勃发展的今天&#xff0c;云计算早已突破单纯技术工具的范畴&#xff0c;成为驱动企业创新、引领行业变革的核心力量。亚马逊云科技凭借前瞻性的战略布局与持续的技术深耕&#xff0c;在全球云计算领域树立起行业标杆&#xff0c;为企业和个人用户提供全方位、高品…

【实测有效】Edge浏览器打开部分pdf文件显示空白

问题现象 Edge浏览器打开部分pdf文件显示空白或显示异常。 ​​​​​​​ ​​​​​​​ ​​​​​​​ 问题原因 部分pdf文件与edge浏览器存在兼容性问题&#xff0c;打开显示异常。 解决办法 法1&#xff1a;修改edge配置 打开edge浏览器&#x…

RJ连接器的未来:它还会是网络连接的主流标准吗?

RJ连接器作为以太网接口的代表&#xff0c;自20世纪以来在计算机网络、通信设备、安防系统等领域中占据了核心地位。以RJ45为代表的RJ连接器&#xff0c;凭借其结构稳定、信号传输可靠、成本低廉等优势&#xff0c;在有线网络布线领域被广泛采用。然而&#xff0c;在无线网络不…

Redis持久化机制详解:保障数据安全的关键策略

在现代应用开发中&#xff0c;Redis作为高性能的内存数据库被广泛使用。然而&#xff0c;内存的易失性特性使得持久化成为Redis设计中的关键环节。本文将全面剖析Redis的持久化机制&#xff0c;包括RDB、AOF以及混合持久化模式&#xff0c;帮助开发者根据业务需求选择最适合的持…

DeepSeek 大模型部署全指南:常见问题、优化策略与实战解决方案

DeepSeek 作为当前最热门的开源大模型之一&#xff0c;其强大的语义理解和生成能力吸引了大量开发者和企业关注。然而在实际部署过程中&#xff0c;无论是本地运行还是云端服务&#xff0c;用户往往会遇到各种技术挑战。本文将全面剖析 DeepSeek 部署中的常见问题&#xff0c;提…

嵌入式培训之数据结构学习(五)栈与队列

一、栈 &#xff08;一&#xff09;栈的基本概念 1、栈的定义&#xff1a; 注&#xff1a;线性表中的栈在堆区&#xff08;因为是malloc来的&#xff09;&#xff1b;系统中的栈区存储局部变量、函数形参、函数返回值地址。 2、栈顶和栈底&#xff1a; 允许插入和删除的一端…

RabbitMQ--进阶篇

RabbitMQ 客户端整合Spring Boot 添加相关的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId> </dependency> 编写配置文件&#xff0c;配置RabbitMQ的服务信息 spri…

Android Studio报错Cannot parse result path string:

前言 最近在写个小Demo&#xff0c;参考郭霖的《第一行代码》&#xff0c;学习DrawerLayout和NavigationView&#xff0c;不知咋地&#xff0c;突然报错Cannot parse result path string:xxxxxxxxxxxxx 反正百度&#xff0c;问ai都找不到答案&#xff0c;报错信息是完全看不懂…

关于网站提交搜索引擎

发布于Eucalyptus-blog 一、前言 将网站提交给搜索引擎是为了让搜索引擎更早地了解、索引和显示您的网站内容。以下是一些提交网站给搜索引擎的理由&#xff1a; 提高可见性&#xff1a;通过将您的网站提交给搜索引擎&#xff0c;可以提高您的网站在搜索结果中出现的机会。当用…