pytorch车牌识别

news2025/6/8 17:01:05

目录

  • 使用pytorch库中CNN模型进行图像识别
    • 收集数据集
    • 定义CNN模型
      • 卷积层
      • 池化层
      • 全连接层
    • CNN模型代码
    • 使用模型

在这里插入图片描述

使用pytorch库中CNN模型进行图像识别

收集数据集

可以去找开源的数据集或者自己手做一个
最终整合成 类别分类的图片文件
在这里插入图片描述

定义CNN模型

卷积层

功能:提取特征

概念

  1. 卷积层输入层通道数

如果输入数据是彩色图像,那么通常情况下,输入数据具有三个通道(红、绿、蓝),因此第一个卷积层的输入通道数应该为3。
如果输入数据是灰度图像,那么输入通道数通常为 1。

  1. 卷积层输出层通道数

卷积层的输出通道数控制着该层提取的特征的数量和复杂度。更多的输出通道意味着网络可以学习更多种类的特征,但过多的输出通道数会导致复杂度和过拟合。

池化层

功能:使卷积层的特征更加明显,对图像进行降维压缩(舍弃无关特征,避免过拟合),提高神经网络的泛华能力。
问题:

  1. 最大池化操作

最大池化操作是一种常用的池化操作,用于减少特征图的空间维度并保留最重要的特征信息
在这里插入图片描述

# 定义最大池化层,池化窗口大小为 2x2,步幅为 2
max_pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)

全连接层

将特征进行整合,然后归一化,对各种分类情况都输入一个概率,根据概率进行分类

CNN模型代码

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
# 进度条工具
from tqdm import tqdm

# 数据集中的类别数
num_classes = len(os.listdir('./数据集'))
# 训练的轮数
num_epochs = 10
# 30次:['陕', '陕', 'U', 'U', '6', '6', '6', '6']
# 10次:['陕', 'A', 'D', '0', '6', '6', '6', '6']

# 一、定义数据预处理和数据加载器
transform = transforms.Compose([
    # 固定图像大小
    transforms.Resize((64, 64)),
    # 将图像转换为灰度图像
    transforms.Grayscale(),
    # 将图像转换为张量
    transforms.ToTensor(),
])
# 使用ImageFolder定义数据集,标签为序号
train_dataset = ImageFolder(root='./数据集', transform=transform)
# 数据加载器,每个批次包含32张图像
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# 二、定义 CNN 模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # 卷积层1  1代表单通道,黑白;32代表输出通道;3代表3*3的卷积核, 1代表在最外围补一圈0
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        # 池化层1  最大池化操作,2代表尺寸减半
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层2 ,32对于卷积层1的输出通道数
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # 全连接层 64输出通道数,16*16代表压缩后的尺寸,生成长度128向量
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    # 前向传播 返回输出结果
    def forward(self, x):
        # 卷积1
        x = self.conv1(x)
        # 激活函数/激化函数 引入非线性变化,增强神经网络复杂性
        x = torch.relu(x)
        # 池化
        x = self.pool(x)
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 三、初始化模型、损失函数和优化器
model = CNNModel()
criterion = nn.CrossEntropyLoss()
# 学习率一般设0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 四、只要当主文件运行时候,才训练模型
if __name__ == "__main__":
    for epoch in range(num_epochs):
        running_loss = 0.0
        print(f'Epoch : {epoch + 1}/{num_epochs}')
        # 显示每轮的进度条
        for images, labels in tqdm(train_loader):
            #  将优化器中存储的之前计算的梯度归零
            optimizer.zero_grad()
            # 将输入图像数据 images 输入到模型中进行前向传播,得到模型的输出
            outputs = model(images)
            # 损失函数 criterion 计算模型 输出 与 真实标签 之间的损失值。
            loss = criterion(outputs, labels)
            # 对损失值进行反向传播,计算模型参数的梯度
            loss.backward()
            # 据优化算法(梯度下降)更新模型参数,最小化损失函数
            optimizer.step()
            running_loss += loss.item()

        # 输出每个 epoch 的平均损失
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')

        # 保存模型
        torch.save(model.state_dict(), 'cnn_model.pt')

使用模型

import torch
from PIL import Image
from torch.utils.data import dataset
from cnn_model import transform, train_dataset, CNNModel

# 加载整个模型
model = CNNModel()
# 将模型设置为评估模式
model.eval()
checkpoint = torch.load('./cnn_model.pt')
model.load_state_dict(checkpoint)


# 使用模型进行预测,识别单个文字图片
def predict_image(image_path):
    image = Image.open(image_path)
    # 转换图片格式
    image = transform(image)
    # 只进行前向传播
    with torch.no_grad():
        output = model(image)
    # ImageFolder输出的标签是文件序号,argmax找到张量output中的最大值
    predicted_idx = torch.argmax(output).item()
    print(predicted_idx)
    # 将输出转换成对应序号的文件名
    if predicted_idx < len(train_dataset.classes) :
        predicted_label = train_dataset.classes[predicted_idx]
        return predicted_label
    else:
        return "null"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1592609.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

R:普通分组柱状图

输入文件实例&#xff08;存为csv格式&#xff09; library(ggplot2) library(ggbreak)# 从CSV文件中读取数据 setwd("C:/Users/fordata/Desktop/研究生/第二个想法(16s肠型&#xff0b;宏基因组功能)/第二篇病毒组/result/otherDB") data <- read.csv("feta…

【软件设计师知识点】一、计算机系统基础知识

文章目录 冯诺依曼计算机CPUCPU 的功能CPU 的组成 数据表示进制转换单位换算定点数浮点小数IEEE 754标准浮点数的运算 校验码奇偶校验码海明码循环冗余校验码&#xff08;CRC&#xff09; 指令系统指令格式寻址方式指令集指令流水线 存储系统存储器的层次化结构存储器的分类相联…

前端vue: 使用ElementUI适配国际化

i18n介绍 i18n&#xff08;其来源是英文单词 internationalization的首末字符i和n&#xff0c;18为中间的字符数&#xff09;是“国际化”的简称。 前端国际化步骤 1、安装i18n插件 安装插件时候&#xff0c;注意必须指定版本号&#xff0c;不然安装会报错。 npm i vue-i1…

【opencv】示例-points_classifier.cpp 使用不同机器学习算法在二维空间中对点集进行分类...

#include "opencv2/core.hpp" // 包含OpenCV核心功能的文件 #include "opencv2/imgproc.hpp" // 包含OpenCV图像处理功能的文件 #include "opencv2/ml.hpp" // 包含OpenCV机器学习模块的文件 #include "opencv2/highgui.hpp" // 包含O…

docker部署Prometheus+AlertManager实现邮件告警

文章目录 一、环境准备1、硬件准备&#xff08;虚拟机&#xff09;2、关闭防火墙&#xff0c;selinux3、所有主机安装docker 二、配置Prometheus1、docker启动Prometheus 三、添加监控节点1、docker启动node-exporter 四、Prometheus配置node-exporter1、修改prometheus.yml配置…

使用Python实现自动化网页答题功能-模拟考试篇

介绍 在驾驶员考试网站上进行模拟考试python自动答题 自动化原理 该脚本使用了自动化模块 DrissionPage 中的 ChromiumPage 类来实现网页的自动化操作。通过定位网页元素和模拟点击操作&#xff0c;完成了选择答案和提交答卷的过程。 用途与注意事项 用途&#xff1a;该脚本…

kafka学习笔记03

SpringBoot2.X项目搭建整合Kafka客户端依赖配置 用自己对应的jdk版本。 先加上我们的web依赖。 添加kafka依赖: SpringBoot2.x整合Kafka客户端adminApi单元测试 设置端口号。 新建一个kafka测试类&#xff1a; 创建一个初始化的Kafka服务。 设置kafka的名称。 测试创建kafka。…

在vue中配置样式 max-width:100px时,发现和width:100px一样没有对应的递增到最大宽度的效果?怎么回事?怎么解决?

原因&#xff1a; 可能时vue的样式大部分和display相关&#xff0c;有很多的联系&#xff0c;导致不生效 解决&#xff1a; 对设置max-width样式的元素设置display:inline-block;属性&#xff0c;即可生效&#xff0c;实现随着子元素的扩展而扩展并增加固定到最大的宽度

使用Postman发送跨域请求实验

使用Postman发送跨域请求 1 跨域是什么&#xff1f;2 何为同源呢?3 跨域请求是如何被检测到的&#xff1f;4 Postman跨域请求测试4.1 后端准备4.2 测试用例4.2.1 后端未配置跨域请求(1) 前端不跨域&#xff08;2&#xff09;前端跨域 4.2.2 后端配置跨域信息&#xff08;1&…

Springboot整合mybatis_plus + redis(使用原生的方式)

首次&#xff0c;创建一个springboot项目&#xff0c;勾选相应的依赖Lombok、Web 添加依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>…

基于Python的深度学习的中文情感分析系统(V2.0),附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

13 Php学习:面向对象

PHP 面向对象 面向对象&#xff08;Object-Oriented&#xff0c;简称 OO&#xff09;是一种编程思想和方法&#xff0c;它将程序中的数据和操作数据的方法封装在一起&#xff0c;形成"对象"&#xff0c;并通过对象之间的交互和消息传递来完成程序的功能。面向对象编…

理想大模型实习面试题6道|含解析

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学&#xff0c;针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总…

在 Google Cloud 上轻松部署开放大语言模型

今天&#xff0c;“在 Google Cloud 上部署”功能正式上线&#xff01; 这是 Hugging Face Hub 上的一个新功能&#xff0c;让开发者可以轻松地将数千个基础模型使用 Vertex AI 或 Google Kubernetes Engine (GKE) 部署到 Google Cloud。 Model Garden (模型库) 是 Google Clou…

Qt | 视频播放器(multimedia、multimediawidgets)

QT +=multimedia 通俗解释: 此代码行告诉编译器在构建应用程序时包含多媒体库。这意味着您的应用程序将能够播放和显示音频和视频文件。 使用分步说明构建模型: 创建一个新的 Qt 项目。 在 .pro 文件中添加以下行: QT += multimedia 导入必要的多媒体头文件: #include &l…

Android 加密之 打包为arr 项目依赖或者为jar

Android 加密之 打包为arr 项目依赖或者为jar 1. 修改build.gradle plugins {//id com.android.application// 1. 修改为libraryid com.android.library }android {namespace com.dzq.iccid2compileSdk 33defaultConfig {//applicationId "com.dzq.iccid2"// 2. 注…

嵌入式驱动学习第七周——I2C子系统

前言 I2C子系统详解&#xff0c;本篇博客从内核源码的角度来看I2C子系统。 嵌入式驱动学习专栏将详细记录博主学习驱动的详细过程&#xff0c;未来预计四个月将高强度更新本专栏&#xff0c;喜欢的可以关注本博主并订阅本专栏&#xff0c;一起讨论一起学习。现在关注就是老粉啦…

爱奇艺APP Android低端机性能优化

01 背景介绍 在智能手机市场上&#xff0c;高端机型经常备受瞩目&#xff0c;但低端机型亦占据了不可忽视的份额。众多厂商为满足低端市场的需求&#xff0c;不断推出低配系列手机。另外过去几年的中高端机型&#xff0c;随着系统硬件的快速迭代&#xff0c;现已经被归类为低端…

学习Rust的第4天:常见编程概念

欢迎来到学习Rust的第四天&#xff0c;基于Steve Klabnik的《The Rust Programming Language》一书。昨天我们做了一个 猜谜游戏 &#xff0c;今天我们将探讨常见的编程概念&#xff0c;例如&#xff1a; Variables 变量Constants 常数Shadowing 阴影Data Types 数据类型Functi…

错题记录(2)

来源&#xff1a; FPGA开发/数字IC笔试系列(10) 笔试刷题 笔试 | 海思2022数字IC模拟卷&#xff08;真题模拟&#xff0c;带解析&#xff09; 运算符优先级