【第三十九周】ViLT

news2025/6/9 16:27:31

ViLT

  • 摘要
  • Abstract
  • 文章信息
  • 介绍
    • 提取视觉特征的方式的演变
    • 模态融合的两种方式
    • 四种不同的 VLP 模型
    • Q&A
  • 方法
    • 模型结构
    • 目标函数
    • Whole Word Masking(WWM)
  • 实验结果
  • 总结

摘要

本篇博客介绍了ViLT(Vision-and-Language Transformer),这是一种简洁高效的视觉语言预训练模型,旨在解决传统 VLP 方法中过度依赖深度卷积网络和区域监督带来的计算开销高和表达受限等问题。该模型的核心思想是将图像与文本统一通过线性投影进行嵌入处理,并在共享的 Transformer 中进行深度跨模态融合,从而彻底移除视觉特定的 CNN 主干和目标检测器。ViLT 针对图像特征提取复杂、慢的问题,提出了基于 patch projection 的轻量级视觉嵌入方式,使视觉输入的处理成本降至最低,并结合 Masked Language Modeling(MLM)与 Image-Text Matching(ITM)等预训练任务提升下游性能。实验表明,ViLT 在多项视觉语言任务中以极低的计算资源达成与传统重模型相近甚至更优的效果。其优势在于结构极简、推理速度快、易于部署,但也存在对细粒度视觉语义建模能力不足的局限。

Abstract

This blog introduces ViLT (Vision-and-Language Transformer), a simple and efficient vision-language pretraining model designed to address the high computational cost and limited expressive power caused by the heavy reliance on deep convolutional networks and region-level supervision in traditional VLP methods. The core idea of ViLT is to unify image and text inputs through linear projection and perform deep cross-modal fusion within a shared Transformer, thereby completely removing the need for CNN-based visual backbones and object detectors. To tackle the complexity and latency of visual feature extraction, ViLT employs a lightweight patch projection-based embedding strategy that minimizes visual input processing overhead, while leveraging pretraining objectives such as Masked Language Modeling (MLM) and Image-Text Matching (ITM) to enhance downstream performance. Experiments show that ViLT achieves comparable or even better results than traditional heavy models on various vision-language tasks, with significantly lower computational cost. Its strengths lie in its minimal design, fast inference, and ease of deployment, although it has limitations in modeling fine-grained visual semantics.


文章信息

Title:ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision
Author:Wonjae Kim, Bokyung Son, Ildoo Kim
Source:https://arxiv.org/abs/2102.03334


介绍

在ViLT的提出之前,视觉与语言预训练(VLP)在跨模态任务(如图文检索、视觉问答等)上取得了显著进展。然而,现有方法大多依赖计算量大的图像特征提取模块(如基于目标检测的区域特征或CNN提取的全局特征),这不仅降低了效率,还限制了模型的表达能力。本文介绍的 ViLT(Vision-and-Language Transformer)是一种极简的VLP模型,其核心创新是完全摒弃传统卷积或区域监督方法,改用类似文本处理的“无卷积”方式直接建模视觉输入。实验证明,ViLT 比之前的方法快数十倍,同时在下游任务上保持甚至超越SOTA性能。

提取视觉特征的方式的演变

在这里插入图片描述

之前的 VLP 模型的 text embedding 基本上都使用类 BERT 结构,但是 visual embedding 存在着差异。往往处理视觉特征的网络越复杂,模型效果就越好,所以提取视觉特征是现有 VLP 模型的瓶颈。如上图所示,获取visual embedding的方法总共有三大类:

  • Region feture(区域特征):通常采用 Faster R-CNN 二阶段检测器提取区域性特征,这种方法的成本是最高的,比如图像经过ResNet101 backbone(一系列的卷积层用于提取图像的feature maps)提取特征,再经过RPN(Region Proposal Network,区域生成网络)得到一些RoI(Region of interest,感兴趣区域),然后使用NMS(Non-Maximum Suppression,非极大抑制)过滤冗余的RoI,最后经过RoI Head(在RPN生成的候选区域中,对候选区域进行分类和边界框回归的神经网络模块)得到一些一维的向量(Region Feature),也就是一个个bounding box。
  • Grid Feature(网格特征):将CNN backbone得到的feature map,作为网格特征,大大降低了计算量。比如将ResNet50最后得到的7×7特征图拉直为一个序列,或者是上一层的14×14的特征图。
  • Patch Projection(块投影):受ViT中patch projection层效果的启发,使用类似ViT模型中的patch projection层直接得到patch embeddings。ViLT是首个这么做的,第一,这种方法不需要使用额外的网络;其次,不需要缓存特征,Region feture和Grid feature都需要先使用预训练的模型提前抽取好图片特征,然后再训练。虽然这样训练还是比较轻量的,但在部署的时候是一个很大的局限性,因为真实场景里每时每刻都在生成新数据,都需要抽取新数据的特征,这时推理速度就是一大瓶颈。

在文本方面,这些模型都是基本一样的,通过一个Embedding矩阵,变成一个个的word token。得到了视觉的序列和文本的序列后输入到Modality Interaction(基本都是Transformer)进行模态之间的融合。

模态融合的两种方式

  • Single-stream:单通路结构,文本特征和图像特征直接concat连接,然后输入一个transformer进行交互;
  • Dual-stream:双通道结构,文本特征和图像特征分别过一个文本模型和图像模型,充分挖掘单模态特征,然后再经过一个transformer layer做融合。

这两种方法的效果相差不大,但 Dual-stream 的参数量和计算量更大,成本更高,所以在ViLT中作者采用了 Single-stream 的方式。

四种不同的 VLP 模型

作者总结了视觉-语言模型(Vision-and-Language Models)的 四种计算范式分类,通过比较视觉嵌入(VE)、文本嵌入(TE)和模态交互(MI) 三个模块的相对计算量(矩形高度表示),展示不同模型的设计差异。
在这里插入图片描述

  1. Visual Embedder (VE):图像特征提取(如CNN、区域检测器、线性投影)。
  2. Textual Embedder (TE):文本特征提取(如BERT、Word2Vec)。
  3. Modality Interaction (MI):跨模态交互(如Transformer注意力机制)。

(a)VE > TE > MI
视觉嵌入计算量最大,文本次之,模态交互最小。VSE/ SCAN等模型的做法,视觉特征的处理远大于文本特征,模态融合只使用了简单的点乘操作或很简单的浅层attention网络。
(b)VE = TE > MI
视觉和文本嵌入计算量相当,均大于模态交互。CLIP就是属于此种类别,每个模态单独使用transformer encoder,两者计算量相当,特征融合部分,只是简单的计算了一下图文特征的相似性。CLIP特别适合需要图文特征(GroupViT/GLIP等)或者是图文检索的任务,但做VQA或者visual reasoning(视觉推理,更难的VQA)这种需要视觉推理的任务时,会稍逊一筹。因为一个简单的不可学习的点乘,是没法做深层次的特征融合和分析的。
(c)VE > MI > TE
模态交互计算量显著增加,视觉/文本嵌入比重降低。比如ViLBERT、UNITER、Pixel-BERT等。
(d) MI > VE = TE
视觉嵌入极简(线性投影,VE耗时低),文本用标准BERT(TE),模态交互(MI)成为主要计算部分。ViLT 就是属于此种模型。

Q&A

对于文本侧,使用 Transformer 是最好的选择,但对于 VLP 来说,就必须将图像的像素转成带有语义性质的具有离散性的特征。这样才能与文本的 tokens 有效匹配,后续才能进行两种模态的融合。
那么为什么之前的 VLP 用目标检测器来处理图像特征?

  • 图像的像素太多,Transformer 处理不了这么长的序列。
  • 目标检测是一个离散化的过程,返回的是 bounding box,它代表一个个物体,有明确的语义信息(可类比文本中的token),而且还是离散化的。
  • 当时的VLP 下游任务主要是VQA(Visual Question Answering)、Image Captioning、Image Retrieval 等,这些任务往往都跟物体有非常直接的联系,有非常强的对物体的依赖性。

方法

模型结构

下图展示了 ViLT(Vision-and-Language Transformer) 的核心架构,它是一种纯Transformer的多模态模型,完全摒弃了CNN和区域检测器,实现了高效的视觉-语言联合建模。
在这里插入图片描述
输入:

  • 文本经过pre-trained BERT tokenizer得到word embedding(前面有CLS token,图中*表示);
  • 图片经过ViT patch projection层得到patch embedding(也是用*表示CLS token);
  • 文本特征+文本位置编码+模态嵌入得到最终的text embedding,图像这边也是类似的操作得到image embedding;二者concat拼接之后,一起输入transformer layer,然后做MSA交互(多头自注意力)

注意:在Single-stream模型中,图文特征是直接拼在一起输入一个transformer,为了区分两种不同的模态,需要加上模态嵌入。
下面结合论文中的公式介绍具体流程:
在这里插入图片描述
公式1: 文本输入嵌入
t class t_\text{class} tclass:文本模态的 [CLS] 向量; t i T t_i^T tiT:第 i i i 个 token 的词向量嵌入; T pos T^\text{pos} Tpos:文本的位置嵌入(position embeddings); [ ⋅ ; ⋅ ] [\cdot; \cdot] [;]:表示拼接操作。最终得到的是一组带有位置信息的文本 token 嵌入序列。
公式2: 图像输入嵌入
图像被划分为 N N N 个 patch,每个 patch 被展平后通过线性投影生成向量 v i V v_i^V viV v class v_\text{class} vclass:图像模态的 [CLS] 向量; V pos V^\text{pos} Vpos:图像 patch 的位置嵌入。类似于文本嵌入,图像也被编码成带位置信息的序列。
公式3:拼接图文输入 + 模态类型嵌入
t type t^\text{type} ttype v type v^\text{type} vtype:模态类型嵌入,用于标识每个 token 是来自文本还是图像;文本与图像嵌入分别加上类型标识后拼接,形成最终输入序列 z 0 z^0 z0 z 0 z^0 z0 是 Transformer 编码器的初始输入。
公式4、5: Transformer 编码器的每层前向传播
ViLT 使用标准的 ViT Transformer Block;每层包含 LayerNorm、Multi-head Self-Attention (MSA)、前馈网络(MLP)和残差连接;ViLT 使用 Pre-LN 架构,即 LayerNorm 在模块输入前; z d z^d zd 表示第 d d d 层的输出。
公式6: 多模态输出池化表示
z 0 D z^D_0 z0D:Transformer 最后一层第一个位置的输出向量(即 [CLS] token); W pool W^\text{pool} Wpool:可学习的线性投影矩阵; p p p 是最终的图文融合表示,用于分类或匹配任务。

目标函数

ViLT使用了一般VLP模型常用的目标函数,即图文匹配损失( ITM,image text matching)和 BERT的掩码学习损失(MLM,Masked Language Modeling)。另外,ViLT还使用了Word Patch Alignment(WPA),简单理解就是将文本和图像的输出都当做一个概率分布,然后使用最优运输理论计算一下两者的距离。

Whole Word Masking(WWM)

不按照token,而按整个词进行 Mask。在 VLP(视觉语言预训练)场景中,WWM 可以显著增强跨模态信息的利用。例如:如果只遮了 “gi” 和 “##fe”,模型可能只用剩下的 “##raf” 推断出 “giraffe”,这就使图像信息的作用变弱了。WWM 迫使模型必须依靠图像信息来预测整个词的语义,从而提升图文联合建模效果。

实验结果

在这里插入图片描述
如图所示,ViLT相比于region feature的方法速度快了60倍,相比于grid feature的方法快了4倍,而且下游任务表现出相似甚至更好的性能。
在这里插入图片描述
在这里插入图片描述
从table2、table3和table4中可以看出,相对于region和grid的方法,ViLT在下游任务表现出相似甚至更好的性能。
在这里插入图片描述
从上图的可视化结果可看出,ViLT学到了word和image patch之间的对应关系。
下面是HuggingFace Transformers实现的vilt模型:

import torch
from torch import nn
from transformers.models.vilt.configuration_vilt import ViltConfig

class ViltEmbeddings(nn.Module):
    """
    构造图像 patch + 文本 token 的嵌入输入,模仿 ViT 和 BERT 的做法。
    """
    def __init__(self, config: ViltConfig):
        super().__init__()
        # 文本词嵌入(等同于 BERT)
        self.text_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 图像 patch 嵌入(等同于 ViT)
        patch_size = config.patch_size
        self.patch_embed = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=patch_size,
            stride=patch_size
        )
        # 初始化 image [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        # 图像 patch 的位置嵌入
        num_patches = (config.image_size // patch_size) ** 2
        self.image_position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        # 图像模态的 token type embedding
        self.modality_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)

    def forward(self, input_ids, pixel_values):
        # 文本嵌入:token + position + token_type
        seq_length = input_ids.size(1)
        positions = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
        x_text = self.text_embeddings(input_ids) + \
                 self.position_embeddings(positions) + \
                 self.token_type_embeddings(torch.zeros_like(input_ids))

        # 图像嵌入:Patch Projection + [CLS] token + position + modality type
        x_image = self.patch_embed(pixel_values)  # [B, H, sqrtP, sqrtP]
        B, C, H, W = x_image.shape
        x_image = x_image.flatten(2).transpose(1, 2)  # [B, num_patches, C]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # 图像专用 CLS
        x_image = torch.cat([cls_tokens, x_image], dim=1)
        x_image = x_image + self.image_position_embeddings + self.modality_type_embeddings(
            torch.ones(x_image.size()[:2], dtype=torch.long, device=x_image.device)
        )

        # 合并文本与图像嵌入为单流输入
        x_text = x_text + self.modality_type_embeddings(
            torch.zeros(x_text.size()[:2], dtype=torch.long, device=x_text.device)
        )
        return torch.cat([x_text, x_image], dim=1)

class ViltModel(nn.Module):
    """
    单流 Transformer 模型,用于融合文本和图像输入,输出跨模态特征。
    """
    def __init__(self, config: ViltConfig):
        super().__init__()
        self.embeddings = ViltEmbeddings(config)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_attention_heads,
                dim_feedforward=config.intermediate_size,
                dropout=config.hidden_dropout_prob,
                activation=config.hidden_act
            ),
            num_layers=config.num_hidden_layers
        )
        self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, input_ids, pixel_values):
        # 1. 构建输入嵌入
        embedding_output = self.embeddings(input_ids, pixel_values)

        # 2. 深度多层自注意力 Transformer 编码
        encoded_output = self.encoder(embedding_output.transpose(0, 1)).transpose(0, 1)

        # 3. 池化:取第一个 [CLS] token,用线性层 + tanh 变换
        cls_output = encoded_output[:, 0, :]  # batch x hidden
        pooled_output = self.activation(self.pooler(cls_output))
        return encoded_output, pooled_output

源码:https://github.com/dandelin/vilt.

总结

ViLT 作为首个完全去除卷积网络和区域监督的视觉-语言预训练模型,提出了一种极简而高效的架构:其通过将图像划分为 patch 并进行线性投影,使图像和文本在输入阶段即被统一处理,再通过单流 Transformer 实现深层跨模态交互,最终利用 [CLS] 向量进行任务预测。相比传统依赖目标检测器或 CNN 视觉编码器的模型,ViLT 显著降低了计算开销,同时保持了主流任务上的竞争性能。其主要优势在于架构轻量、推理速度快、易于扩展,而劣势则在于去除区域信息后在精细对象级任务上略显不足。未来研究可进一步探索更有效的视觉掩码策略、增强跨模态对齐机制、以及在大规模数据或更大模型(如 ViLT-L/H)上的扩展能力,以实现兼顾效率与表达力的更强多模态预训练模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2405579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营第60期第六十天打卡

大家好,今天因为有数学建模比赛的校赛,今天的文章可能会简单一点,望大家原谅,我们昨天主要讲的是并查集的题目,我们复习了并查集的功能,我们昨天的题目其实难度不小,尤其是后面的有向图&#xf…

数据结构——D/串

一、串的定义和基本操作  1. 串的定义   1)串的概念   组成结构: 串是由零个或多个字符组成的有限序列,记为 S′a1a2⋯an′Sa_1a_2\cdots a_nS′a1​a2​⋯an′​&#x…

AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年6月8日第102弹

从今天开始,咱们还是暂时基于旧的模型进行预测,好了,废话不多说,按照老办法,重点8-9码定位,配合三胆下1或下2,杀1-2个和尾,再杀4-5个和值,可以做到100-300注左右。 (1)定…

【第九篇】 SpringBoot测试补充篇

简介 本文介绍了SpringBoot测试中的五项关键技术:测试类专用属性加载、 测试类专用Bean配置、 表现层测试方法、测试类事务回滚控制、配置文件随机数据设置)。这些技术可以有效隔离测试环境,确保测试数据不影响生产环境,同时提供了…

springcloud SpringAmqp消息队列 简单使用

这期只是针对springBoot/Cloud 在使用SpringAmqp消息队列的时候遇到的坑。 前提 如果没有安装RabbitMQ是无法连接成功的!所以前提是你要安装好RabbitMQ。 docker 安装命令 # 拉取docker镜像 docker pull rabbitmq:management# 创建容器 docker run -id --namera…

Framework开发之IMS逻辑浅析1--关键线程及作用

关键线程:EventHub,InputReader,InputDispatcher EventHub: 由于Android继承Linux,Linux的思想是一切皆文件,而输入的类型不止一种(触碰,写字笔,键盘等),每种类型都对应一种驱动设备,而每个硬件驱动设备又对应Linux的一个目录文件…

系统思考:跳出症状看全局

明天将为华为全球采购认证管理部的伙伴们带来一场关于系统思考的深度课程!通过经典的啤酒游戏经营决策沙盘,一起沉浸式体验如何从全局视角看待问题,发现单点最优并不等于全局最优。 这不仅是一次简单的课程,更是一次洞察系统背后…

DeepSeek R1 V2 深度探索:开源AI编码新利器,效能与创意并进

最近,AI界迎来了一位神秘的“突袭者”——DeepSeek团队悄无声息地发布了其推理模型DeepSeek R1的重磅升级版V2(具体型号R1-0528)。这款基于MIT许可的开源模型,在原版R1的基础上进行了多项令人瞩目的改进,正以其强大的潜…

surfer15安装

安装文件 安装包和破解文件 安装 破解及汉化 打开软件

Python训练营---DAY48

DAY 48 随机函数与广播机制 知识点回顾: 随机张量的生成:torch.randn函数卷积和池化的计算公式(可以不掌握,会自动计算的)pytorch的广播机制:加法和乘法的广播机制 ps:numpy运算也有类似的广播机…

debian12拒绝海外ip连接

确保 nftables 已安装: Debian 12 默认使用 nftables 作为防火墙框架。检查是否安装: sudo apt update sudo apt install nftables启用并启动 nftables 服务 sudo systemctl enable nftables sudo systemctl start nftables下载maxmind数据库 将文件解…

70年使用权的IntelliJ IDEA Ultimate安装教程

安装Java环境 下载Java Development Kit (JDK) 从Oracle官网或OpenJDK。推荐选择JDK 11或更高版本。 运行下载的安装程序,按照提示完成安装。注意记录JDK的安装路径(如C:\Program Files\Java\jdk-11.0.15)。 配置环境变量: 右键…

MySQL的日志

就相当于人的日记本,记录每天发生的事,可以对数据进行追踪 一、错误日志 也就是存放错误信息的 二、二进制日志-binlog 在低版本的MySQL中,二进制日志是不会默认开启的 存放除了查询语句的其他语句 三、查询日志 查询日志会记录客户端的所…

低功耗高安全:蓝牙模块在安防系统中的应用方案

随着物联网(IoT)和智能家居的快速发展,安防行业正迎来前所未有的技术革新。蓝牙模块作为一种低功耗、高稳定性的无线通信技术,凭借其低成本、易部署和智能化管理等优势,在安防领域发挥着越来越重要的作用。本文将探讨蓝牙模块在安防系统中的应…

C++定长内存块的实现

内存池 内存池是指程序预先从操作系统 申请一块足够大内存 ,此后,当程序中需要申请内存的时候,不是直接向操作系统申请,而是 直接从内存池中获取 ; 同理,当 **程序释放内存 **的时候,并不真正将…

Unity使用代码分析Roslyn Analyzers

一、创建项目(注意这里不要选netstandard2.1会有报错) 二、NuGet上安装Microsoft.CodeAnalysis.CSharp 三、实现[Partial]特性标注的类,结构体,record必须要partial关键字修饰 需要继承DiagnosticAnalyzer 注意一定要加特性Diagn…

大数据CSV导入MySQL

CSV Import MySQL 源码主要特性技术栈快速开始1. 环境要求2. 构建项目3. 使用方式交互式模式命令行模式编程方式使用 核心组件1. CsvService2. DatabaseService3. CsvImportService 数据类型映射性能优化1. 连接池优化2. 批量操作优化3. MySQL配置优化 配置说明application.yml…

MySQL 索引优化(Explain执行计划) 详细讲解

🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 MySQL 索引优化(Explain执行计划…

Cad 反应器 cad c#二次开发

在 AutoCAD C# 二次开发中,DocumentCollectionEventHandler 是一个委托(delegate),用于处理与 AutoCAD 文档集合(DocumentCollection)相关的事件。它属于 AutoCAD .NET API 的事件处理机制,本质…

【websocket】安装与使用

websocket安装与使用 1. 介绍2. 安装3. websocketpp常用接口4. Websocketpp使用4.1 服务端4.2 客户端 1. 介绍 WebSocket 是从 HTML5 开始支持的一种网页端和服务端保持长连接的 消息推送机制。 传统的 web 程序都是属于 “一问一答” 的形式,即客户端给服务器发送…