作业2 CNN实现手写数字识别

news2025/5/13 1:09:55
# 导入必要库
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns  # 用于高级可视化
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import time  # 用于计时

# ======================
# 1. 数据加载与预处理
# ======================

# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 数据预处理
# 归一化并添加通道维度(CNN需要通道信息)
x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255
x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255

# 将标签转换为one-hot编码
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# ======================
# 2. 构建CNN模型
# ======================
model = keras.Sequential([
    # 第一卷积层:32个3x3滤波器,ReLU激活,输入28x28x1
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),  # 下采样

    # 第二卷积层:64个3x3滤波器,ReLU激活
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    # 全连接层前处理
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),  # 防止过拟合

    # 输出层:10个类别,softmax激活
    layers.Dense(10, activation='softmax')
])

# ======================
# 3. 模型编译与训练
# ======================
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy', # 分类交叉熵
    metrics=['accuracy']   # 准确率
)

# 训练配置
epochs = 15
batch_size = 128
validation_split = 0.1  # 使用10%训练数据作为验证集

# 训练模型并记录历史数据
start_time = time.time()
history = model.fit(
    x_train, y_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_split=validation_split,
    verbose=1  # 显示训练进度
)
training_time = time.time() - start_time

# ======================
# 4. 模型评估与可视化
# ======================

# 打印训练信息
print(f"\nTraining completed in {training_time:.2f} seconds")
print(f"Test accuracy: {model.evaluate(x_test, y_test, verbose=0)[1]:.4f}")

# ======================
# 可视化1:训练过程曲线
# ======================
plt.figure(figsize=(12, 4))

# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# ======================
# 可视化2:混淆矩阵(修正版)
# ======================
# 获取预测结果
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_test, axis=1)

# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred_classes)

# 手动设置类别标签(MNIST 是 0-9)
class_names = [str(i) for i in range(10)]

# 可视化混淆矩阵
plt.figure(figsize=(10, 8))
sns.heatmap(cm, 
            annot=True,      # 在单元格中显示数值
            fmt='d',         # 数值格式为整数(适用于混淆矩阵的计数)
            cmap='Blues',    # 颜色映射(蓝色渐变)
            xticklabels=class_names,  # X轴标签(类别名称)
            yticklabels=class_names)  # Y轴标签(类别名称)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# ======================
# 可视化3:错误预测样本
# ======================
# 找出预测错误的样本
errors = (y_pred_classes != y_true)
error_samples = x_test[errors]
true_labels = y_true[errors]
pred_labels = y_pred_classes[errors]

# 显示前15个错误样本
plt.figure(figsize=(15, 6))
for i in range(min(15, len(error_samples))):
    plt.subplot(3, 5, i + 1)
    plt.imshow(error_samples[i].reshape(28, 28), cmap='gray')
    plt.title(f"True: {true_labels[i]}, Pred: {pred_labels[i]}")
    plt.axis('off')
plt.tight_layout()
plt.show()

# ======================
# 可视化4:特征图可视化
# ======================
# 获取第一个卷积层的输出
layer_outputs = [layer.output for layer in model.layers[:2]]
activation_model = keras.models.Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(x_test[0:1])

# 显示第一卷积层的特征图
plt.figure(figsize=(12, 6))
first_layer_activation = activations[0]
for i in range(32):  # 显示前32个滤波器
    plt.subplot(4, 8, i + 1)
    plt.imshow(first_layer_activation[0, :, :, i], cmap='viridis')
    plt.axis('off')
plt.suptitle('First Convolutional Layer Activations', fontsize=16)
plt.show()

运行结果 

Epoch 15/15
422/422 [==============================] - 16s 38ms/step - loss: 0.0184 - accuracy: 0.9941 - val_loss: 0.0295 - val_accuracy: 0.9938

Training completed in 343.88 seconds
Test accuracy: 0.9931
313/313 [==============================] - 1s 2ms/step

 ======================
# 3. 新增功能:随机展示20张测试集样本(调整到模型训练之后)
# ======================
def show_random_samples(model, x_test, y_test, num_samples=20):
    """显示随机测试样本及其预测结果"""
    # 确保模型已训练
    if not hasattr(model, 'layers'):
        raise ValueError("Model must be trained first")

    # 生成预测结果
    y_pred = model.predict(x_test)
    y_pred_classes = np.argmax(y_pred, axis=1)

    # 获取真实标签
    y_true = np.argmax(y_test, axis=1)

    # 随机选择样本
    sample_indices = random.sample(range(len(x_test)), num_samples)

    # 创建可视化
    plt.figure(figsize=(16, 18))
    plt.suptitle("Random Handwritten Digit Samples with Predictions\n(Green=Correct, Red=Wrong)",
                 fontsize=16, y=1.03)

    rows, cols = 4, 5
    plt.subplots_adjust(hspace=0.5, wspace=0.3)

    # 使用新版Matplotlib API
    cmap = plt.colormaps.get_cmap('RdYlGn')  # 修复弃用警告

    for i, idx in enumerate(sample_indices):
        ax = plt.subplot(rows, cols, i + 1)
        img = x_test[idx].squeeze()

        # 显示图像
        plt.imshow(img, cmap='gray')

        # 获取标签信息
        true_label = y_true[idx]
        pred_label = y_pred_classes[idx]

        # 设置标题和颜色
        color = 'green' if true_label == pred_label else 'red'
        title = f'True: {true_label}\nPred: {pred_label}'
        plt.title(title, color=color, fontsize=10, pad=8)

        plt.axis('off')

    plt.tight_layout()
    plt.show()

 

混淆矩阵基础结构

1. 矩阵布局(以二分类为例)

2. 关键指标计算

 

TensorFlow Keras 核心组件

1. 常用层类型

2. 构建模型的三种方式

方式1:顺序模型(Sequential API)

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([
    Dense(128, activation='relu', input_shape=(784,)),  # 输入层
    Dense(64, activation='relu'),                       # 隐藏层
    Dense(10, activation='softmax')                     # 输出层
])

方式2:函数式API(Functional API)

from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense

input_layer = Input(shape=(784,))
hidden = Dense(128, activation='relu')(input_layer)
output = Dense(10, activation='softmax')(hidden)
model = Model(inputs=input_layer, outputs=output)

方式3:子类化模型(Subclassing)

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(10, activation='softmax')

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

model = MyModel()

3. 模型编译与训练

# 编译模型
model.compile(
    optimizer='adam',                 # 优化器(自动调参)
    loss='sparse_categorical_crossentropy',  # 损失函数(分类任务)
    metrics=['accuracy']              # 评估指标
)

# 训练模型
history = model.fit(
    x_train, 
    y_train,
    batch_size=32,                    # 每批样本数
    epochs=10,                        # 训练轮次
    validation_split=0.2              # 验证集比例
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2338844.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WordPress自定义页面与文章:打造独特网站风格的进阶指南

文章目录 引言一、理解WordPress页面与文章的区别二、主题与模板层级:自定义的基础三、自定义页面模板:打造专属页面风格四、自定义文章模板:打造个性化文章呈现五、使用自定义字段和元数据:增强内容灵活性六、利用WordPress钩子&…

PHP最新好看UI个人引导页网页源码

PHP最新好看UI个人引导页网页源码 采用PHP、HTML、CSS及JavaScript等前端技术,构建了一个既美观又实用的个人主页解决方案。 源码设计初衷在于提供一个高度可定制、跨平台兼容的模板,让用户无需深厚的编程基础,即可快速搭建出专业且富有创意的…

arkTs:使用回调函数的方法实现子组件向父组件传值

使用回调函数的方法实现子组件向父组件传值 1 主要内容说明2 实现步骤2.1 父组件中定义回调函数2.2 子组件声明并调用回调函数2.3 注意事项 3 源码3.1 父组件3.2 子组件3.3 源码效果显示截图 4 结语5 定位日期 1 主要内容说明 本文源码是一套 父组件与子组件之间双向数据传递的…

VBA 调用 dll 优化执行效率

问题描述 之前excel 用vba写过一个应用,请求的是aws lambda 后端, 但是受限于是云端服务,用起来响应特别慢,最近抽了点时间准备优化下,先加了点日志看看是哪里慢了 主方法代码如下,函数的主要目的是将 Excel 工作簿的…

Django-Friendship 项目常见问题解决方案

Django-Friendship 项目常见问题解决方案 django-friendship Django app to manage following and bi-directional friendships 项目地址: https://gitcode.com/gh_mirrors/dj/django-friendship Django-Friendship 是一个基于 Django 的应用,它允许创建和管…

AI时代下 你需要和想要了解的英文缩写含义

在AI智能时代下,越来愈多的企业都开始重视并应用以及开发AI相关产品,这个时候都会或多或少的涉及到英文,英文还好,但是如果是缩写,如果我们没有提前了解过,我们往往很难以快速Get到对方的意思。在这里&…

2025年对讲机选购指南:聚焦核心参数与场景适配

在无线通信领域,对讲机始终占据着专业通讯工具的独特地位。随着5G时代到来和物联网技术深化,2025年的对讲机市场正呈现智能化、专业化、场景化的升级趋势。面对琳琅满目的产品,选购者需从通信性能、环境适应性、智能集成度三个维度进行综合考…

C/C++ 动态链接详细解读

1. 为什么要动态链接? 1.1 静态链接浪费内存和磁盘空间 静态链接的方式对于计算机内存和磁盘空间浪费非常严重,特别是多进程操作系统的情况下,静态链接极大的浪费了内存空间。在现在的Linux系统中,一个普通的程序会使用的C 语言静…

python flask 项目部署

文章目录 概述 windows 部署准备工作使用 Waitress 部署 Flask 应用 linux 部署**2. 使用 WSGI 服务器**示例:使用 Gunicorn nginx反向代理**5. 使用进程管理工具**示例:使用 Systemd 概述 在 Windows 上使用 Waitress 部署 Flask 应用是一个不错的选择…

Java课程内容大纲(附重点与考试方向)

本文是在传统 Java 教程框架基础上,加入了重点提示与考试思路,适合用于课程备考、知识查漏与面试准备。 第1章:Java语言基础 ⭐ 重点知识: Java平台特点(跨平台性、JVM) JDK、JRE、JVM 区别 Java 程序的…

200+短剧出海平台:谁能成为“海外红果”?

2025年,短剧的国际市场表现令人瞩目。仅在两年前,业界关注的焦点仍是美国市场,如今国产短剧应用已成功打入包括印尼、巴西、美国、墨西哥、印度、菲律宾、泰国、日本、哥伦比亚及韩国在内的多个国家,轻松获得超过500万次下载。 市…

Visio导出清晰图片步骤

在Visio里画完图之后如何导出清晰的图片?👇 ①左上角单击【文件】 ②导出—更改文件类型—PNG/JPG ③分辨率选择【打印机】,大小选择【源】,即可。 ④选择保存位置并命名 也可以根据自己需要选择是否需要【透明底】哈。 选PNG 然…

Linux系统:详解进程等待wait与waitpid解决僵尸进程

本节重点 理解进程等待的相关概念掌握系统调用wait与waitpid的使用方法输出型status参数的存储结构阻塞等待与非阻塞等待 一、概念 进程等待是操作系统中父进程与子进程协作的核心机制,指父进程通过特定方式等待子进程终止并回收其资源的过程。这一机制的主要目的…

IntelliJ IDEA clean git password

IntelliJ IDEA clean git password 清除git密码 方法一:(这个要特别注意啊,恢复默认设置,你的插件什么要重新下载了) File->Manage IDE Settings->Restore Default Settings以恢复IDEA的默认设置(可选); 清空…

【已更新完毕】2025泰迪杯数据挖掘竞赛C题数学建模思路代码文章教学:竞赛智能客服机器人构建

完整内容请看文末最后的推广群 基于大模型的竞赛智能客服机器人构建 摘要 随着国内学科和技能竞赛的增多,参赛者对竞赛相关信息的需求不断上升,但传统人工客服存在效率低、成本高、服务不稳定和用户体验差的问题。因此,设计一款智能客服机器…

ACI EP Learning Whitepaper 3. Disabling IP Data-plane Learning 功能

目录 1. 使用场景 1.1 未disable IP data-plane learning时 1.2 disable IP data-plane learning后 2. 一代Leaf注意事项 3. L2 未知单播注意事项 1. 使用场景 Windows网卡的动态负载均衡绑定模式等。或多个设备共享相同VIP并通过ARP/GARP/ND来宣告VIP切换时,这些外部设…

C++入门七式——模板初阶

目录 函数模板 函数模板概念 函数模板格式 函数模板的原理 函数模板的实例化 模板参数的匹配原则 类模板 类模板的定义格式 类模板的显式实例化 当面对下面的代码时,大家会不会有一种无力的感觉?明明这些代码差不多,只是因为类型不…

【教程】检查RDMA网卡状态和测试带宽 | 附测试脚本

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~ 目录 检查硬件和驱动状态 测试RDMA通信 报错修复 对于交换机的配置,可以看这篇: 【教程】详解配置多台主机通过交换机实现互…

(二)Trae 配置C++ 编译

Trae配置c编译 零 CMake 编译C0.1 下载安装0.2 安装设置0.3 三种编译方式(见 下文 一 二 三)0.4 调试 (见 下文四) 一 使用MSVC方式编译1.1 安装编译环境1.2安装插件1.3 设置文件 二 使用GCC方式2.1 安装编译环境2.1.1下载:[MinGw](https://gcc-mcf.lhmouse.com/)2.1.2安装:(以…

日本公司如何实现B2B商城订货系统的自动化和个性化?

在日本构建具备前后台日文本地化、业务员代客下单、一客一价、智能拆单发货的B2B电商系统,需结合日本商业习惯与技术实现。以下是关键模块的落地方案: 一、系统架构设计 1. 前端本地化 语言与UI适配 采用全日语界面,包含敬语体系&#xff08…