第 3 期:逆过程建模与神经网络的作用(Reverse Process)

news2025/7/15 1:20:25

一、从正向扩散到逆向去噪:生成的本质

在上期中我们讲到,正向扩散是一个逐步加入噪声的过程,从原始图像 x_0到接近高斯分布的 x_T​:

而我们真正关心的,是从纯噪声中逐步还原原图的过程,也就是逆过程:

这个逆过程没有 closed form,我们只能用神经网络来近似学习它。

二、逆过程建模:从高斯中一步步采样

根据论文设定,我们假设每一步的逆过程仍是高斯分布:

也就是说:我们要学习的是每一步的均值和方差。

  • 方差 Σ_θ通常被固定或共享

  • 网络主要任务是输出 μ_θ,也就是引导去噪的方向

想象一下,你现在手上有一张全是雪点(噪声)的图片,你想一步一步去“擦掉”这些噪声,还原最初的图片,这就是神经网络的任务。

三、三种预测方式:预测 μ、ϵ 或 x_0​?

论文中探讨了三种不同的预测方式,来指导我们如何训练神经网络 ϵ_θ:

方式一:预测噪声 ϵ\epsilonϵ(默认使用)

利用公式:

我们可以反推:

训练时的损失函数:

也就是说我们训练神经网络来预测加进去的噪声,然后反推出干净图像。

方式二:直接预测 x_0

由上面的公式我们可以看到,如果我们预测出 x_0​,也能得到 ϵ 或 μ。

有些改进模型(如Guided Diffusion)使用这种方式,因为可以更直接地控制生成图像。

方式三:直接预测 μ_θ(x_t,t)

这种方式虽然看似最直接,但训练不如预测 ϵ稳定,因此实际使用中较少。

四、神经网络结构:用U-Net来建模 ϵ_θ(x_t,t)

DDPM中广泛使用 U-Net 结构来建模 ϵ_θ​,原因如下:

  • 图像到图像的任务中,U-Net有非常强的表现

  • 可融合多层语义信息(通过跳跃连接)

  • 可轻松嵌入时间步 ttt 信息(通过time embedding)

网络输入:

  • 噪声图像 x_t

  • 时间步编码 t

网络输出:

  • 同样大小的图像,预测噪声 ϵ

五、采样过程简述:从高斯恢复图像

当模型训练好之后,采样过程是这样的:

  1. 从高斯分布中采样

  2. 对 t=T,T−1,…,1:

    • 用网络预测 ϵ_θ(x_t,t)

    • 计算并加入随机项(保持多样性)

整个过程逐步“去除噪声”,最终得到 x_0,也就是生成图像。

代码演示:构造训练样本并训练模型

我们用 PyTorch 举例说明:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

# 超参数
T = 1000  # 扩散步数
beta = torch.linspace(1e-4, 0.02, T)  # 固定线性beta表
alpha = 1 - beta
alpha_bar = torch.cumprod(alpha, dim=0)

# 加噪函数 q(x_t | x_0)
def q_sample(x_0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0)
    sqrt_alpha_bar = torch.sqrt(alpha_bar[t])[:, None, None, None]
    sqrt_one_minus = torch.sqrt(1 - alpha_bar[t])[:, None, None, None]
    return sqrt_alpha_bar * x_0 + sqrt_one_minus * noise

 网络结构(最小U-Net)

class SimpleDenoiseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 3, padding=1),
        )

    def forward(self, x, t):
        return self.net(x)

 训练核心逻辑

model = SimpleDenoiseModel().to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def get_loss(x_0, t):
    noise = torch.randn_like(x_0)
    x_t = q_sample(x_0, t, noise)
    noise_pred = model(x_t, t)
    return nn.MSELoss()(noise_pred, noise)

# 示例训练循环
for epoch in range(10):
    for x, _ in dataloader:
        x = x.to("cuda")
        t = torch.randint(0, T, (x.size(0),), device="cuda").long()
        loss = get_loss(x, t)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

 可视化一个加噪过程

def show_noisy_images(x_0, steps=[0, 200, 400, 600, 800, 999]):
    fig, axes = plt.subplots(1, len(steps), figsize=(15, 2))
    for i, t in enumerate(steps):
        xt = q_sample(x_0, torch.tensor([t]))
        axes[i].imshow(xt[0][0].cpu(), cmap="gray")
        axes[i].set_title(f"t = {t}")
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()

 

小结

关键点内容
学习目标模型学习预测给定x_t时的噪声 ϵ
网络输入x_t 和时间步 t
网络输出估计的 ϵ_θ(x_t,t)
损失函数MSE between 预测噪声 和 真实噪声
实际操作从 x_0采样,生成x_t,训练模型反推噪声

下一讲预告(第 4 期):

我们将深入解读为什么损失函数可以简化为预测噪声的 MSE,并且用变分下界(ELBO)的推导说明这个做法的理论基础!

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2338148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

健康养生:开启活力生活新篇章

在当代社会,熬夜加班、久坐不动、外卖快餐成为许多人的生活常态,随之而来的是各种亚健康问题。想要摆脱身体的疲惫与不适,健康养生迫在眉睫,它是重获活力、拥抱美好生活的关键。​ 应对不良饮食习惯带来的健康隐患,饮…

记录学习的第二十九天

还是力扣每日一题。 本来想着像昨天一样两个循环搞定的,就下面👇🏻 不过,结果肯定是超时啦,中等题是吧。 正确答案是上面的。 之后就做了ls题单第一部分,首先是定长滑窗问题 这种题都是有套路的&#xff0…

Express学习笔记(六)——前后端的身份认证

目录 1. Web 开发模式 1.1 服务端渲染的 Web 开发模式 1.2 服务端渲染的优缺点 1.3 前后端分离的 Web 开发模式 1.4 前后端分离的优缺点 1.5 如何选择 Web 开发模式 2. 身份认证 2.1 什么是身份认证 2.2 为什么需要身份认证 2.3 不同开发模式下的身份认证 3. Sessio…

leetcode 309. Best Time to Buy and Sell Stock with Cooldown

目录 题目描述 第一步,明确并理解dp数组及下标的含义 第二步,分析并理解递推公式 1.求dp[i][0] 2.求dp[i][1] 3.求dp[i][2] 第三步,理解dp数组如何初始化 第四步,理解遍历顺序 代码 题目描述 这道题与第122题的区别就是卖…

优化自旋锁的实现

在《C11实现一个自旋锁》介绍了分别使用TAS和CAS算法实现自旋锁的方案,以及它们的优缺点。TAS算法虽然实现简单,但是因为每次自旋时都要导致一场内存总线流量风暴,对全局系统影响很大,一般都要对它进行优化,以降低对全…

SS25001-多路复用开关板

1 概述 1.1 简介 多路复用开关板是使用信号继电器实现2线制的多路复用开关板卡;多路复用开关是一种可以将一个输入连接到多个输出或一个输出连接到多个输入的拓扑结构。这种拓扑通常用于扫描,适合将一系列通道自动连接到公共线路的的设备。多路复用开…

【AI News | 20250418】每日AI进展

AI Repos 1、exa-mcp-server AI助手通过Exa获得实时网络信息获取的能力,提供结构化的搜索结果,返回包括标题、URL以及内容片段在内的结构化结果;会把最近的搜索结果缓存为资源,下次再搜索相同的内容时可以直接使用缓存&#xff1…

Dify LLM大模型参数(一)

深入了解大语言模型(LLM)的参数设置 模型的参数对模型的输出效果有着至关重要的影响。不同的模型会拥有不同的参数,而这些参数的设置将直接影响模型的生成结果。以下是 DeepSeek 模型参数的详细介绍: 温度(Tempera…

展示数据可视化的魅力,如何通过图表、动画等形式让数据说话

在当今信息爆炸的时代,数据的量级和复杂性不断增加。如何从海量数据中提取有价值的信息,并将其有效地传达给用户,成为了一个重要的课题。数据可视化作为一种将复杂数据转化为直观图形、图表和动画的技术,能够帮助用户快速理解数据…

时序预测 | Matlab实现基于VMD-WOA-ELM和VMD-ELM变分模态分解结合鲸鱼算法优化极限学习机时间序列预测

时序预测 | Matlab实现基于VMD-WOA-ELM和VMD-ELM变分模态分解结合鲸鱼算法优化极限学习机时间序列预测 目录 时序预测 | Matlab实现基于VMD-WOA-ELM和VMD-ELM变分模态分解结合鲸鱼算法优化极限学习机时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab…

基于EasyX库开发的球球大作战游戏

目录 球球大作战 一、开发环境 二、流程图预览 三、代码逻辑 1、初始化时间 2、设置开始界面大小 3、设置开始界面 4、让玩家选择速度 5、设置玩家小球、人机小球、食物的属性 6、一次性把图绘制到界面里 7、进入死循环 8、移动玩家小球 9、移动人机 10、食物刷新…

《系统分析师-第三阶段—总结(一)》

背景 采用三遍读书法进行阅读,此阶段是第三遍。 过程 第一章 第二章 总结 在这个过程中,对导图的规范越来越清楚,开始结构化,找关系,找联系。

AI——K近邻算法

文章目录 一、什么是K近邻算法二、KNN算法流程总结三、Scikit-learn工具1、安装2、导入3、简单使用 三、距离度量1、欧式距离2、曼哈顿距离3、切比雪夫距离4、闵可夫斯基距离5、K值的选择6、KD树 一、什么是K近邻算法 如果一个样本在特征空间中的k个最相似(即特征空…

用 NLP + Streamlit,把问卷变成能说话的反馈

网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…

TCP/IP和UDP协议的发展历程

TCP/IP和UDP协议的发展历程 引言 互联网的发展史是人类技术创新的辉煌篇章,而在这一发展过程中,通信协议发挥了奠基性的作用。TCP/IP(传输控制协议/互联网协议)和UDP(用户数据报协议)作为互联网通信的基础…

Function Calling的时序图(含示例)

🧍 用户: 发起请求,输入 prompt(比如:“请告诉我北京的天气”)。 🟪 应用: 将用户输入的 prompt 和函数定义(包括函数名、参数结构等)一起发给 OpenAI。 …

若依框架修改左侧菜单栏默认选中颜色

1.variables.sacc中修改为想要的颜色 2.给目标设置使用的颜色

搜广推校招面经七十八

字节推荐算法 一、实习项目:多任务模型中的每个任务都是做什么?怎么确定每个loss的权重 这个根据实际情况来吧。如果实习时候用了moe,就可能被问到。 loss权重的话,直接根据任务的重要性吧。。。 二、特征重要性怎么判断的&…

广搜bfs-P1443 马的遍历

P1443 马的遍历 题目来源-洛谷 题意 要求马到达棋盘上任意一个点最少要走几步 思路 国际棋盘规则是马的走法是-日字形,也称走马日,即x,y一个是走两步,一个是一步 要求最小步数,所以考虑第一次遍历到的点即为最小步数&#xff…

强化学习算法系列(六):应用最广泛的算法——PPO算法

强化学习算法 (一)动态规划方法——策略迭代算法(PI)和值迭代算法(VI) (二)Model-Free类方法——蒙特卡洛算法(MC)和时序差分算法(TD) (三)基于动作值的算法——Sarsa算法与Q-Learning算法 (四…