stable diffusion论文解读

news2025/5/25 18:30:21

High-Resolution Image Synthesis with Latent Diffusion Models

论文背景

LDM是Stable Diffusion模型的奠基性论文

于2022年6月在CVPR上发表
 

img


传统生成模型具有局限性:

  • 扩散模型(DM)通过逐步去噪生成图像,质量优于GAN,但直接在像素空间操作导致高计算开销。
  • 随着分辨率提升,扩散模型的优化和推理成本呈指数级增长,限制了实际应用

如DDPM生成的图像分辨率普遍不超过256×256,而LDM生成的图像分辨率可以超过1024×1024.
而LDM通过将扩散过程迁移至潜在空间,解决了传统模型的计算瓶颈,同时保持生成质量与灵活性

论文框架方法

论文中框架示意图如图所示:
 

img


在训练阶段:

  • 预训练自动编码器(AE)和条件生成编码器(如clip)
  • 输入图片x,经过自动编码器压缩到隐空间ε(x)=z
  • 随机采样时间步T,对Z进行加噪到ZTZT
  • 对右边框里的条件进行条件编码τθ(y)τθ(y)和ZTZT一起输入UNet网络中
  • 进行交叉注意力计算,其中ZTZT作为Q向量,τθ(y)τθ(y)作为K,V向量计算注意力,这样做是让图像的每个位置根据文本的语义来决定关注哪些部分
  • 最后Unet输出两个向量,一个是无条件预测噪声,一个是文本预测噪声。

无条件预测噪声输入是空字符串

  • 使用CFG计算最终预测噪声ϵguided(zt,t,τθ(y))=ϵθ(zt,t,τθ(y))+s⋅(ϵθ(zt,t,τθ(y))−ϵθ(zt,t,∅))ϵguided(zt,t,τθ(y))=ϵθ(zt,t,τθ(y))+s⋅(ϵθ(zt,t,τθ(y))−ϵθ(zt,t,∅))
  • 使用损失函数进行反向传播计算

    img

在生成阶段:

  • 以随机噪声ZTZT作为起点
  • 输入文本作为条件,编码后一起进入Unet进行交叉注意力计算
  • 输出预测噪声ϵguided(zt,t,τθ(y))ϵguided(zt,t,τθ(y))
  • 使用调度器进行逐步去噪计算(如DDPM,DDIM)成为ZT−1ZT−1
  • 重复以上过程,直到Z
  • 通过自动编码器的解码器部分把Z迁移到像素空间,D(z),即生成图像

交叉注意力机制中的维度变换

图像编码后变成 C=4, H'=64, W'=64,展平后作为Q(zt⇒Q∈R(H′W′)×dzt⇒Q∈R(H′W′)×d),文本通过编码器的编码表示为c=[t1,t2,...,tL]⇒Embedding∈RL×dc=[t1,t2,...,tL]⇒Embedding∈RL×d,K和V表示为K,V∈RL×dK,V∈RL×d,计算注意力权重A=softmax(QK⊤√d)∈R(H′W′)×LA=softmax(QK⊤d)∈R(H′W′)×L,输出为Attention(Q,K,V)=A⋅V∈R(H′W′)×dAttention(Q,K,V)=A⋅V∈R(H′W′)×d

        # 潜在空间输入(prepare_latents生成)
latents.shape = (batch_size * num_images_per_prompt, 4, H//8, W//8)

# 文本嵌入处理(encode_prompt输出)
prompt_embeds.shape = (batch_size, max_sequence_length, embedding_dim)

# IP适配器图像嵌入处理
image_embeds[0].shape = (batch_size * num_images_per_prompt, num_images, emb_dim)

# UNet输入/输出维度
latent_model_input.shape = [batch*2, 4, H//8, W//8]  # 当启用CFG时
noise_pred.shape = [batch*2, 4, H//8, W//8]          # UNet输出噪声预测
假设参数设置
prompt = "一只坐在月球上的猫"
height = 512
width = 512
num_images_per_prompt = 1
guidance_scale = 7.5
batch_size = 1  # 根据prompt长度自动确定

# 关键计算步骤演示
# ---------------------------
# 步骤1:潜在空间(latents)维度计算
latents_shape = (
    batch_size * num_images_per_prompt,  # 1*1=1
    4,  # UNet输入通道数
    height // 8,  # 512/8=64
    width // 8    # 512/8=64
)
print(f"潜在空间维度: {latents_shape}")  # -> (1, 4, 64, 64)

# 步骤2:文本编码维度(假设使用CLIP模型)
prompt_embeds_shape = (
    batch_size, 
    77,  # CLIP最大序列长度
    768  # CLIP文本编码维度
) 
print(f"文本嵌入维度: {prompt_embeds_shape}")  # -> (1, 77, 768)

# 步骤3:CFG处理后的嵌入
if guidance_scale > 1:
    prompt_embeds = torch.cat([negative_embeds, positive_embeds])
    print(f"CFG嵌入维度: {prompt_embeds.shape}")  # -> (2, 77, 768)

# 步骤4:UNet输入维度(假设启用CFG)
latent_model_input = torch.cat([latents] * 2)
print(f"UNet输入维度: {latent_model_input.shape}")  # -> (2, 4, 64, 64)

# 步骤5:噪声预测输出
noise_pred = unet(latent_model_input, ...)[0]
print(f"噪声预测维度: {noise_pred.shape}")  # -> (2, 4, 64, 64)

# 步骤6:CFG调整后的噪声
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
print(f"调整后噪声维度: {noise_pred.shape}")  # -> (1, 4, 64, 64)

# 最终输出图像
image = vae.decode(latents / vae.config.scaling_factor)[0]
print(f"输出图像维度: {image.shape}")  # -> (1, 3, 512, 512)
    def cross_attention(query, key, value):
    # 输入维度说明
    # query: 来自潜在噪声 [batch=2, 4*64*64=16384] → 投影为 [2, 16384, 768]
    # key/value: 来自文本嵌入 [2, 77, 768]
    
    # 步骤1:计算注意力分数
    attention_scores = torch.matmul(
        query,  # [2, 16384, 768] 
        key.transpose(-1, -2)  # [2, 768, 77] → 转置后维度
    )  # 矩阵乘法结果 → [2, 16384, 77]
    
    # 步骤2:计算注意力权重
    attention_probs = torch.softmax(
        attention_scores,  # [2, 16384, 77]
        dim=-1  # 对最后一个维度(文本标记维度)做归一化
    )  # 保持维度 [2, 16384, 77]
    
    # 步骤3:应用注意力到value
    output = torch.matmul(
        attention_probs,  # [2, 16384, 77]
        value  # [2, 77, 768]
    )  # 结果维度 → [2, 16384, 768]
    
    # 步骤4:重塑为潜在空间维度
    output = output.view(2, 4, 64, 64, 768)  # 恢复空间结构
    output = output.permute(0, 4, 1, 2, 3)  # [2, 768, 4, 64, 64]
    output = self.to_out(output)  # 通过最后的线性层投影回4通道
    return output  # [2, 4, 64, 64]

数据集以及指标介绍

数据集介绍

img

CelebA-HQ 256 × 256数据集,是一个大规模的人脸属性数据集,拥有超过200K张名人图片,每张图片都有40个属性注释(如身份,年龄、表情、发型等)。

img

从Flickr网站爬取的人脸数据集集合,涵盖多样化的年龄、种族、表情、配饰(如眼镜、帽子)等属性

img

这两个数据集都是LSUN(大规模场景理解)数据集的子集,两个数据集分别表示教堂和卧室场景的数据,包含教堂建筑的不同视角、结构和环境条件,覆盖多样化的卧室场景,包括不同装修风格、家具布局和光照条件

指标介绍

IS分数介绍

Inception Score 的定义为:
IS(G)=exp(Ex∼pg[DKL(p(y|x)∥p(y))])IS(G)=exp⁡(Ex∼pg[DKL(p(y|x)‖p(y))])

x~pg:生成图像样本来自生成模型的分布 。

p(y|x):通过预训练分类器(如Inception v3)对生成图像的类别预测概率分布。

p(y):预测类别的边缘分布。类别可以是猫,狗,猪等诸如此类的动物。
其中如果生成图像明确、质量高,则p(y|x)的熵就会比较低,如果生成图像比较多样,则p(y)的熵就会较高,体现在公式中则IS分数会较高。

FID分数介绍

主要是计算生成图像分布和真实图像分布在特征空间中的距离
公式FID=∥μr−μg∥22+Tr(Σr+Σg−2(ΣrΣg)12)FID=‖μr−μg‖22+Tr(Σr+Σg−2(ΣrΣg)12)

μr,Σrμr,Σr:真实图像分布的均值和协方差矩阵。

μg,Σgμg,Σg:生成图像分布的均值和协方差矩阵。

∥μr−μg∥22‖μr−μg‖22:欧几里得距离的平方。

TrTr:矩阵的迹。

(ΣrΣg)12(ΣrΣg)12:协方差矩阵的乘积的平方根。

两个分布的均值和协方差越低,FID越低,生成图像质量越接近生成的图像

prec和recall

这里的指标和一般理解的不一样。

会先用Inception网络分别提取真实图像和生成图像的特征点

用集合的角度解释:

Precision ≈ 生成图像中,有多少落在真实图像分布的“支持区域”里(真实性)

Recall ≈ 真实图像中,有多少被生成图像的“支持区域”覆盖(多样性)

实验分析

img

研究不同下采样因子f对生成图像质量和训练效率的影响

下采样因子:指的是自动编码器中的参数。

可以看到下采样因子为4或8时表现最好。因为如果因子过小,会导致维度高,计算缓慢,因子过大,会损失很多信息,导致最后生成图像生成质量较差

后续的实验将基于此展开

img

在这个实验里可以看到,LDM在CelebA-HQ中取得了最优的FID分数,在其他数据集上的表现也是中规中矩。

img

这个实验里展示了LDM在类别生成任务中的表现,可以看到使用cfg引导的LDM展现出了非常优秀的性能,在FID和IS分数上表现优异,虽然recall略低,但是使用的参数量也大幅减少了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2385520.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络(3)——传输层

1.概述 1.1 传输层的服务和协议 (1)传输层为允许在不同主机(Host)上的进程提供了一种逻辑通信机制 (2)端系统(如手机、电脑)运行传输层协议 发送方:将来自应用层的消息进行封装并向下提交给 网络层接收方:将接收到的Segment进行组装并向上提交给应用层 …

LangChain构建RAG的对话应用

目录 Langchain是什么? LangSmith是什么? ​编辑 使用Python构建并使用AI大模型 数据解析器 提示模版 部署 记忆功能 Chat History -- 记忆 代码执行流程: 流式输出 构建向量数据库和检索器 检索器 代码执行流程 LLM使用检索器…

目标检测DN-DETR(2022)详细解读

文章目录 gt labels 和gt boxes加噪query的构造attention maskIS(InStability)指标 在DAB-Detr的基础上,进一步分析了Detr收敛速度慢的原因:二分图匹配的不稳定性(也就是说它的目标在频繁地切换,特别是在训…

嵌入式培训之系统编程(四)进程

一、进程的基本概念 (一)定义 进程是一个程序执行的过程(也可以说是正在运行的程序),会去分配内存资 源,cpu的调度,它是并发的 (二)PCB块 1、PCB是一个结构体&#x…

天文数据处理:基于CUDA的射电望远镜图像实时去噪算法(开源FAST望远镜数据处理代码解析)

一、射电天文数据处理的挑战与CUDA加速的必要性 作为全球最大的单口径射电望远镜,中国天眼(FAST)每秒产生38GB原始观测数据,经预处理后生成数千万张图像。这些数据中蕴含的脉冲星、中性氢等天体信号常被高斯白噪声、射频干扰&…

VS编码访问Mysql数据库

安装 MySQL Connector/C 的开发包 libmysqlcppconn-dev是 MySQL Connector/C 的开发包,它的主要用途是让 C 开发者能够方便地在应用程序中与 MySQL 数据库进行交互。它提供了以下功能: 数据库连接:通过标准的 C 接口连接到 MySQL 数据库。S…

一周学会Pandas2 Python数据处理与分析-Pandas2数据合并与对比-pd.merge():数据库风格合并

锋哥原创的Pandas2 Python数据处理与分析 视频教程: 2025版 Pandas2 Python数据处理与分析 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili pd.merge():数据库风格合并 **核心功能**:基于列值(类似 SQL JOIN)合…

CodeBuddy 实现图片转素描手绘工具

本文所使用的 CodeBuddy 免费下载链接:腾讯云代码助手 CodeBuddy - AI 时代的智能编程伙伴 前言 最近在社交媒体上,各种素描风格的图片火得一塌糊涂,身边不少朋友都在分享自己的 “素描照”,看着那些黑白线条勾勒出的独特韵味&a…

3.8.2 利用RDD计算总分与平均分

在本次实战中,我们利用Spark的RDD完成了成绩文件的总分与平均分计算任务。首先,准备了包含学生成绩的文件并上传至HDFS。接着,通过交互式方式逐步实现了成绩的读取、解析、总分计算与平均分计算,并最终输出结果。此外,…

29-FreeRTOS事件标志组

一、概述 事件是一种实现任务间通信的机制,主要用于实现多任务间的同步,但事件通信只能是事件类型的通信,无数据传输。与信号量不同的是,它可以实现一对多,多对多的同步。 即一个任务可以等待多个事件的发生&#xff1…

「EMD/EEMD/VMD 信号分解方法 ——ECG信号处理-第十四课」2025年5月23日

一、引言 上一节,我们介绍了希尔伯特黄变换(HHT)及其经验模态分解(EMD)的相关内容,这一节,我们继续拓展EMD分解技术,补充介绍集合经验模态分解(Ensemble Empirical Mode …

二叉树层序遍历6

INT_MIN的用法&#xff1a; INT_MIN是C/C 中的一个宏常量 &#xff0c;在 <limits.h> &#xff08;C 中也可使用 <climits> &#xff09;头文件中定义&#xff0c;代表 int 类型能表示的最小整数值 。其用法主要体现在以下方面&#xff1a; 1.初始化变量 …

【论文精读】2023 AAAI--FastRealVSR现实世界视频超分辨率(RealWorld VSR)

文章目录 一、摘要二、Method2.1 现象&#xff08;问题&#xff09;--对应文中隐状态的分析&#xff08;Analysis of Hidden State&#xff09;2.2 怎么解决 --对应文中Framework2.2.1 整体流程&#xff1a;2.2.2 HSA模块怎么工作&#xff1f;2.2.2.1 隐藏状态池2.2.2.2 选择性…

IPython 常用魔法命令

文章目录 IPython 魔法命令&#xff08;Magic Commands&#xff09;一、系统与文件操作1. %ls2. %cd​​和%pwd3. %%writefile​​4. %run 二、性能分析与计时1. %timeit2. %prun​​3. ​​%%timeit 三、代码处理与交互1. %load2. ​​%edit3. ​​%store 四、调试与诊断2. ​…

Java虚拟机 - 程序计数器和虚拟机栈

运行时数据结构 Java运行时数据区程序计数器为什么需要程序计数器执行流程虚拟机栈虚拟机栈作用虚拟机栈核心结构运行机制 Java运行时数据区 首先介绍Java运行时数据之前&#xff0c;我们要了解&#xff0c;对于计算机来说&#xff0c;内存是非常重要的资源&#xff0c;因为内…

新能源汽车产业链图谱分析

1. 产业定义 新能源汽车是指采用非常规的车用燃料作为动力来源&#xff0c;综合车辆的动力控制和驱动方面的先进技术&#xff0c;形成的具有新技术、新结构、技术原理先进的汽车。 新能源车包括四大类型&#xff1a;混合动力电动汽车&#xff08;HEV&#xff09;、纯电动汽车…

如何在PyCharm2025中设置conda的多个Python版本

前言 体验的最新版本的PyCharm(Community)2025.1.1&#xff0c;发现和以前的版本有所不同。特别是使用Anaconda中的多个版本的Python的时候。 关于基于Anaconda中多个Python版本的使用&#xff0c;以及对应的Pycharm&#xff08;2023版&#xff09;的使用&#xff0c;可以参考…

maven快速上手

之前我们项目如果要用到其他额外的jar包&#xff0c;需要自己去官网下载并且导入。但是有maven后&#xff0c;直接在maven的pom.xml文件里用代码配置即可&#xff0c;配置好后maven会自动帮我们联网下载并且会自动导入该jar包 在右边的maven中&#xff0c;我们可以看到下载安装…

cplex12.9 安装教程以及下载

cplex 感觉不是很好找&#xff0c;尤其是教育版&#xff0c;我这里提供一个版本&#xff0c;在下面的图可以看到&#xff0c;不仅可以配置matlab&#xff0c;也可以配置vs,现在拿vs2017来测试一下&#xff0c;具体文件的文件有需要的可以复制下面的链接获取 我用网盘分享了「c…

甘特图实例 dhtmlxGantt.js

本文介绍了如何使用dhtmlxGantt库创建一个基础的甘特图示例&#xff0c;并对其进行汉化和自定义配置。首先&#xff0c;通过引入dhtmlxgantt.css和dhtmlxgantt.js文件初始化甘特图。接着&#xff0c;通过设置gantt.i18n.setLocale("cn")实现核心文本的汉化&#xff0…