基于卷积神经网络和Pyqt5的猫狗识别小程序

news2025/5/10 10:00:54

任务描述

猫狗分类任务(Dogs vs Cats)是Kaggle平台在2013年举办的一个经典计算机视觉竞赛。官方给出的Kaggle Dogs vs Cats 数据集中包括由12500张猫咪图片和12500张狗狗图片组成的训练集,12500张未标记照片组成的测试集。选手需要在规定时间内搭建模型,并在测试集中取得尽可能高的准确率。

数据集中的图片长这样👆,是不是还蛮贴近生活呢~

本文将设计一个简单的三层卷积神经网络,完成分类任务。基于Pyqt5设计交互界面,在测试集中检验分类结果。工作量不大非常适合新手作为入门项目~最终功能实现如下图所示。


 代码介绍

我觉得新手最有成就感的事情莫过于跑通一个功能了,我们抛开原理不谈,先完整运行程序。代码我是用pycharm写的,新建一个classfy_program项目,在classfy_program项目中创建两个py文件,分别用于训练和测试数据集。此外还创建data文件夹,用于存放训练和测试用的图片。

cat_dog_test.py代码:

import sys
import torch
import torch.nn as nn
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QVBoxLayout, QFileDialog
from PyQt5.QtGui import QPixmap
from PyQt5.QtCore import Qt
from torchvision import transforms
from PIL import Image

# ---------------- 定义模型结构 ----------------
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, 2)  # 猫 vs 狗

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 128 * 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ---------------- 图像预处理 ----------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

# 类别名称(根据你的训练数据文件夹顺序)
class_names = ['猫', '狗']

# ---------------- 主窗口 ----------------
class CatDogClassifier(QWidget):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("猫狗分类器")
        self.setGeometry(100, 100, 400, 500)

        # 加载模型
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CNN().to(self.device)
        self.model.load_state_dict(torch.load("model_checkpoint.pth", map_location=self.device))
        self.model.eval()

        # 创建界面组件
        self.image_label = QLabel("请点击按钮加载图片", self)
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setFixedSize(300, 300)

        self.result_label = QLabel("", self)
        self.result_label.setAlignment(Qt.AlignCenter)

        self.load_button = QPushButton("加载图片", self)
        self.load_button.clicked.connect(self.load_image)

        # 布局
        layout = QVBoxLayout()
        layout.addWidget(self.image_label)
        layout.addWidget(self.result_label)
        layout.addWidget(self.load_button)
        self.setLayout(layout)

    def load_image(self):
        file_path, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg)")
        if file_path:
            # 显示图片
            pixmap = QPixmap(file_path)
            pixmap = pixmap.scaled(300, 300, Qt.KeepAspectRatio)
            self.image_label.setPixmap(pixmap)

            # 进行预测
            image = Image.open(file_path).convert("RGB")
            input_tensor = transform(image).unsqueeze(0).to(self.device)

            with torch.no_grad():
                output = self.model(input_tensor)
                _, predicted = torch.max(output, 1)
                result = class_names[predicted.item()]

            self.result_label.setText(f"预测结果:{result}")

# ---------------- 运行主程序 ----------------
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = CatDogClassifier()
    window.show()
    sys.exit(app.exec_())

cat_dog_train.py代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 数据预处理与加载
transform = transforms.Compose([
    transforms.Resize((224,224)),  # 调整图片尺寸为224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 适用于自定义模型的标准化
])

# 设置训练集和验证集
train_dataset = datasets.ImageFolder(root='data/train', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 自定义卷积神经网络模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 第一层卷积
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # 输入通道为3(RGB),输出通道为32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        # 池化层
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # 全连接层
        self.fc1 = nn.Linear(128 * 28 * 28, 512)  # 调整为经过池化后的维度
        self.fc2 = nn.Linear(512, 2)  # 输出为2(猫 vs 狗)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + ReLU + 最大池化
        x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + ReLU + 最大池化
        x = self.pool(torch.relu(self.conv3(x)))  # 第三层卷积 + ReLU + 最大池化
        x = x.view(-1, 128 * 28 * 28)  # 展平
        x = torch.relu(self.fc1(x))  # 全连接层 + ReLU
        x = self.fc2(x)  # 输出层
        return x


# 实例化模型
model = CNN()

# 使用GPU训练(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # 将模型移动到GPU(如果可用)

# 损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # 训练阶段
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU(如果可用)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 输出每轮训练损失与准确率
    print(
        f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

    # 保存模型检查点
    torch.save(model.state_dict(), 'model_checkpoint.pth')

print("Training complete!")

data数据集的下载链接在这里了:

Download Kaggle Cats and Dogs Dataset from Official Microsoft Download Center


希望看到这里的大家能够点一个小小的赞❤❤,后续会持续更新更多内容~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2372212.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解锁健康养生新境界

在追求高品质生活的当下,健康养生早已超越 “治未病” 的传统认知,成为贯穿全生命周期的生活艺术。它如同精密的交响乐,需饮食、运动、心理与生活习惯多维度协奏,方能奏响生命的强音。 饮食养生讲究 “顺时、适性”。遵循二十四节…

基于OpenCV的人脸识别:EigenFaces算法

文章目录 引言一、概述二、代码解析1. 准备工作2. 加载训练图像3. 设置标签4. 准备测试图像5. 创建和训练识别器6. 进行预测7. 显示结果 三、代码要点总结 引言 人脸识别是计算机视觉领域的一个重要应用,今天我将通过一个实际案例来展示如何使用OpenCV中的EigenFac…

【深度学习新浪潮】智能追焦技术全解析:从算法到设备应用

一、智能追焦技术概述 智能追焦是基于人工智能和自动化技术的对焦功能,通过深度学习算法识别并持续跟踪移动物体(如人、动物、运动器械等),实时调整焦距以保持主体清晰,显著提升动态场景拍摄成功率。其核心优势包括: 精准性:AI 算法优化复杂运动轨迹追踪(如不规则移动…

网络研讨会开发注册中, 5月15日特励达力科,“了解以太网”

在线研讨会主题 Understanding Ethernet - from basics to testing & optimization 了解以太网 - 从基础知识到测试和优化 注册链接# https://register.gotowebinar.com/register/2823468241337063262 时间 北京时间 2025 年 5 月 15 日 星期四 下午 3:30 - 4:30 适宜…

STL?vector!!!

一、前言 之前我们借助手撕string加深了类和对象相关知识,今天我们将一起手撕一个vector,继续深化类和对象、动态内存管理、模板的相关知识 二、vector相关的前置知识 1、什么是vector? vector是一个STL库中提供的类模板,它是存储…

从黔西游船侧翻事件看极端天气预警的科技防线——疾风气象大模型如何实现精准防御?

近日,贵州省黔西市一起载人游船侧翻事故令人痛心。调查显示,事发时当地突遇强风暴雨,水面突发巨浪导致船只失控。这一事件再次凸显:在极端天气频发的时代,传统“经验式防灾”已不足够,唯有依靠智能化的气象预警技术,才能筑牢安全底线。 极端天气预警的痛点:为什么传统方…

FastChat部署大模型

一、前提条件 1、系统环境(使用的 autodl 算力平台) 2、安装相关库 安装 modescope pip3 install -U modelscope # 或使用下方命令 # pip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple安装 fastchat git clone https://gi…

智汇云舟亮相第二十七届北京科博会

5月8日,备受瞩目的第二十七届中国北京国际科技产业博览会(以下简称:北京科博会)在国家会议中心盛大开幕。作为我国科技领域的重要盛会,北京科博会汇聚了众多前沿科技成果与创新力量,为全球科技产业交流搭建…

Redis最新入门教程

文章目录 Redis最新入门教程1.安装Redis2.连接Redis3.Redis环境变量配置4.入门Redis4.1 Redis的数据结构4.2 Redis的Key4.3 Redis-String4.4 Redis-Hash4.5 Redis-List4.6 Redis-Set4.7 Redis-Zset 5.在Java中使用Redis6.缓存雪崩、击穿、穿透6.1 缓存雪崩6.2 缓冲击穿6.3 缓冲…

北斗三号手持终端设备功能与应用

北斗三号卫星系统是我国自主建设、独立运行的全球卫星导航系统。通过多颗不同轨道卫星组成的,这些卫星持续向地球发射携带精确时间和位置信息的信号。地面上的北斗手持终端接收到至少四颗卫星信号后,利用信号传播时间差,通过三角函数等算法&a…

opencv中的图像特征提取

图像的特征,一般是指图像所表达出的该图像的特有属性,其实就是事物的图像特征,由于图像获得的多样性(拍摄器材、角度等),事物的图像特征有时并不特别突出或与无关物体混杂在一起,因此图像的特征…

【JVM-GC调优】

一、预备知识 掌握GC相关的VM参数,会基本的空间调整掌握相关工具明白一点:调优跟应用、环境有关,没有放之四海而皆准的法则 二、调优领域 内存锁竞争cpu占用io 三、确定目标 【低延迟】:CMS、G1(低延迟、高吞吐&a…

shopping mall(document)

shopping mall(document) 商城的原型,学习,优化,如何比别人做的更好,更加符合大众的习惯 抄别人会陷入一个怪圈,就是已经习惯了,也懒了,也不带思考了。 许多产品会迫于…

qiankun微前端任意位置子应用

qiankun微前端任意位置子应用 主项目1、安装qiankun2、引入注册3、路由创建4、路由守卫 二、子项目1、安装sh-winter/vite-plugin-qiankun2、main.js配置3、vite.config.js配置 三、问题解决 主项目 1、安装qiankun npm i qiankun -S2、引入注册 创建存放子应用页面 //whpv…

第十五章,SSL VPN

前言 IPSec 和 SSL 对比 IPSec远程接入场景---client提前安装软件,存在一定的兼容性问题 IPSec协议只能够对感兴趣的流量进行加密保护,意味着接入用户需要不停的调整策略,来适应IPSec隧道 IPSec协议对用户访问权限颗粒度划分的不够详细&…

spring5.x讲解介绍

Spring 5.x 是 Spring Framework 的重要版本升级,全面拥抱现代 Java 技术栈,其核心改进涵盖响应式编程、Java 8支持、性能优化及开发模式创新。以下从特性、架构和应用场景三个维度详细解析: 一、核心特性与架构改进 Java 8 全面支持 Spring …

荣耀A8互动娱乐组件部署实录(第3部分:控制端结构与房间通信协议)

作者:曾在 WebSocket 超时里泡了七天七夜的苦命人 一、控制端总体架构概述 荣耀A8控制端主要承担的是“运营支点”功能,也就是开发与运营之间的桥梁。它既不直接参与玩家行为,又控制着玩家的行为逻辑和游戏规则触发机制。控制端的主要职责包…

levelDB的数据查看(非常详细)

起因:.net大作业天气预报程序(WPF)答辩时,老师问怎么维持数据持久性的,启动时加载的数据存在哪里,我明白老师想考的应该是json文件的解析(正反),半天没答上来存那个文件了(老师默认这个文件是自…

在Fiddler中添加自定义HTTP方法列并高亮显示

在Fiddler中添加自定义HTTP方法列并高亮显示 Fiddler 是一款强大的 Web 调试代理工具,允许开发者检查和操作 HTTP 流量。一个常见需求是自定义 Web Sessions 列表,添加显示 HTTP 方法(GET、POST 等)的列,并通过颜色区…

基于公共卫生大数据收集与智能整合AI平台构建测试:从概念到实践

随着医疗健康数据的爆发式增长,如何有效整合、分析和利用这些数据已成为公共卫生领域的重要挑战。传统方法往往难以应对数据的复杂性、多样性和海量性,而人工智能技术的迅猛发展为解决这些挑战提供了新的可能性。基于数据整合与公共卫生大数据的AI平台旨在构建一个全面的生态…