自定义数据 微调CLIP (结合paper)

news2025/5/24 22:17:42

CLIP 是 Contrastive Language-Image Pre-training 的缩写,是一个擅长理解文本和图像之间关系的模型,下面是一个简单的介绍:

优点: CLIP 在零样本学习方面特别强大,它可以(用自然语言)给出图像的描述,并在基于该描述对新图像进行分类方面表现良好,例如,您可以将图像描述为“a”。猫的黑白照片”,CLIP 可以准确地对猫的新照片进行分类,即使它以前没有见过这些特定图像。
训练: CLIP 在从互联网收集的大量文本图像对数据集上进行训练,这使得它能够学习视觉概念及其描述之间的联系。
局限性: CLIP 也有缺点,训练的计算成本可能很高,并且在需要非常具体或抽象概念的任务上,或者对于与训练所用的文本描述非常不同的数据时,可能表现不佳。训练可能会将社会偏见引入模型中。

paper:Learning Transferable Visual Models From Natural Language Supervision

本文用CLIP做一个零样本分类,
CLIP训练的时候用的是图片和文本描述对,并没有分类的标签,那如何让CLIP做零样本分类?
我们需要给出标签的文本,让图像和所有的文本标签进行匹配,得分高的就是匹配到的标签文本。

paper中提到预测哪个文本整体与哪个图像配对,而不是该文本的准确单词。

在这里插入图片描述

下面通过一个kaggle数据集来具体说明。

这里选用indo fashion dataset, 它有15种印度服饰。

在这里插入图片描述
类别如下:
在这里插入图片描述

数据集结构:
其中images文件夹下又有train, val, test文件夹。

在这里插入图片描述

再看一下json文件,
image_path指的是上面images文件夹下的路径,
product_title是和图片对应的文本描述,训练的时候就是用图片和这个文本进行匹配。
class_label训练的时候不需要,最后验证分类是否正确时会用到。

在这里插入图片描述

import需要的库,定义数据集的文件夹,读取json数据

import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor,CLIPModel
from tqdm import tqdm

json_path = 'your_path/train_data.json'
image_path = 'your_path/images/train/'

input_data = []
with open(json_path, 'r') as f:
	for line in f:
		obj = json.loads(line)
		input_data.append(obj)

CLIP模型,如果不能download, 手动下载走offline模式。

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Setting our device to GPU (Cuda) and loading the pre-trained CLIP model.

device = "cuda:0" if torch.cuda.is_available() else "cpu" 

model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

定义Dataloader

# Define a custom dataset
class image_title_dataset():
	def __init__(self, list_image_path, list_txt):
		self.image_path = list_image_path

		# Tokenize text using CLIP's tokenizer
		self.title = clip.tokenize(list_txt)
	
    def __len__(self):
		# Define the length of the dataset
		return len(self.title)

	def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]))
		title = self.title[idx]
		return image, title

这里的dataset需要传入list_image_path和list_txt,
格式是这种:
list_image_path = [‘folder/image1.jpg’,‘folder2/image2.jpg’]
list_txt = [‘description for image1.jpg’ , ‘description for image2.jpg’]
所以要把image_path和product_title都装进list里面。

注意,CLIP的最大序列长度限制在76, 而有些文本描述非常长,需要截掉一部分,
当然截到76长度也有很多种方法,这里简单粗暴就从开头取长度76.

实际代码中,indo数据集不限制长度会报错,而博主觉得这个76可能是text被tokenize之后的token的长度,而不是原文本的长度,
因为把文本截到长度>77也是可以的。
而token的长度是由tokenize的算法决定的。具体最大极限文本长度是多少没测,这里简单地截取到77.

在这里插入图片描述

list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  
  caption = item['product_title'][:77]
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=100, shuffle=True) 

# Function to convert model's parameters to FP32 format
#转精度省内存.
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

if device == "cpu":
    model.float()  # Convert the model's parameters to float if using CPU

optimizer用Adam,参数按paper中的设置.
不过博主的机器容纳不了这么大的batch_size, 具体batch_size设多少合适,需要自己去验证。

在这里插入图片描述
由于数据集比较小,lr设得更小一些。

optimizer = torch.optim.Adam(
    model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 

训练

paper中的训练是这样的
在这里插入图片描述

    for epoch in range(num_epochs):
        pbar = tqdm(train_dataloader, total=len(train_dataloader))
        for batch in pbar:
            optimizer.zero_grad()

            images, texts = batch

            images = images.to(device)
            texts = texts.to(device)

            logits_per_image, logits_per_text = model(images, texts)

            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)

            total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            total_loss.backward()
            if device == "cpu":
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)

            pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")
            if torch.isnan(total_loss).any():
                print("epoch {} loss is NaN".format(epoch))
                epoch = num_epochs
                break

训练中,遇到了这些问题:
loss出现了NaN, 调整batch_size能解决,batch_size不要太小。
loss降不下去了,看看paper中的参数,有哪些需要调整。

训练完之后,找来一张图片测试。
这里又有一些注意事项,
请看paper.
因为训练的时候是图片和一段文本描述匹配的,而不是图片和一个单词。
所以你做零样本分类时,类别文本最好不要只写一个单词,比如只写"Saree"。
你要写"A photo of Saree", 这就成了一个句子,效果就会好一些。

在这里插入图片描述

model, preprocess = clip.load("ViT-B/32", device=device)

checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])

clothing_items = [
    "Saree",
    "Lehenga",
    "Women Kurta",
    "Dupatta",
    "Gown",
    "Nehru Jacket",
    "Sherwani",
    "Men Kurta",
    "Men Mojari",
    "Leggings and Salwar",
    "Blouse",
    "Palazzo",
    "Dhoti Pants",
    "Petticoat",
    "Women Mojari"
]

这里你可能要问,那json文件里面的标签不是这么写的,比如"Women Kurta",json文件的标签是"women_kurta",
为什么不写成"women_kurta"。
这个博主是测试过的,写成json文件里面的标签形式准确率会降低,可能是因为"Women Kurta"更接近自然语言,更贴合训练数据吧。

把15个类别的标签都写成"A photo of {label}" 进行测试。

#你想测的第几张图片
index_ = 500
image_json = input_data[index_]
image_path = os.path.join("indo-fashion-dataset", image_json['image_path'])
image_class = image_json['class_label']
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in clothing_items]).to(device)

with torch.no_grad():
    # Encode image and text
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    # Calculate similarity scores between image and text
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# Normalize image and text features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

# Calculate similarity scores
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the top predictions
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{clothing_items[index]:>16s}: {100 * value.item():.2f}%")

# Display the image with its class label
plt.imshow(plt.imread(image_path))
plt.title(f"Image for class: {image_class}")
plt.axis('off')
plt.show()

请添加图片描述
请添加图片描述

训练中并没有精调参数,也没有训练很多epoch. 效果如下。
统计了一下测试集中7450张图片的top1和top3准确率。
top1: 77.7%, top3: 93.57%

请添加图片描述

paper中说CLIP 模型的 Top-5 准确率明显高于其 Top-1 准确率, 本文虽测的是top3, 但也是明显高于top1的。

在这里插入图片描述

又试了一下这种方法,这里效果并没有变好。

在这里插入图片描述

参考资料1
参考资料2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1608348.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【文件系统】 F2FS文件系统学习

一、基本介绍 1、F2FS History F2FS(Flash Friendly File System)是专门为Nand Flash设计的一个日志型文件系统,于2012年12月合入Linux3.8内核,Google也在2018年(Android P)将其吸收到安卓原生版本中&…

【DL水记】循环神经网络RNN的前世今生,Transformer的崛起,Mamba模型

文章目录 RNN网络简介传统RNN网络结构RNN的分类 长-短期记忆网络 (LSTM)GRU网络横空出世的Transformer网络Self-AttentionVisionTransformer Mamba模型Reference: RNN网络简介 “当人类接触新事物时,他们不会从头开始思考。就像你在阅读这篇文章时,你会根…

最新版的GPT-4.5-Turbo有多强

OpenAI再次用实力证明了,GPT依然是AI世界最强的玩家!在最新的AI基准测试中,OpenAI几天前刚刚发布的GPT-4-Turbo-2024-04-09版本,大幅超越了Claude3 Opus,重新夺回了全球第一的AI王座: 值得一提的是&#xf…

Assign Memory Resources to Containers and Pods

minikube addons enable metrics-server minikube addons enable metrics-server 是一个命令,用于在 Minikube 环境中启用 metrics-server 插件。 Minikube 是一个工具,可以在本地轻松创建和管理单节点 Kubernetes 集群,适合开发和测试。Mini…

二叉树进阶题目

1还原二叉树 #include<bits/stdc.h> using namespace std; const int N1e310; char pre[N],mid[N]; int w[N]; int ans; struct node{int l,r; }t[N]; int build(int prel,int prer,int midl,int midr){int ascpre[prel];int posw[asc];if(midl<pos)t[asc].lbuild(pre…

从 Elastic 的 Go APM 代理迁移到 OpenTelemetry Go SDK

作者&#xff1a;来自 Elastic Damien Mathieu 正如我们之前所分享的&#xff0c;Elastic 致力于帮助 OpenTelemetry&#xff08;OTel&#xff09;取得成功&#xff0c;这意味着在某些情况下构建语言 SDK 的分发版本。 Elastic 在观察性和安全数据收集方面战略性地选择了 OTel…

深入挖掘C语言 ----动态内存分配

开篇备忘录: "自给自足的光, 永远都不会暗" 目录 1. malloc和free1.1 malloc1.2 free 2. calloc和realloc2.1 calloc2.2 realloc 3. 总结C/C中程序内存区域划分 正文开始 1. malloc和free 1.1 malloc C语言提供了一个动态开辟内存的函数; void* malloc (size_t s…

Python中的迭代器:深入理解与实用指南

文章目录 1. 迭代器的基本概念2. Python中的迭代器实例3. 自定义迭代器3.1 例子3.2 详细过程 4. 迭代器的高级应用5. 常见问题与解答 迭代器是Python中非常核心的概念之一&#xff0c;在面试中也会被问到。下面我会详细介绍什么是迭代器&#xff0c;使用方法&#xff0c;以及使…

爬虫 | 基于 requests 实现加密 POST 请求发送与身份验证

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本项目旨在实现一个简单的 Python 脚本&#xff0c;用于向指定的 URL 发送 POST 请求&#xff0c;并通过特定的加密算法生成请求头中的签名信息。这个脚本的背后是与某个特定的网络服务交互&#xff0c;发送特定格式的 JSON 数据…

vi编辑器的用法linux中的vim编辑器大全

vim的介绍 vi 和 vim 命令是linux中强⼤的⽂本编辑器, 由于Linux系统⼀切皆⽂件&#xff0c;⽽配置⼀个服务就是在修改其配置⽂件的参数。 vim 编辑器是运维⼯程师必须掌握的⼀个⼯具, 没有它很多⼯作都⽆法完成。 其中有vi和vim两种 vi和vim的区别 Vim是Vi的升级版本&#…

source map 开发优化工具

什么是 Source map 简单来说 Source map 就是一个存储信息的文件&#xff0c;里面储存着位置信息。 Source map 英文释义&#xff1a;源程序映射。 位置信息&#xff1a;转换后的代码 对应的 转换前的代码 位置映射关系。 有了 Source map&#xff0c;就算线上运行的是转换…

el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面

问题描述 el-menu 该有的页面显示不出来第一个应该想到的问题首先就算检查是否多写了一个 , 导致显示不出来原有的页面 如图所示多写了一个&#xff0c;就会导致该有的页面显示不出来。

nVisual在线网络规划设计软件

●01● nVisual在线网络规划设计软件 在信息化快速发展的今天&#xff0c;网络基础设施的建设与优化变得尤为关键。为了满足现代通信行业对高效、精准的网络规划需求&#xff0c;nVisual在线网络规划设计软件应运而生&#xff0c;它通过集成先进的GIS技术和网络规划工具&#…

OpenHarmony鸿蒙南向开发案例:【智能门铃】

样例简介 智能门铃通过监控来访者信息&#xff0c;告诉主人门外是否有人按铃、有陌生人靠近或者无人状态。主人可以在数字管家中远程接收消息&#xff0c;并根据需要进行远程取消报警和一键开锁。同时&#xff0c;也可以通过室内屏幕获取门外状态。室内屏幕显示界面使用DevEco…

人工智能,如何走好下一步

以下文章来源&#xff1a;金台资讯 2023年&#xff0c;生成式人工智能在全球范围爆火&#xff0c;引发了人工智能领域新一轮的科技竞赛。眼下&#xff0c;人工智能不仅能辅助科学研究与艺术创作&#xff0c;还能实现自动驾驶、打造“无人农场”和“黑灯工厂”&#xff0c;成为解…

数据很重要,ASM磁盘组损坏,使用AMDU来抢救

欢迎关注“数据库运维之道”公众号&#xff0c;一起学习数据库技术! 本期将为大家分享“数据很重要&#xff0c;ASM磁盘组损坏&#xff0c;使用AMDU来抢救”的处置案例。这个案例对个人来说比较经典&#xff0c;下面我将把自己的处理思路进行整理与总结。 环境信息&#xff1…

户用光伏业务解决方案

一、光伏户用痛点 1、推广难 没有成熟的推广与拓客能力&#xff0c;造成时间和金钱大量浪费。 2、管理难 有内部和外部几十或者上百推广人员&#xff0c;管理纷杂&#xff0c;效率低下。 3、无制度 缺少有效的人员管理制度与系统&#xff0c;分辨不出优秀人才&#xff0c…

Let‘s Forkin‘ Dance!Tanssi 激励测试网活动全面启动

作者&#xff1a;Tanssi 编译&#xff1a;OneBlock 作为 Tanssi 社区和生态系统发展的重要推手&#xff0c;Tanssi 基金会推出了 Incentivized TestNet 活动 —— Let’s Forkin’ Dance。该活动旨在激励顶尖参与者&#xff0c;推动社区增长和网络活动&#xff0c;为今年晚些时…

c语言-快速排序

文章目录 代码工程运行结果 这个是升序排列&#xff0c;如果想降序排列,将下面两行的符号反过来即可; arr[right] < arr[key] arr[left] > arr[key]代码工程 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h>void swap(int *v1, int *v2) {int temp *v1;*v1 …

【VTKExamples::Meshes】第 十四期 ExtractEdges

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例ExtractEdges,并解析接口vtkExtractEdges,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~…