使用 Keras 训练一个卷积神经网络(CNN)(入门篇)

news2024/12/5 20:52:29

在上一篇文章中,我们介绍了如何使用 Keras 训练一个简单的全连接神经网络(MLP)。本文将带你深入学习如何使用 Keras 构建和训练一个卷积神经网络(CNN),用于图像分类任务。我们将继续使用 MNIST 数据集,但这次我们将采用更适合图像数据的 CNN 架构。

目录

  1. 环境准备
  2. 导入必要的库
  3. 加载和预处理数据
  4. 构建卷积神经网络模型
  5. 编译模型
  6. 训练模型
  7. 评估模型
  8. 保存和加载模型
  9. 可视化训练过程
  10. 总结

1. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

2. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于数据可视化。

3. 加载和预处理数据

我们继续使用 Keras 自带的 MNIST 数据集。

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 查看数据形状
print(f"训练数据形状: {x_train.shape}, 训练标签形状: {y_train.shape}")
print(f"测试数据形状: {x_test.shape}, 测试标签形状: {y_test.shape}")

# 数据预处理
# 归一化:将像素值缩放到 0-1 之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# CNN 需要添加通道维度
x_train = np.expand_dims(x_train, -1)  # 形状变为 (60000, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)    # 形状变为 (10000, 28, 28, 1)

# 将标签转换为分类编码
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# 可视化部分数据
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(np.argmax(y_train[i]))
plt.show()

说明:

  • CNN 需要输入数据具有通道维度,因此使用 np.expand_dims 添加一个维度。
  • MNIST 数据集是灰度图像,因此通道维度为 1。

4. 构建卷积神经网络模型

我们将构建一个简单的 CNN 模型,包含两个卷积层和两个池化层,最后接上全连接层进行分类。

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # 卷积层,32 个 3x3 卷积核
    layers.MaxPooling2D((2, 2)),  # 最大池化层,池化窗口 2x2

    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层,64 个 3x3 卷积核
    layers.MaxPooling2D((2, 2)),  # 最大池化层

    layers.Flatten(),  # 展平层
    layers.Dense(64, activation='relu'),  # 全连接层,64 个神经元
    layers.Dense(num_classes, activation='softmax')  # 输出层,10 个神经元
])

# 查看模型结构
model.summary()

说明:

  • Conv2D: 二维卷积层,用于提取图像特征。
  • MaxPooling2D: 最大池化层,用于下采样,减少参数数量。
  • Flatten: 将多维输入一维化,以便连接全连接层。
  • Dense: 全连接层,用于分类。

5. 编译模型

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和交叉熵损失函数。
  • 评估指标为准确率。

6. 训练模型

# 设置训练参数
batch_size = 128
epochs = 10

# 训练模型
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_split=0.1)  # 使用 10% 的训练数据作为验证集

说明:

  • 使用 10% 的训练数据作为验证集,以监控模型在验证集上的性能。

7. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\n测试准确率: {test_acc:.4f}")

8. 保存和加载模型

# 保存模型
model.save("mnist_cnn_model.h5")

# 加载模型
new_model = keras.models.load_model("mnist_cnn_model.h5")

9. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))

# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')

# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')

plt.show()

说明:

  • 通过可视化训练过程中的准确率和损失值,可以帮助我们了解模型的训练情况,判断是否存在过拟合或欠拟合。

10. 本节回顾

本节介绍了如何使用 Keras 构建和训练一个简单的卷积神经网络(CNN),用于手写数字识别任务。主要步骤包括:

  1. 环境准备和库导入: 确保安装了必要的库,并导入所需模块。
  2. 数据加载和预处理: 加载 MNIST 数据集,进行归一化,并添加通道维度。
  3. 构建 CNN 模型: 使用 Conv2D、MaxPooling2D、Flatten、Dense 等层构建模型。
  4. 编译模型: 指定优化器、损失函数和评估指标。
  5. 训练模型: 使用训练数据训练模型,并使用验证集监控性能。
  6. 评估模型: 在测试集上评估模型性能。
  7. 保存和加载模型: 将训练好的模型保存到磁盘,并可加载进行预测。
  8. 可视化训练过程: 通过绘制准确率和损失值曲线,了解模型的训练情况。

通过这个基础教程,你可以开始自行探索更复杂的 CNN 模型和更深入的应用,如图像分类、目标检测、图像分割等。

导师简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2240750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VUE实现点击导航栏进行切换右边内容

首先看看效果&#xff0c;左边导航栏进行切换&#xff0c;右边内容进行切换 代码如下 <div><el-tabs :tab-position"tabPosition" style"height: 800px;"><el-tab-pane label"通知通告">通知通告</el-tab-pane><el-t…

Power BI - Get Data from SQLite

1.简单介绍 Power BI支持连接的数据源很多&#xff0c;比如SQL Server, PostgreSQL,MYSQL, Dataverse等。如果要连接小型数据源Sqlite&#xff0c;由于没有内置的直接连接Sqlite选项&#xff0c;需要使用ODBC driver来进行连接 这边尝试一下使用Power BI去连接本地Sqlite数据…

DAY63||拓扑排序精讲 |dijkstra(朴素版)精讲

拓扑排序精讲 117. 软件构建 题目描述 某个大型软件项目的构建系统拥有 N 个文件&#xff0c;文件编号从 0 到 N - 1&#xff0c;在这些文件中&#xff0c;某些文件依赖于其他文件的内容&#xff0c;这意味着如果文件 A 依赖于文件 B&#xff0c;则必须在处理文件 A 之前处理文…

爬虫补环境案例---问财网(rpc,jsdom,代理,selenium)

目录 一.环境检测 1. 什么是环境检测 2.案例讲解 二 .吐环境脚本 1. 简介 2. 基础使用方法 3.数据返回 4. 完整代理使用 5. 代理封装 6. 封装所有使用方法 jsdom补环境 1. 环境安装 2. 基本使用 3. 添加参数形式 Selenium补环境 1. 简介 2.实战案例 1. 逆向目…

DICOM图像解析:深入解析DICOM格式文件的高效读取与处理(续)

目录 一、DICOM图像高效解析 1、处理压缩的像素数据 常见压缩算法及其处理方法 解压缩示例 2、多帧图像的处理 多帧图像解析流程 三维图像的体绘制 3、序列和嵌套数据元素 序列数据的解析 二、错误处理与数据验证 常见错误类型 错误处理策略 三、使用现有的DICOM库…

「QT」高阶篇 之 d-指针 的用法

✨博客主页何曾参静谧的博客&#x1f4cc;文章专栏「QT」QT5程序设计&#x1f4da;全部专栏「Win」Windows程序设计「IDE」集成开发环境「UG/NX」BlockUI集合「C/C」C/C程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「UG/NX」NX定制…

如何绑定洛谷账号

首先注册洛谷 然后登陆 点击键盘F12 点击加号 点击应用程序 在name中找到__client_id和_uid 再复制相应的value到下图右侧 然后点击confirm即可 愿我们都能成为我们想要去成为的人&#xff01; 花会沿路盛开&#xff0c;我们以后的路也会&#xff01; 追风赶月莫停留&…

5G CPE:为什么活动会场与商铺的网络成为最新选择

在快节奏的现代社会中&#xff0c;无论是举办一场盛大的活动还是经营一家繁忙的商铺&#xff0c;稳定的网络连接都是不可或缺的基石。然而&#xff0c;面对复杂的布线难题或高昂的商业宽带费用&#xff0c;许多场所往往陷入两难境地。幸运的是&#xff0c;5G CPE&#xff08;Cu…

【ACM独立出版|高校主办】第四届信号处理与通信技术国际学术会议(SPCT 2024)

第四届信号处理与通信技术国际学术会议&#xff08;SPCT 2024&#xff09; 2024 4th International Conference on Signal Processing and Communication Technology 2024年12月27-29日 中国深圳 www.icspct.com 会议亮点&#xff1a; 1、ACM独立出版&#xff0c;EI稳…

CPU的性能指标总结(学习笔记)

CPU 性能指标 我们先来回顾下&#xff0c;描述 CPU 的性能指标都有哪些。 首先&#xff0c;最容易想到的应该是 CPU 使用率&#xff0c;这也是实际环境中最常见的一个性能指标。 用户 CPU 使用率&#xff0c;包括用户态 CPU 使用率&#xff08;user&#xff09;和低优先级用…

python 同时控制多部手机

在这个智能时代,我们的手机早已成为生活和工作中不可或缺的工具。无论是管理多个社交媒体账号,还是处理多台设备上的事务,如何更高效地控制多个手机成为了每个人的痛点。 今天带来的这个的软件为你提供了一键控制多部手机的强大功能。无论是办公、娱乐,还是社交,你都能通过…

实用教程:如何无损修改MP4视频时长

如何在UltraEdit中搜索MP4文件中的“mvhd”关键字 引言 在视频编辑和分析领域&#xff0c;有时我们需要深入到视频文件的底层结构中去。UltraEdit&#xff08;UE&#xff09;和UEStudio作为强大的文本编辑器&#xff0c;允许我们以十六进制模式打开和搜索MP4文件。本文将指导…

java项目-jenkins任务的创建和执行

参考内容: jenkins的安装部署以及全局配置 1.编译任务的general 2.源码管理 3.构建里编译打包然后copy复制jar包到运行服务器的路径 clean install -DskipTests -Pdev 中的-Pdev这个参数用于激活 Maven 项目中的特定构建配置&#xff08;Profile&#xff09; 在 pom.xml 文件…

【扩散——BFS】

题目 代码 #include <bits/stdc.h> using namespace std; const int t 2020, off 2020; #define x first #define y second typedef pair<int, int> PII; int dx[] {0, 0, 1, -1}, dy[] {-1, 1, 0, 0}; int dist[6080][6080]; // 0映射到2020&#xff0c;2020…

C++编程:利用环形缓冲区优化 TCP 发送流程,避免 Short Write 问题

文章目录 1. 什么是 Short Write 问题&#xff1f;2. 如何解决 Short Write 问题&#xff1f;2.1 方法 1&#xff1a;将 Socket 设置为阻塞模式2.2 方法 2&#xff1a;用户态维护发送缓冲区 3. 用户态维护发送缓冲区实现3.1 核心要点3.2 代码实现3.3 测试程序 参考文档 1. 什么…

数据网格能替代数据仓库吗?

一、数据网格是什么&#xff1f; 数据网格&#xff1a;是一种新兴的数据管理架构和理念&#xff0c;主要用于解决大规模、复杂数据环境下的数据管理和利用问题。 核心概念&#xff1a; 1、数据即产品&#xff1a;将数据看作一种产品&#xff0c;每个数据域都要对其生产的数据负…

力扣经典面试26题删除有序数组中的重复项1

给你一个非严格递增排列的数组nums&#xff0c;请你原地删除重复出现的元素&#xff0c; 使每个元素 只出现一次&#xff0c;返回删除后数组的新长度。元素的相对顺序 应该保持 一致。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数量为 k&#xff0c; 你需要做以…

LLM: AI Mathematical Olympiad (上)

文章目录 一、项目简介二、first place 攻略三、必备知识1、COT思维链技术2、ToRA 四、first place 训练功略五、数据集构建1、COT数据集2、TIR数据集 六、数据集详细技术报告总结 本文较长分成两个部分分析 | ू•ૅω•́)ᵎᵎᵎ 第一部分&#xff1a;预备知识介绍和数据准备…

GA/T1400视图库平台EasyCVR视频融合平台HLS视频协议是什么?

在数字化时代&#xff0c;视频监控系统已成为保障安全、提升效率的关键技术。EasyCVR视频融合云平台&#xff0c;作为TSINGSEE青犀视频在“云边端”架构体系中的重要一环&#xff0c;专为大中型项目设计&#xff0c;提供了一个跨区域、网络化的视频监控综合管理系统平台。它不仅…

给阿里云OSS绑定域名并启用SSL

为什么要这么做&#xff1f; 问题描述&#xff1a; 当用户通过 OSS 域名访问文件时&#xff0c;OSS 会在响应头中增加 Content-Disposition: attachment 和 x-oss-force-download: true&#xff0c;导致文件被强制下载而不是预览。这个问题特别影响在 2022/10/09 之后新开通 OS…