李哥深度学习班学习笔记——图像识别

news2026/3/18 23:26:02
一、导入依赖库​ import random #用于设置随机种子保证实验可复现 import torch #Pytorh核心库构建和训练神经网络 import torch.nn as nn #Pytorch神经网络层模块 import numpy as np #数值计算库处理矩阵 import os #操作系统交互处理文件路径 from PIL import Image #读取图片数据 from torch.utils.data import Dataset, DataLoader #自定义数据集和数据加载器 from tqdm import tqdm #进度条显示方便查看数据加载和训练进度 from torchvision import transforms #图像预处理裁剪、旋转、转张量等 import time #计时统计训练耗时 import matplotlib.pyplot as plt #绘制loss和acc曲线 from model_utils.model import initialize_model #自定义函数初始化预训练模型 ​二、读取数据集1.初始化方法Dataset是PyTorch的抽象类必须实现__init__/__getitem__/__len__三个方法用于定义数据加载逻辑。class food_Dataset(Dataset): #初始化方法 def __init__(self, path, modetrain): self.mode mode #保存数据模式train训练集/val验证集/semi半监督集 #semi模式只有图像没有标签其他模式有模式标签 if mode semi: self.X self.read_file(path) #X图像数据numpy数据shapeN224,224,3 else: self.X, self.Y self.read_file(path) self.Y torch.LongTensor(self.Y) #标签转为LongTensor if mode train: self.transform train_transform else: self.transform val_transform #读取文件方法 def read_file(self, path): #半监督学习读取无标签图像 if self.mode semi: file_list os.listdir(path) #列出路径下所有文件 xi np.zeros((len(file_list), HW, HW, 3), dtypenp.uint8) #初始化图像数组N张图224,224,3通道uint8类型 # 列出文件夹下所有文件名字 for j, img_name in enumerate(file_list): #遍历每个文件 img_path os.path.join(path, img_name) #拼接完整文件路径 img Image.open(img_path) #打开图片 img img.resize((HW, HW)) #统一大小缩放为224*224 xi[j, ...] img #将PIL图像转为numpy数组存入xi print(读到了%d个数据 % len(xi)) #打印读取数量 return xi #返回图像数据 #非半监督模式读取有标签图像 else: for i in tqdm(range(11)): #遍历11个类别 file_dir path /%02d % i #拼接类别文件夹路径 file_list os.listdir(file_dir) #列出该类别下所有文件 #初始化当前类别的图像/标签数组 xi np.zeros((len(file_list), HW, HW, 3), dtypenp.uint8) yi np.zeros(len(file_list), dtypenp.uint8) # 列出文件夹下所有文件名字 for j, img_name in enumerate(file_list): img_path os.path.join(file_dir, img_name) img Image.open(img_path) img img.resize((HW, HW)) xi[j, ...] img #存入图像 yi[j] i #存入标签当前类别i #合并所有类别的数据 if i 0: #第一类直接赋值 X xi Y yi else: #后续类别拼接axis0:按样本数维度拼接 #np.concatenate是合并不同类别的数据 X np.concatenate((X, xi), axis0) Y np.concatenate((Y, yi), axis0) print(读到了%d个数据 % len(Y)) #打印总数 return X, Y #返回图像标签 #获取单条数据方法训练时核心调用 def __getitem__(self, item): #半监督学习模式返回变换后的图像原始图像用于生成伪标签 if self.mode semi: return self.transform(self.X[item]), self.X[item] #非半监督模式返回变换后的图像标签 else: return self.transform(self.X[item]), self.Y[item] def __len__(self): return len(self.X)三、定义半监督学习数据集类生成伪标签半监督学习核心用训练好的模型给无标签数据打为标签筛选置信度高的样本加入训练。no_label_loader是无标签数据的DateLoaderthres0.99是置信度阈值只有模型预测置信度大于等于99%的样本才能被选为半监督样本保证伪标签准确。Softmax层将模型输出的logits转为概率dim1 表示按类别维度计算torch.no_grag()是推理模式的关键关闭梯度计算避免占用GPU内存class semiDataset(Dataset): #初始化方法 def __init__(self, no_label_loder, model, device, thres0.99): x, y self.get_label(no_label_loder, model, device, thres) #判断是否符合置信度的样本 if x []: self.flag False #无有效样本 else: self.flag True #有有效样本 self.X np.array(x) #为标签图像 self.Y torch.LongTensor(y) #伪标签转为LongTensor self.transform train_transform #用训练集变换 #生成伪标签方法 def get_label(self, no_label_loder, model, device, thres): model model.to(device) #模型迁移到指定设备GPU/CPU #初始化存储预测置信度、预测标签、有效图像、有效标签 pred_prob [] #预测置信度 labels [] #预测标签 x [] #有效图像 y [] #有效标签 #softmax层将模型输出的logits转为概率dim1表示按类别维度计算 soft nn.Softmax() #无梯度计算推理模式节省内存加快速度 with torch.no_grad(): #遍历无标签数据 for bat_x, _ in no_label_loder: bat_x bat_x.to(device)#数据移到设备 pred model(bat_x)#模型预测输出logits pred_soft soft(pred)#转为概率 #获取每个样本的最大概率置信度和对应标签 pred_max, pred_value pred_soft.max(1) #转为列表并存入 pred_prob.extend(pred_max.cpu().numpy().tolist()) labels.extend(pred_value.cpu().numpy().tolist()) #筛选置信度阈值的样本 for index, prob in enumerate(pred_prob): if prob thres: #加入有效图像原始图像 x.append(no_label_loder.dataset[index][1]) #调用到原始的getitem #加入对应伪标签 y.append(labels[index]) return x, y#返回有效图像和伪标签 def __getitem__(self, item): return self.transform(self.X[item]), self.Y[item] def __len__(self): return len(self.X)四、构建半监督数据加载器封装半监督数据的初始化逻辑返回DataLoader或NoneshffleFalse:半监督样本已筛选不能洗牌def get_semi_loader(no_label_loder, model, device, thres): # 初始化半监督数据集 semiset semiDataset(no_label_loder, model, device, thres) # 无有效样本返回None if semiset.flag False: return None # 有有效样本构建DataLoader else: semi_loader DataLoader(semiset, batch_size16, shuffleFalse) return semi_loader五、定义模型使用ResNet18进行迁移学习迁移学习把大佬的特征提取器即模型和参数拿来训练有预训练模型的时候尽量用预训练模型from torchvision.models import resnet18 #加载ImageNet预训练权重w比随机初始化收敛更快、精度更高 model resnet18(pretrained True)#把大佬训练好的w拿来用 #获得最后一层全连接的输入特征值 in_fetures model.fc.in_features#模型的维度 #替换最后一层全连接将1000分类改为11分类 model.fc nn.linear(in_fetures,11)六、训练和验证主函数1.交叉熵损失用来衡量预测的概率分布和真实标签有多像2.np.sum(np.argmax(pred.detach().cpu().numpy(), axis1) target.cpu().numpy())np.argmax():计算最大值下标axis1横轴[ [[0.1, 0.5, 0.4], 1 1 true[0.2, 0.3, 0.5], 2 0 false[0.2, 0.1, 0.7], 2 2 true] ]3.为什么要做无梯度计算计算梯度是为了训练模型不计算梯度视为了省资源、提速、不破坏模型。在训练时需要算梯度用来更新权重测试和部署时只需要输出结果。def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path): #模型移到设备 model model.to(device) #半监督学习加载器初始为None semi_loader None #存储训练 plt_train_loss [] plt_val_loss [] #验证的loss和acc用于绘图 plt_train_acc [] plt_val_acc [] max_acc 0.0#最佳验证准确率用于保存最优模型 #遍历每个epoch for epoch in range(epochs): #初始化本轮的loss和acc train_loss 0.0 val_loss 0.0 train_acc 0.0 val_acc 0.0 semi_loss 0.0 semi_acc 0.0 start_time time.time()#记录本轮开始时间 #训练阶段 model.train()#模型设为训练模式启动dropout/bn等训练层 #遍历训练集 for batch_x, batch_y in train_loader: x, target batch_x.to(device), batch_y.to(device)#数据移到设备 pred model(x)#模型预测 train_bat_loss loss(pred, target)#计算损失交叉熵损失分类任务标配 train_bat_loss.backward()#反向传播计算梯度梯度回传 optimizer.step() #优化器更新权重 optimizer.zero_grad()# 更新参数 之后要梯度清零否则会累积梯度 train_loss train_bat_loss.cpu().item()#累加损失转为cpu值避免GPU内存占用 train_acc np.sum(np.argmax(pred.detach().cpu().numpy(), axis1) target.cpu().numpy())#计算准确率预测标签与真实标签对比求和 #计算本轮训练平均loss和acc plt_train_loss.append(train_loss / train_loader.__len__()) plt_train_acc.append(train_acc/train_loader.dataset.__len__()) #记录准确率 #如果有半监督数据训练半监督样本 if semi_loader! None: for batch_x, batch_y in semi_loader: x, target batch_x.to(device), batch_y.to(device) pred model(x) semi_bat_loss loss(pred, target) semi_bat_loss.backward() optimizer.step() # 更新参数 之后要梯度清零否则会累积梯度 optimizer.zero_grad() semi_loss semi_bat_loss.cpu().item()#累加半监督损失 semi_acc np.sum(np.argmax(pred.detach().cpu().numpy(), axis1) target.cpu().numpy())#累加半监督准确率 print(半监督数据集的训练准确率为, semi_acc/train_loader.dataset.__len__()) #----------------------------验证阶段---------------------------------- model.eval()#模型设为评估模式关闭dropout/bn更新 #无梯度计算 with torch.no_grad(): for batch_x, batch_y in val_loader: x, target batch_x.to(device), batch_y.to(device) pred model(x) val_bat_loss loss(pred, target) val_loss val_bat_loss.cpu().item() val_acc np.sum(np.argmax(pred.detach().cpu().numpy(), axis1) target.cpu().numpy()) plt_val_loss.append(val_loss / val_loader.__len__()) plt_val_acc.append(val_acc / val_loader.dataset.__len__()) #---------------------------------------------------------------------------- #生成半监督数据保存最优模型 #每3个epoch且验证准确率0.6时生成半监督数据 if epoch%3 0 and plt_val_acc[-1] 0.6: semi_loader get_semi_loader(no_label_loader, model, device, thres) #保存最优模型验证准确率最高时 if val_acc max_acc:#只有验证准确率超过之前的最大值时才保存避免保存差的模型 torch.save(model, save_path) max_acc val_acc #打印本轮训练信息 print([%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f % \ (epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1], plt_val_acc[-1]) ) # 打印训练结果。 注意python语法 %2.2f 表示小数位为2的浮点数 后面可以对应。 #绘制loss曲线 plt.plot(plt_train_loss) plt.plot(plt_val_loss) plt.title(loss) plt.legend([train, val]) plt.show() #绘制acc曲线 plt.plot(plt_train_acc) plt.plot(plt_val_acc) plt.title(acc) plt.legend([train, val]) plt.show()如果训练loss下降但验证loss上升说明过拟合如果两者都下降说明收敛正常。七、主程序初始化启动训练#数据路径 #训练集文件路径 train_path rF:\pycharm\beike\classification\food_classification\food-11_sample\training\labeled #验证集文件路径 val_path rF:\pycharm\beike\classification\food_classification\food-11_sample\validation #无标签文件路径 no_label_path rF:\pycharm\beike\classification\food_classification\food-11_sample\training\unlabeled\00 #初始化数据集 train_set food_Dataset(train_path, train) val_set food_Dataset(val_path, val) no_label_set food_Dataset(no_label_path, semi) #构建DataLoader批处理洗牌多线程 train_loader DataLoader(train_set, batch_size16, shuffleTrue) val_loader DataLoader(val_set, batch_size16, shuffleTrue) no_label_loader DataLoader(no_label_set, batch_size16, shuffleFalse) #训练参数 lr 0.001#学习率 loss nn.CrossEntropyLoss()#损失函数分类任务 #优化器AdamX带权重衰减的Adam防止过拟合 optimizer torch.optim.AdamW(model.parameters(), lrlr, weight_decay1e-4) #设备优先GPU无则CPU device cuda if torch.cuda.is_available() else cpu #模型保存路径 save_path model_save/best_model.pth #训练轮次 epochs 15 #半监督学习置信度阈值 thres 0.99 #启动训练 train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2408660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…

网络编程(Modbus进阶)

思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式

一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明&#xff1a;假设每台服务器已…

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明

LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造&#xff0c;完美适配AGV和无人叉车。同时&#xff0c;集成以太网与语音合成技术&#xff0c;为各类高级系统&#xff08;如MES、调度系统、库位管理、立库等&#xff09;提供高效便捷的语音交互体验。 L…

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…