少样本苹果分类机器深度学习

news2025/7/27 4:01:23

场景:

样本少,且只有部分进行了标注。负样本类别(不是被标注的那些)不可穷尽,图像处理

步骤:

1,数据增强,扩充确认为普通苹果的样本数量
2,特征提取,使用VGG16模型提取图像特征
3,Kmeans模型尝试普通/其他苹果聚类,查看效果
4,Meanshift模型提升模型表现
5,数据降维PCA处理,提升模型表现

环境:

使用conda 安装:
tensorflow-gpu 2.10.1
keras 2.10.0
使用pip安装:
numpy
scipy
matplotlib
scikit-learn

操作解释:

1,因为数据量太少了,需要对数据进行增强 :有多种方式,旋转,平移,换色等。
2,准备数据集。找一个文件夹,单独建一个文件夹存放标注的样本
在这里插入图片描述
文件夹中只存放普通苹果的样本。
在这里插入图片描述
将增强生成的图片可以放到另一个train_data 文件夹中,然后在该文件夹中放入其他类型的苹果,最终结果如下:
在这里插入图片描述
使用代码:

#对数据进行增强
from keras.preprocessing.image import ImageDataGenerator
path = "E:\\BaiduNetdiskDownload\\DataSet\\Apple"
dst_path = "E:\\BaiduNetdiskDownload\\DataSet\\GenApple"
data_gen = ImageDataGenerator(rotation_range=10,   #这个表示旋转
                              width_shift_range=0.1,
                             height_shift_range=0.02,
                             horizontal_flip=True,   #水平翻转
                             vertical_flip=True)  #垂直翻转
gen = data_gen.flow_from_directory(path, target_size=(224, 224),
                                 batch_size=2,   # 表示每轮循环生成两张照片。
                                 save_to_dir=dst_path,
                                 save_prefix="gen",
                                 save_format="jpg")
for i in range(100):
    gen.next()

可以使用如下代码加载查看:

#from keras.preprocessing.image import load_img,img_to_array  #因为keras版本的缘故,无法适用load_img等方法,使用下面的utils进行加载。
from keras.utils import load_img, img_to_array
img_path = "E:\\BaiduNetdiskDownload\\DataSet\\train_data\\1.jpg"
img = load_img(img_path, target_size=(224, 224))  #224 大小是vgg16模型的适用的
from matplotlib import pyplot as plt
plt.imshow(img)
img = img_to_array(img)
print(img.shape)

3,使用VGG16,提取特征

# 加载模型,提取特征
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
import numpy as np
model_vgg = VGG16(weights="imagenet", include_top=False)
X = np.expand_dims(img, axis=0)  #增加一个维度
X = preprocess_input(X)   #预处理vgg可以使用
features = model_vgg.predict(X)   # 这一步就是借助vgg16提取图片的特征
features = features.reshape(1, 7*7*512)   #这一步就相当于全连接层的展开。

4,上面两段代码只是用vgg提取了单个图片的特征,下面用代码批量提取图片的特征。

#上面是逐个对图片进行提取特征
import os
import numpy as np
folder = "E:\\BaiduNetdiskDownload\\DataSet\\train_data"
dirs = os.listdir(folder)
img_path = []
for i in dirs:
    img_path.append(folder + "\\" + i)
def featureProcess(img_path, model):
    img = load_img(img_path,target_size=(224,224))
    img = img_to_array(img)
    X = np.expand_dims(img, axis=0)
    X = preprocess_input(X) #处理成VGG16可以处理的格式。
    X_VGG = model.predict(X)
    X_VGG = X_VGG.reshape(1, 7*7*512)
    return X_VGG
features_train = np.zeros([len(img_path), 7*7*512])  #这里
for i in range(len(img_path)):
    features_i = featureProcess(img_path[i], model_vgg)
    features_train[i] = features_i
    
X = features_train
print(type(X))
print(X.shape)   #230个苹果,10个普通的200个增强的,其他的是多余的

5,使用kmeans算法预测分类。

# 使用kmeans 模型进行聚类。 使用k均值聚类算法。
from sklearn.cluster import KMeans
cnn_kmeans = KMeans(n_clusters=2, max_iter=2000)  #分为两类, 最大迭代次数是2000次
cnn_kmeans.fit(X)
y_predict_kmeans = cnn_kmeans.predict(X) # 这里只是得到的 0 1 的预测结果,需要统计一下0,1各自的数量
from collections import Counter   #计数器,统计聚类算法分类对应的个数
print(Counter(y_predict_kmeans))

使用下面的代码可视化查看效果

normal_apple_id = 1
fig2 = plt.figure(figsize=(10,40))
for i in range(45):
    for j in range(5):
        img = load_img(img_path[i*5 + j])
        plt.subplot(45,5, i*5 +j +1)
        plt.title("apple" if y_predict_kmeans[i*5 +j]== normal_apple_id else "other")
        plt.imshow(img)
        plt.axis("off")  # 这个的功能是去掉坐标和边框

# 会发现预测效果不是很好

在这里插入图片描述
6,效果太差,尝试换一个聚类算法,使用MeanShift

# 预测效果太差,尝试换一个聚类算法,来查看效果
from sklearn.cluster import MeanShift, estimate_bandwidth
bw = estimate_bandwidth(X, n_samples = 140)  #均值漂移算法先确定,需要以多大宽度进行搜索
                                           # 每140个样本作为一个搜索。
cnn_ms = MeanShift(bandwidth = bw)
cnn_ms.fit(X)
y_predict_ms = cnn_ms.predict(X)
from collections import Counter
print(Counter(y_predict_ms))

可视化查看

normal_apple_id = 0
fig3 = plt.figure(figsize=(10,40))
for i in range(45):
    for j in range(5):
        img = load_img(img_path[i*5 + j])
        plt.subplot(45,5, i*5 +j +1)
        plt.title("apple" if y_predict_ms[i*5 +j]== normal_apple_id else "other")
        plt.imshow(img)
        plt.axis("off")    # 这个的功能是去掉坐标和边框
        # 效果改善了非常多

在这里插入图片描述
7,效果有所改善,可以继续进一步优化。因为一个模型的好坏更多取决于数据,数据多了后难免会有许多噪点。所以我们应该想怎么样去除它们,去除异常点,PCA降维都可以。

# 效果虽然改善了,但是还有其他维度没有去除,肯定是存在噪点的。所以采用PCA进行降维去噪
from sklearn.preprocessing import StandardScaler
stds = StandardScaler()
X_norm = stds.fit_transform(X)  #标准化
from sklearn.decomposition import PCA
pca = PCA(n_components=200)   # n_components 指定要降到的维度
X_pca = pca.fit_transform(X_norm)
var_ratio = pca.explained_variance_ratio_  # 获取pca处理后各维度的方差比例
print(np.sum(var_ratio))
print(X_pca.shape)

在这里插入图片描述
再使用MeanShift进行预测:

from sklearn.cluster import MeanShift, estimate_bandwidth
bw = estimate_bandwidth(X_pca, n_samples = 140)
cnn_pca_ms = MeanShift(bandwidth = bw)
cnn_pca_ms.fit(X_pca)
y_predict_pcs_ms = cnn_pca_ms.predict(X_pca)
normal_apple_id = 0
fig4 = plt.figure(figsize=(10,40))
for i in range(45):
    for j in range(5):
        img = load_img(img_path[i*5 + j])
        plt.subplot(45,5, i*5 +j +1)
        plt.title("apple" if y_predict_pcs_ms[i*5 +j]== normal_apple_id else "other")
        plt.imshow(img)
        plt.axis("off")    # 这个的功能是去掉坐标和边框

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/35347.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国内优秀的多用户商城系统盘点(2022年整理)

电商战略时代,越来越多的企业或商家选择将消费者引入自己建设的独立商城,如零食行业的良品铺子、三只松鼠,从而打造属于自己的IP形象。此时,挑选一款优秀的商城源码是企业的不二之选,既降低了电商从业者和创业者的入门…

Dubbo

致力于提供高性能和透明化的RPC远程服务调用方案,以及SOA服务治理方案 使用zookeeper作为注册中心registry dubbo.config.annotation下相关注解 Service:被该注解修饰的类,会对外发布,包括IP、端口、路径到注册中心Reference&am…

深度学习之路=====10=====>>Resnext(tensorflow2)

简介 类型:2017CVPR 作者: Kaiming He组 和其他轻量级网络特点一样,Resnext也是通过降低参数量来改进模型,提高模型精度的。该模型基于Inception的split-transform-merge范式和VGG堆叠网络,将Resnet的单路卷积变成多…

程序员注意!35岁前,别靠死工资过日子

《2022程序员职场洞察报告》显示,六成受访者的职级和薪酬原地踏步,仅38.3%程序员群体的工作发生过变动,升职加薪、搞副业、自由工作等。 近两年,伴随疫情及行业发展的不确定性,企业招聘以及人才求职双方都变得谨慎。越…

MFC程序设计——用button更改静态文本+显示内容并弹出新内容+静态文本动态打开位图

目录 一、新建基于对话框的MFC编程项目 二、设计界面 2.设置启动项 2.找到资源视图和Dialog 3.拖入控件 三、创建变量(关联对话框与静态文本) 四、写入控件代码 1.在文本上的应用 2.在图像上的应用 2.1初始化的方法 2.2控件导入的方法 3.控件…

TSC TTP244Pro 打码机出现的问题及解决方案

背景: 最近在使用TSC的TTP 244 Pro 打码机的过程中,出现了几个小问题,最后请教了专业的人员才解决了问题,现把需要注意的点记录如下: 准备: 先去TSC的** 官网 **上找关于适用于你的打码机和使用环境的驱…

数据结构(高阶)—— AVL树

目录 一、AVL树的基本概念 二、AVL树的结点定义 三、AVL树的插入 四、AVL树的旋转 1. 右单旋 2. 左单旋 3. 右左双旋 4. 左右双旋 五、AVL树的验证 六、AVL树的性能 七、源代码 一、AVL树的基本概念 二叉搜索树虽可以缩短查找的效率,但如果数据有序或…

CXL 2.0 Device配置空间寄存器组成

目录 1 配置空间 1.1 PCI Power Management Capability Structure 1.2 PCI Express Capability Structure 2 扩展配置空间 2.1 PCIe DVSEC for CXL Device 2.2 GPF DVSEC for CXL Devices 2.3 PCIe DVSEC for Flex Bus Port 2.4 Register Locator DVSEC CXL设备配置空间…

ThinkPHP架构

文章目录一、架构总览1.1、有关常用的一些概念入口文件应用模块控制器操作模型视图驱动行为命名空间【全限定类名】1.补充二、生命周期三、入口文件四、URL访问五、模块设计六、命明空间七、自动加载八、Traits引入九、API友好一、架构总览 ThinkPHP5.0应用基于MVC(…

前后端分页插件

PageHelper 是一个 MyBatis 的分页插件,支持多种数据库,可查看官网&#xff0c;负责将已经写好的 SQL 语句&#xff0c;进行SQL分页加工。无需你自己去封装以及关心 SQL 分页等问题&#xff0c;支持多种分页方式,如从第0或第一页开始, 使用很方便。 添加依赖 <dependency&…

线代 | 【提神醒脑】自用笔记串联二 —— 向量组 · 线性方程组 · 特征值与特征向量

本文总结参考于 kira 2023 线代提神醒脑技巧班。 笔记均为自用整理。加油!ヾ(◍∇◍)ノ゙ 四、向量组 4.1、向量组的线性相关性 ----------------------------------------------------------------------------------------------------------…

Linux 软链接 与 硬链接 的区别

Linux 软链接 与 硬链接 的区别 1、概念 ​  链接文件&#xff1a;是 Linux 操作系统中的一种文件&#xff0c;主要用于解决文件的共享使用问题&#xff0c;而链接的方式分为两种——软链接和硬链接。 ​  inode&#xff1a;是文件系统中存储文件元信息&#xff08;文件的…

Auddly Music Server的编译和安装

本文始于 2021 年 11 月&#xff0c;已经忘记了是什么原因一直没发&#xff0c;这次基本上全部重写了一遍&#xff0c;除了官方的图&#xff0c;所有图片都是重新截取的&#xff1b; 什么是 auddly &#xff1f; auddly 是一款自托管音乐流应用程序。 什么是 auddly-server &am…

模拟实现ATM系统——Java

目录 一、内容简介 二、基本流程 三、具体步骤 1.定义Account类 2.菜单栏 3.账户注册 (1)根据卡号查询账户信息 (2)生成随机卡号 (3)注册账户 4.账户登录 (1)验证码 (2)登录 5.账户展示界面 6.用户操作 (1)查询账户 (2)存款 (3)取款 (4)转账 (5)修改密码 …

旋转的骰子(二)

1.动画——旋转的骰子 上一次我们做了一个旋转的骰子(参看第2讲),这次我们想要点击按钮,让骰子旋转到特定的点数停下来! 2.分析需求——庖丁解牛 a.立方体特定的朝向

LiveData源码分析

先放整理流程图&#xff1a; 1.postValue调2次只触发1次&#xff1f; postValue本质是把新值保存到LiveData的mPendingData成员变量里&#xff0c;版本号1&#xff0c;把执行Runnable post到主线程&#xff0c;在主线程setValue。 多次调用会更新mPendingData的值&#xff0c…

opencv的极线几何

一、理论介绍 当我们使用针孔相机拍摄图像时&#xff0c;我们会丢失一个重要的信息&#xff0c;即图像的深度。一个解决方案如我们的眼睛的方式使用两个相机&#xff08;两只眼睛&#xff09;&#xff0c;这就是所谓的立体视觉。 PO1O2为极平面&#xff0c;l1和l2为极线,e1和…

基于webrtc的数据传输研究总结

什么是webrtc WebRTC (Web Real-Time Communications) 是一项实时通讯技术&#xff0c;它允许网络应用或者站点&#xff0c;在不借助中间媒介的情况下&#xff0c;建立浏览器之间点对点&#xff08;Peer-to-Peer&#xff09;的连接&#xff0c;实现视频流和&#xff08;或&…

最新阿里云ECS服务器S6/C6/G6/N4/R6/sn2ne/sn1ne/se1ne处理器CPU性能详解

阿里云ECS服务器S6/C6/G6/N4/R6/sn2ne/sn1ne/se1ne处理器CPU性能怎么样&#xff1f;阿里云服务器优惠活动机型有云服务器S6、计算型C6、通用型G6、内存型R6、云服务器N4、云服务器sn2ne、云服务器sn1ne、云服务器se1ne处理器CPU性能详解及使用场景说明。 1、阿里云服务器活动机…

全局唯一ID

文章目录前言MongoDB ObjectIdTwitter SnowflakeUUID前言 基于数据库设置其实初始值&#xff0c;以及增量步长。基于ZK,Redis,改良雪花集中式服务生成&#xff0c;远程调用获取id。基于并行空间划分&#xff0c;Snowflake&#xff08;8Byte字节64bit位&#xff09;&#xff0c…