2. 手写数字预测 gui版

news2025/6/3 7:23:06

2. 手写数字预测 gui版

  • 背景
  • 1.界面绘制
  • 2.处理图片
  • 3. 加载模型
  • 4. 预测
  • 5.结果
  • 6.一点小问题

在这里插入图片描述

背景

做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下

源代码可以去这里https://github.com/Leezed525/pytorch_toy拿

1.界面绘制

在这里插入图片描述

整个页面布局逻辑很简单,搭建一下就好了

class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.net = self.get_net()  # 获取数字预测模型
        self.setWindowTitle("PyQt 数字预测")
        self.setGeometry(100, 100, 500, 550)  # 设置主窗口的初始位置和大小,留出空间给按钮

        self.setFixedSize(500, 550)
        self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.WindowMaximizeButtonHint)

        central_widget = QWidget()  # 创建一个中央 QWidget
        self.setCentralWidget(central_widget)  # 设置中央 QWidget 为主窗口的中心部件
        layout = QVBoxLayout(central_widget)  # 为中央 QWidget 创建一个垂直布局

        # 创建一个水平布局
        operation_layer = QHBoxLayout()  # 创建一个水平布局用于放置操作区域
        left_operation_layer = QVBoxLayout()
        right_operation_layer = QVBoxLayout()

        self.canvas = DrawingCanvas(self)  # 创建 DrawingCanvas 实例
        canvas_label = QLabel("请在此处绘制数字")  # 创建一个标签,提示用户在画布上绘制数字
        canvas_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
        canvas_label.setStyleSheet("font-size: 20px;")  # 设置标签的样式
        left_operation_layer.addWidget(canvas_label)  # 将标签添加到左侧操作区域布局中
        left_operation_layer.addWidget(self.canvas)

        left_operation_layer.setStretch(0, 1)
        left_operation_layer.setStretch(1, 10)  # 设置画布的伸缩比例,使其占据更多空间

        operation_layer.addLayout(left_operation_layer)  # 将左侧操作区域布局添加到操作层布局中
        # 右侧操作区域

        self.predict_label = QLabel("预测结果: ")  # 创建一个标签,显示预测结果
        right_operation_layer.addWidget(self.predict_label)
        self.predict_digit_labels = []
        for i in range(10):
            predict_digit_label = QLabel(f"数字 {i}: 0.00%")  # 创建标签显示每个数字的预测概率
            self.predict_digit_labels.append(predict_digit_label)  # 将标签添加到列表中
        for label in self.predict_digit_labels:
            right_operation_layer.addWidget(label)

        operation_layer.addLayout(right_operation_layer)  # 将右侧操作区域布局添加到操作层布局中
        operation_layer.setStretch(0, 10)
        operation_layer.setStretch(1, 1)

        layout.addLayout(operation_layer)  # 将操作层布局添加到主布局中
        # 按钮区布局
        button_layout = QHBoxLayout()  # 创建一个垂直布局用于放置按钮

        clear_button = QPushButton("清空画布")  # 清空画布按钮
        clear_button.clicked.connect(self.canvas.clear_canvas)  # 连接按钮的点击信号到清空画布方法

        predict_button = QPushButton("预测")  # 清空画布按钮
        predict_button.clicked.connect(self.predict)  # 连接按钮的点击信号到预测方法

        button_layout.addStretch(6)
        button_layout.addWidget(clear_button)
        button_layout.addWidget(predict_button)

        layout.addLayout(button_layout)  # 将按钮布局添加到主布局中

其中稍微有点心智压力的区域就是画图区域,这里配合ai然后再自行修改一下就好了,逻辑就是鼠标按住然后绘制,松开后停止绘制。

canvas代码

class DrawingCanvas(QWidget):
    """
    一个自定义的 QWidget 类,用作绘图画布。
    用户可以在此画布上用鼠标点击并拖动来绘制线条。
    """

    def __init__(self, parent=None):
        super().__init__(parent)  # 调用父类 QWidget 的构造函数
        self.setWindowTitle("绘图画布")  # 设置窗口标题
        self.setGeometry(100, 100, 280, 280)  # 设置窗口的初始位置和大小 (x, y, width, height)
        self.setMinimumSize(280, 280)

        # 创建一个 QImage 对象作为绘图缓冲区
        # 所有的绘图操作都在这个 QImage 上进行,然后整体绘制到屏幕,可以避免闪烁。
        # QImage.Format.Format_RGB32 是 PyQt6 中推荐的 RGBA 格式,支持透明度。
        self.image = QImage(self.size(), QImage.Format.Format_RGB32)
        # 将 QImage 填充为白色。
        self.image.fill(Qt.GlobalColor.white)

        self.drawing = False  # 一个布尔标志,指示当前是否正在进行鼠标拖拽绘图
        self.last_point = QPoint()  # 存储鼠标上次的位置,用于绘制连续的线条

        # 同样,颜色常量需要通过 Qt.GlobalColor 访问。
        self.pen_color = Qt.GlobalColor.black
        self.pen_size = 20

    def paintEvent(self, event):
        """
        绘制事件处理函数。
        每当窗口需要被重新绘制时(例如,首次显示、窗口大小改变、调用 update() 时),
        Qt 就会自动调用这个方法。
        """
        painter = QPainter(self)  # 创建一个 QPainter 对象,指定在当前 QWidget (self) 上进行绘制
        # 将 self.image (绘图缓冲区) 的内容绘制到当前 QWidget 的整个矩形区域内。
        painter.drawImage(self.rect(), self.image, self.image.rect())

    def mousePressEvent(self, event):
        # 检查是否是鼠标左键被按下。
        if event.button() == Qt.MouseButton.LeftButton:
            self.drawing = True  # 设置绘图标志为 True
            self.last_point = event.pos()  # 记录当前鼠标位置作为线条的起始点

    def mouseMoveEvent(self, event):
        """
        鼠标移动事件处理函数。
        当鼠标在窗口内移动时触发。
        """
        # 只有当正在绘图 (self.drawing 为 True) 并且鼠标左键被按住时才执行绘图操作。
        # event.buttons() 返回当前按下的所有鼠标按钮的位掩码,Qt.MouseButton.LeftButton 用于检查左键是否按下。
        if self.drawing and event.buttons() & Qt.MouseButton.LeftButton:
            painter = QPainter(self.image)  # 在 QImage (绘图缓冲区) 上创建 QPainter 进行绘制
            # 设置画笔的颜色、粗细和样式。
            painter.setPen(QPen(QColor(self.pen_color), self.pen_size,
                                Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin))
            # 绘制从上次记录的点到当前鼠标位置的直线
            painter.drawLine(self.last_point, event.pos())
            self.last_point = event.pos()  # 更新 last_point 为当前鼠标位置,为下一次绘制做准备
            self.update()  # 请求窗口重绘。这会间接调用 paintEvent,将 QImage 的最新内容显示到屏幕上。

    def mouseReleaseEvent(self, event):
        """
        鼠标释放事件处理函数。
        当用户释放鼠标按钮时触发。
        """
        # 检查是否是鼠标左键被释放。
        if event.button() == Qt.MouseButton.LeftButton:
            self.drawing = False  # 停止绘图

    def resizeEvent(self, event):
        """
        窗口大小改变事件处理函数。
        当窗口大小改变时触发。
        """
        # 如果新窗口的宽度或高度大于当前 QImage 的尺寸,则需要创建一个新的 QImage。
        if self.width() > self.image.width() or self.height() > self.image.height():
            new_image = QImage(self.size(), QImage.Format.Format_RGB32)
            # 填充新图像为白色
            new_image.fill(Qt.GlobalColor.white)
            painter = QPainter(new_image)
            # 将旧图像的内容绘制到新图像上,以保留已有的绘图。
            painter.drawImage(QPoint(0, 0), self.image)
            self.image = new_image  # 更新 self.image 为新的 QImage
            self.update()  # 请求重绘窗口

    def clear_canvas(self):
        """
        清空画布内容,将整个 QImage 重新填充为白色。
        """
        self.image.fill(Qt.GlobalColor.white)
        self.update()  # 请求重绘以显示空白画布

    def set_pen_size(self, size):
        """
        设置画笔粗细。
        """
        self.pen_size = size

2.处理图片

当布局完成后就只需要处理将图片变成输入的过程就好了,先给代码,在讲解

    def get_image(self):
        """
        获取当前画布上的图像数据。
        返回一个 QImage 对象,包含当前画布的绘图内容。
        """
        image = self.canvas.image
        # 将图像缩放到 28x28 像素并转换为灰度图
        scaled_image = image.scaled(
            28, 28,
            Qt.AspectRatioMode.IgnoreAspectRatio,  # 不保持宽高比
            Qt.TransformationMode.SmoothTransformation  # 平滑缩放
        )
        # 转换为 8 位灰度图
        grayscale_image = scaled_image.convertToFormat(QImage.Format.Format_Grayscale8)

        # 使用 qimage2ndarray.byte_view() 获取 NumPy 数组
        arr_3d = qimage2ndarray.byte_view(grayscale_image)
        arr = arr_3d.squeeze()

        # 将 NumPy 数组转换为 PyTorch 张量
        tensor_image = torch.from_numpy(arr).float()

        # --- 关键修正:添加颜色反转和标准化 ---
        # 1. 将像素值从 [0, 255] 归一化到 [0.0, 1.0]
        tensor_image = tensor_image / 255.0

        # 2. 颜色反转:如果你的模型是基于白色数字黑色背景训练的 而画布是黑色数字白色背景,则需要反转颜色
        tensor_image = 1.0 - tensor_image

        # 3. 标准化:应用训练时使用的均值和标准差
        # MNIST 均值和标准差
        mean = 0.1307
        std = 0.3081
        tensor_image = (tensor_image - mean) / std

        # 添加批次维度和通道维度,使形状变为 (1, 1, 28, 28)
        tensor_image = tensor_image.unsqueeze(0).unsqueeze(0).cuda()

        # --- 可视化 PyTorch 张量 ---
        # 为了可视化,我们先将其恢复到 [0,1] 范围,否则标准化后的值可能很难看
        # 逆标准化 (用于可视化,不影响模型输入)
        # visual_tensor = tensor_image * std + mean
        # # 确保在 [0,1] 范围内
        # visual_tensor = torch.clamp(visual_tensor, 0.0, 1.0)
        # plt.figure(figsize=(2, 2))
        # plt.imshow(visual_tensor.cpu().squeeze().numpy(), cmap='gray')
        # plt.title("input")
        # plt.axis('off')
        # plt.show()
        return tensor_image

其中有几个注意点
1.
目前的画布是白色的,画笔是黑色,但是mnist数据集的底是黑色的,画笔是白色的,因此需要使用

tensor_image = 1.0 - tensor_image

来将颜色取反,不然跟训练数据不一样模型无法良好运行。
2.
QT中的image是Qimage,转换成numpy代码有点麻烦,我这里图省事直接用了qimage2ndarray库,因此只需一行代码

arr_3d = qimage2ndarray.byte_view(grayscale_image)

就完成了这个操作。
3.
在输入到模型之前,要进行数据预处理,如上面的代码中

        # 3. 标准化:应用训练时使用的均值和标准差
        # MNIST 均值和标准差
        mean = 0.1307
        std = 0.3081
        tensor_image = (tensor_image - mean) / std

来优化模型效果。

3. 加载模型

这里的预训练权重就直接用了上一篇文章中训练出来的权重,还给她放到cuda上了,不过这么小的模型其实放不放其实都无所谓,没有太大的影响。

    def get_net(self):
        """
        获取数字预测模型。
        返回一个 DigitCNN 模型实例。
        """
        # 创建并返回一个 DigitCNN 模型实例
        net = DigitCNN()
        net.eval()
        net.cuda()
        net.load_state_dict(torch.load('./digit_CNN.pth'))
        return net

4. 预测

这里就没什么好说的了,就是简单地预测然后将结果同步到gui上了。

    def predict(self):
        """
        预测当前画布上绘制的数字。
        这里可以调用模型进行预测,并更新预测结果标签。
        """
        input = self.get_image()  # 获取当前画布上的图像数据
        # 使用模型进行预测
        with torch.no_grad():
            output = self.net(input)
        # 获取预测结果
        self.update_predict_result(output)

    def update_predict_result(self, output):
        _, predict = output.max(1)  # 获取预测的数字类别
        predict = predict.cpu().numpy()[0]
        # 更新预测结果标签
        self.predict_label.setText(f"预测结果: {predict}")
        # 更新每个数字的预测概率
        probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]
        for i, label in enumerate(self.predict_digit_labels):
            label.setText(f"数字 {i}: {probabilities[i] * 100:.2f}%")

5.结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

6.一点小问题

现在模型是可以用了,但是因为Mnist数据集本身的局限性,已经网络也比较小,泛化性能比较差(但是没差到不能用的地步),所以预测结果又是后会比较奇怪,例如:

.在这里插入图片描述
这是mnist数据集中的数据,可以看出这里的0大部分都是上面闭合,导致模型预测奇怪位置的闭合的0会失准。

还有其中的4大部分都是开口的,并没有闭合4上面的开口,导致写一个很标准的4反倒有时候会预测出错,还有其他的一些问题我就不赘述了。

总之如果想要模型想要获得更好的表现,一是可以增强一下模型的能力,第二个我觉得更重的是把数据好好清洗一下,有些数据真的太差了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2395028.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

特别篇-产品经理(三)

一、市场与竞品分析—竞品分析 1. 课后总结 案例框架:通过"小新吃蛋糕"案例展示行业分析方法,包含四个关键步骤: 明确目标行业调研确定竞品分析竞争策略输出结论 1)行业背景分析方法 PEST分析法:从四个…

【unity游戏开发——编辑器扩展】AssetDatabase公共类在编辑器环境中管理和操作项目中的资源

注意:考虑到编辑器扩展的内容比较多,我将编辑器扩展的内容分开,并全部整合放在【unity游戏开发——编辑器扩展】专栏里,感兴趣的小伙伴可以前往逐一查看学习。 文章目录 前言一、AssetDatabase常用API1、创建资源1.1 API1.2 示例 …

BLE协议全景图:从0开始理解低功耗蓝牙

BLE(Bluetooth Low Energy)作为一种针对低功耗场景优化的通信协议,已经广泛应用于智能穿戴、工业追踪、智能家居、医疗设备等领域。 本文是《BLE 协议实战详解》系列的第一篇,将从 BLE 的发展历史、协议栈结构、核心机制和应用领域出发,为后续工程实战打下全面认知基础。 …

【机器学习基础】机器学习入门核心算法:GBDT(Gradient Boosting Decision Tree)

机器学习入门核心算法:GBDT(Gradient Boosting Decision Tree) 1. 算法逻辑2. 算法原理与数学推导2.1 目标函数2.2 负梯度计算2.3 决策树拟合2.4 叶子权重计算2.5 模型更新 3. 模型评估评估指标防止过拟合 4. 应用案例4.1 金融风控4.2 推荐系…

基于开源AI大模型AI智能名片S2B2C商城小程序源码的销售环节数字化实现路径研究

摘要:在数字化浪潮下,企业销售环节的转型升级已成为提升竞争力的核心命题。本文基于清华大学全球产业研究院《中国企业数字化转型研究报告(2020)》提出的“提升销售率与利润率、打通客户数据、强化营销协同、构建全景用户画像、助…

Spring Cache核心原理与快速入门指南

文章目录 前言一、Spring Cache核心原理1.1 架构设计思想1.2 运行时执行流程1.3 核心组件协作1.4 关键机制详解1.5 扩展点设计1.6 与Spring事务的协同 二、快速入门实战三、局限性3.1 多级缓存一致性缺陷3.2 分布式锁能力缺失3.3 事务集成陷阱 总结 前言 在当今高并发、低延迟…

Redisson学习专栏(四):实战应用(分布式会话管理,延迟队列)

文章目录 前言一、为什么需要分布式会话管理?1.1 使用 Redisson 实现 Session 共享 二、订单超时未支付?用延迟队列精准处理2.1 RDelayedQueue 核心机制2.2 订单超时处理实战 总结 前言 在现代分布式系统中,会话管理和延迟任务处理是两个核心…

java程序从服务器端到Lambda函数的迁移与优化

source:https://www.jfokus.se/jfokus24-preso/From-Serverful-to-Serverless-Java.pdf 从传统的服务器端Java应用,到如今的无服务器架构。这不仅仅是技术名词的改变,更是开发模式和运维理念的一次深刻变革。先快速回顾一下我们熟悉的“服务…

使用yocto搭建qemuarm64环境

环境 yocto下载 # 源码下载 git clone git://git.yoctoproject.org/poky git reset --hard b223b6d533a6d617134c1c5bec8ed31657dd1268 构建 # 编译镜像 export MACHINE"qemuarm64" . oe-init-build-env bitbake core-image-full-cmdline 运行 # 跑虚拟机 export …

Linux系统下安装配置 Nginx

Windows Nginx https://nginx.org/en/download.htmlLinux Nginx https://nginx.org/download/nginx-1.24.0.tar.gz解压 tar -zxvf tar -zxvf nginx-1.18.0.tar.gz #解压安装依赖(如未安装) yum groupinstall "Development Tools" -y yum…

LiveGBS作为下级平台GB28181国标级联2016|2022对接海康大华宇视华为政务公安内网等GB28181国标平台查看级联状态及会话

LiveGBS作为下级平台GB28181国标级联2016|2022对接海康大华宇视华为政务公安内网等GB28181国标平台查看级联状态及会话 1、GB/T28181级联概述2、搭建GB28181国标流媒体平台3、获取上级平台接入信息3.1、向下级提供信息3.2、上级国标平台添加下级域3.3、接入LiveGBS示例 4、配置…

Gartner《2025 年软件工程规划指南》报告学习心得

一、引言 软件工程领域正面临着前所未有的变革与挑战。随着生成式人工智能(GenAI)等新兴技术的涌现、市场环境的剧烈动荡以及企业对软件工程效能的更高追求,软件工程师们必须不断适应和拥抱变化,以提升自身竞争力并推动业务发展。Gartner 公司发布的《2025 年软件工程规划…

Java Class类文件结构

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…

quasar electron mode如何打包无边框桌面应用程序

预览 开源项目Tokei Kun 一款简洁的周年纪念app,现已发布APK(安卓)和 EXE(Windows) 项目仓库地址:Github Repo 应用下载链接:Github Releases Preparation for Electron quasar dev -m elect…

【HW系列】—Windows日志与Linux日志分析

文章目录 一、Windows日志1. Windows事件日志2. 核心日志类型3. 事件日志分析实战详细分析步骤 二、Linux日志1. 常见日志文件2. 关键日志解析3. 登录爆破检测方法日志分析核心要点 一、Windows日志 1. Windows事件日志 介绍:记录系统、应用程序及安全事件&#x…

VIN码识别解析接口如何用C#进行调用?

一、什么是VIN码识别解析接口? VIN码不仅是车辆的“身份证”,更是连接制造、销售、维修、保险、金融等多个环节的数字纽带。而VIN码查询API,正是打通这一链条的关键工具。 无论是汽车电商平台、二手车商、维修厂,还是保险公司、金…

动态规划之网格图模型(一)

文章目录 动态规划之网格图模型(一)LeetCode 64. 最小路径和思路Golang 代码 LeetCode 62. 不同路径思路Golang 代码 LeetCode 63. 不同路径 II思路Golang 代码 LeetCode 120. 三角形最小路径和思路Golang 代码 LeetCode 3393. 统计异或值为给定值的路径…

PCB设计实践(三十)地平面完整性

在高速数字电路和混合信号系统设计中,地平面完整性是决定PCB性能的核心要素之一。本文将从电磁场理论、信号完整性、电源分配系统等多个维度深入剖析地平面设计的关键要点,并提出系统性解决方案。 一、地平面完整性的电磁理论基础 电流回流路径分析 在PC…

使用ray扩展python应用之流式处理应用

流式处理就是数据一来,咱们就得赶紧处理,不能攒批再算。这里的实时不是指瞬间完成,而是要在数据产生的那一刻,或者非常接近那个时间点,就做出响应。这种处理方式,我们称之为流式处理。 流式处理的应用场景…

IP证书的作用与申请全解析:从安全验证到部署实践

在网络安全领域,IP证书(IP SSL证书)作为传统域名SSL证书的补充方案,专为公网IP地址提供HTTPS加密与身份验证服务。本文将从技术原理、应用场景、申请流程及部署要点四个维度,系统解析IP证书的核心价值与操作指南。 一…