【Stable Diffusion 1.5 】在 Unet 中每个 Cross Attention 块中的张量变化过程

news2025/6/6 13:46:57

系列文章目录


文章目录

  • 系列文章目录
  • 前言
      • 特征图和注意力图的尺寸差异原因
      • 在Break-a-Scene中的具体实现
      • 总结


前言

特征图 (Latent) 尺寸和注意力图(attention map)尺寸在扩散模型中有差异,是由于模型架构和注意力机制的特性决定的。
在这里插入图片描述

特征图和注意力图的尺寸差异原因

  1. 不同的功能目的

    • 特征图(Feature Maps):承载图像的语义和视觉特征,维持空间结构
    • 注意力图(Attention Maps):表示不同位置之间的关联强度,是一种关系矩阵
  2. UNet架构中的特征图尺寸
    在U-Net中,特征图的尺寸在不同层级有变化:

    • 输入图像通常是 512×512 或 256×256
    • 下采样路径(Encoder):尺寸逐渐缩小 (512→256→128→64→32→16…)
    • 上采样路径(Decoder):尺寸逐渐增大 (16→32→64→128→256→512…)

    在Break-a-Scene代码中,我们看到特征图尺寸被下采样到64×64:

    downsampled_mask = F.interpolate(input=max_masks, size=(64, 64))
    
  3. 注意力机制中的尺寸计算
    注意力机制处理的是"token"之间的关系,其中:

    • 自注意力(Self-Attention):特征图中的每个位置视为一个token
    • 交叉注意力(Cross-Attention):文本序列中的token与特征图中的位置建立关联

    如果特征图尺寸是h×w,则自注意力矩阵的尺寸是(hw)×(hw),这是一个平方关系

    在代码中,注意力图通常被下采样到16×16:

    GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16))
    
  4. 计算效率考虑

    • 注意力计算的复杂度是O(n²),其中n是token数量
    • 对于64×64的特征图,如果直接计算自注意力,需要处理4096×4096的矩阵
    • 为了降低计算量,通常在较低分辨率(如16×16)的特征图上计算注意力,这样只需处理256×256的矩阵

在Break-a-Scene中的具体实现

在Break-a-Scene中,这些尺寸差异体现在:

  1. 两种不同的损失计算

    a. 掩码损失(Masked Loss):应用在64×64的 Latent 上

    max_masks = torch.max(batch["instance_masks"], axis=1).values
    downsampled_mask = F.interpolate(input=max_masks, size=(64, 64))
    model_pred = model_pred * downsampled_mask
    target = target * downsampled_mask
    

    b. 注意力损失(Attention Loss):应用在16×16的注意力图上

    GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16))
    agg_attn = self.aggregate_attention(res=16, from_where=("up", "down"), is_cross=True, select=batch_idx)
    
  2. 注意力存储的筛选

    在存储注意力图时,只保留小尺寸的注意力图:

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32**2:  # 只保存小于或等于32×32的注意力图
            self.step_store[key].append(attn)
        return attn
    
  3. 注意力聚合

    在聚合不同层的注意力时,确保只使用匹配目标分辨率的注意力图:

    def aggregate_attention(self, res: int, from_where: List[str], is_cross: bool, select: int):
        # ...
        num_pixels = res**2
        for location in from_where:
            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
                if item.shape[1] == num_pixels:  # 只选择匹配分辨率的注意力图
                    cross_maps = item.reshape(self.args.train_batch_size, -1, res, res, item.shape[-1])[select]
                    out.append(cross_maps)
        # ...
    

总结

特征图和注意力图尺寸的差异主要是因为:

  1. 它们在模型中的功能不同
  2. 注意力计算的计算复杂度要求在较低分辨率上进行
  3. UNet架构中的不同层级有不同的特征图尺寸
  4. 为了平衡精度和计算效率,Break-a-Scene使用不同分辨率的特征图和注意力图来计算不同类型的损失

这种设计使得Break-a-Scene能够有效地学习token与图像区域之间的对应关系,同时保持计算效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2398835.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL - Windows 中 MySQL 禁用开机自启,并在需要时手动启动

Windows 中 MySQL 禁用开机自启&#xff0c;并在需要时手动启动 打开服务管理器&#xff1a;在底部搜索栏输入【services.msc】 -> 点击【服务】 打开 MySQL 服务的属性管理&#xff1a;找到并右击 MySQL 服务 -> 点击【属性】 此时的 MySQL 服务&#xff1a;正在运行&a…

OpenCV CUDA模块霍夫变换------在 GPU 上执行概率霍夫变换检测图像中的线段端点类cv::cuda::HoughSegmentDetector

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::cuda::HoughSegmentDetector 是 OpenCV 的 CUDA 模块中一个非常重要的类&#xff0c;它用于在 GPU 上执行 概率霍夫变换&#xff08;Probabi…

ck-editor5的研究 (5):优化-页面离开时提醒保存,顺便了解一下 Editor的生命周期 和 6大编辑器类型

前言 经过前面的 4 篇内容&#xff0c;我们已经慢慢对 CKEditor5 熟悉起来了。这篇文章&#xff0c;我们就来做一个优化&#xff0c;顺便再补几个知识点&#xff1a; 当用户离开时页面时&#xff0c;提醒他保存数据了解一下 CKEditor5 的 六大编辑器类型了解一下 editor 实例对…

[3D GISMesh]三角网格模型中的孔洞修补算法

&#x1f4d0; 三维网格模型空洞修复技术详解 三维网格模型在扫描、重建或传输过程中常因遮挡、噪声或数据丢失产生空洞&#xff08;即边界非闭合区域&#xff09;&#xff0c;影响模型的完整性与可用性。空洞修复&#xff08;Hole Filling&#xff09;是计算机图形学和几何处…

11.2 java语言执行浅析3美团面试追魂七连问

美团面试追魂七连问&#xff1a;关于Object o New Object() ,1请解释一下对象的创建过程(半初始化) 2,加问DCL要不要volatile 问题(指令重排) 3.对象在内存中的存储布局(对象与数组的存储不同),4.对象头具体包括什么.5.对象怎么定位.6.对象怎么分配(栈-线程本地-Eden-Old)7.在…

MySQL 全量、增量备份与恢复

一.MySQL 数据库备份概述 备份的主要目的是灾难恢复&#xff0c;备份还可以测试应用、回滚数据修改、查询历史数据、审计等。之前已经学习过如何安装 MySQL&#xff0c;本小节将从生产运维的角度了解备份恢复的分类与方法。 1 数据备份的重要性 在企业中数据的价值至关…

MonoPCC:用于内窥镜图像单目深度估计的光度不变循环约束|文献速递-深度学习医疗AI最新文献

Title 题目 MonoPCC: Photometric-invariant cycle constraint for monocular depth estimation of endoscopic images MonoPCC&#xff1a;用于内窥镜图像单目深度估计的光度不变循环约束 01 文献速递介绍 单目内窥镜是胃肠诊断和手术的关键医学成像工具&#xff0c;但其…

SpringAI系列 - MCP篇(三) - MCP Client Boot Starter

目录 一、Spring AI Mcp集成二、Spring AI MCP Client Stater三、spring-ai-starter-mcp-client-webflux集成示例3.1 maven依赖3.2 配置说明3.3 集成Tools四、通过SSE连接MCP Server五、通过STDIO连接MCP Server六、通过JSON文件配置STDIO连接一、Spring AI Mcp集成 Spring AI…

【深度学习新浪潮】以Dify为例的大模型平台的对比分析

我们从核心功能、适用群体、易用性、可扩展性和安全性五个维度展开对比分析: 一、核心功能对比 平台核心功能多模型支持插件与工具链Dify低代码开发、RAG增强、Agent自律执行、企业级安全支持GPT-4/5、Claude、Llama3、Gemini及开源模型(如Qwen-VL-72B),支持混合模型组合可…

Asp.net core 使用EntityFrame Work

安装以下Nuget 包 Microsoft.EntityFrameworkCore.Tools Microsoft.EntityFrameworkCore.Design Microsoft.AspNetCore.Diagnostics.EntityFrameworkCore Microsoft.EntityFrameworkCore.SqlServer或者Npgsql.EntityFrameworkCore.PostgreSQL 安装完上述Nuget包之后,在appset…

AI Coding 资讯 2025-06-03

Prompt工程 RAG-MCP&#xff1a;突破大模型工具调用瓶颈&#xff0c;告别Prompt膨胀 大语言模型(LLM)在工具调用时面临Prompt膨胀和决策过载两大核心挑战。RAG-MCP创新性地引入检索增强生成技术&#xff0c;通过外部工具向量索引和动态检索机制&#xff0c;仅将最相关的工具信…

2024年12月 C/C++(三级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:最近的斐波那契数 斐波那契数列 Fn 的定义为:对 n ≥ 0 有 Fn+2 = Fn+1 + Fn,初始值为 F0 = 0 和 F1 = 1。所谓与给定的整数 N 最近的斐波那契数是指与 N 的差之绝对值最小的斐波那契数。 本题就请你为任意给定的整数 N 找出与之最…

NeRF PyTorch 源码解读 - 体渲染

文章目录 1. 体渲染公式推导1.1. T ( t ) T(t) T(t) 的推导1.2. C ( r ) C(r) C(r) 的推导 2. 体渲染公式离散化3. 代码解读 1. 体渲染公式推导 如下图所示&#xff0c;渲染图像上点 P P P 的颜色值 c c c 是累加射线 O P → \overrightarrow{OP} OP 在近平面和远平面范围…

SpringBoot 数据库批量导入导出 Xlsx文件的导入与导出 全量导出 数据库导出表格 数据处理 外部数据

介绍 poi-ooxml 是 Apache POI 项目中的一个库&#xff0c;专门用于处理 Microsoft Office 2007 及以后版本的文件&#xff0c;特别是 Excel 文件&#xff08;.xlsx 格式&#xff09;和 Word 文件&#xff08;.docx 格式&#xff09;。 在管理系统中需要对数据库的数据进行导…

解决:install via Git URL失败的问题

为解决install via Git URL失败的问题&#xff0c;修改安全等级security_level的config.ini文件&#xff0c;路径如下&#xff1a; 还要重启&#xff1a; 1.reset 2.F5刷新页面 3.关机服务器&#xff0c;再开机&#xff08;你也可以省略&#xff0c;试试&#xff09; 4.Wind…

OpenCV CUDA模块特征检测------创建Harris角点检测器的GPU实现接口cv::cuda::createHarrisCorner

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 该函数创建一个 基于 Harris 算法的角点响应计算对象&#xff0c;专门用于在 GPU 上进行高效计算。 它返回的是一个 cv::Ptrcv::cuda::Cornernes…

【氮化镓】钝化层对p-GaN HEMT阈值电压的影响

2021年5月13日,中国台湾阳明交通大学的Shun-Wei Tang等人在《Microelectronics Reliability》期刊发表了题为《Investigation of the passivation-induced VTH shift in p-GaN HEMTs with Au-free gate-first process》的文章。该研究基于二次离子质谱(SIMS)、光致发光(PL)…

C++:优先级队列

目录 1. 概念 2. 特征 3. 优先级队列的使用 1. 概念 优先级队列虽然名字有队列二字&#xff0c;但根据队列特性来说优先级队列不满足先进先出这个特征&#xff0c;优先级队列的底层是用堆来实现的。 优先级队列是一种容器适配器&#xff0c;就是将特定容器类封装作为其底层…

睡眠分期 html

截图 代码 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>睡眠图表</title><script src…

Github 2025-05-29 Go开源项目日报Top9

根据Github Trendings的统计,今日(2025-05-29统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Go项目9Assembly项目1Ollama: 本地大型语言模型设置与运行 创建周期:248 天开发语言:Go协议类型:MIT LicenseStar数量:42421 个Fork数量:27…