【扩散模型系列3】DiT开源项目

news2025/6/11 8:09:25

文章目录

  • DiT原始项目
  • Fast-DiT readme
    • Sampling
    • Training
      • 训练之前的准备
      • 训练DiT
      • PyTorch 训练结果
      • 改进训练效果
    • Evaluation (FID, Inception Score, etc.)
  • 总结

DiT原始项目

该项目仅针对DiT训练,并未包含VAE 的训练

项目地址

论文主页

Fast-DiT readme

该项目仅针对DiT训练,并未包含VAE 的训练

项目地址

该项目是基于论文Scalable Diffusion Models with Transformers的pytorch改进实现

包含:

  • PyTorch的改进实现 和原始的DiT实现

  • 在ImageNet (512x512 and 256x256)预训练的条件分类的 DiT 模型;

  • 一个独立的 Hugging Face Space 和 Colab notebook,用于运行预训练的DiT-XL/2模型

  • 改进的DiT 训练脚本 和一些 训练建议

启动
首先,下载开源代码

git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT

我们提供了 environment.yml 文件,可创建Conda 虚拟环境。

如果想要在本地CPU运行预训练的模型,可以在文件中删除cudatoolkitpytorch-cuda 相关的依赖项。

conda env create -f environment.yml
conda activate DiT

Sampling

在这里插入图片描述

预训练的DiT checkpoints

你可以使用预训练模型样例sample.py

预训练的DiT模型权重会根据使用的模型自动下载。

根据输入模型尺寸的不同,脚本中进行了不同的参数设置转换(256x256 and 512x512),比如针对512x512 DiT-XL/2 模型,你可以使用以下命令:

python sample.py --image-size 512 --seed 1

为了更加方便,我们的预训练模型也可直接进行下载:

DiT ModelImage ResolutionFID-50KInception ScoreGflops
XL/2256x2562.27278.24119
XL/2512x5123.04240.82525

自定义的DiT checkpoints

如果你想训练一个新的DiT 模型,可使用 train.py (see below)。

你可以增加一个参数–ckpt 使用你自己的checkpoint进行演示。比如运行一个 256x256 DiT-L/4 模型,可使用以下命令 :

If you’ve trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Training

训练之前的准备

在一个GPU节点上抽取ImageNet 特征:

torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features

训练DiT

We provide a training script for DiT in train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning.

我们提供了一个DiT的训练脚本 train.py。该训练脚本可以被用作训练条件分类的 DiT 模型,但是也可修改后用于其他类型的条件。

To launch DiT-XL/2 (256x256) training with 1 GPUs on one node:

accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

To launch DiT-XL/2 (256x256) training with N GPUs on one node:

accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features

或者,你可以选择提取并训练文件夹training options中的脚本。

PyTorch 训练结果

我们训练了 DiT-XL/2 和 DiT-B/4 模型

我们使用PyTorch训练脚本从头开始训练DiT-XL/2和DiT-B/4模型,以验证它重现了原始的JAX结果,达到数十万次训练迭代。

在我们的实验中,在合理的随机变化范围内,pytorch训练的模型与jax训练的模型相比,给出了类似(有时略好)的结果。一些数据点如下:

DiT ModelTrain StepsFID-50K (JAX Training)FID-50K (PyTorch Training)PyTorch Global Training Seed
XL/2400K19.518.142
B/4400K68.468.942
B/4400K68.468.3100

这些模型在256x256分辨率下进行训练; 我们使用8x A100来训练XL/2,使用4x A100来训练B/4。注意,这里的FID是使用mse VAE解码器,在没有指导的情况下(cfg-scale=1),通过250个DDPM采样步骤计算得到的。

改进训练效果

与原始实现相比,实现了一些训练加速和节省内存的特征,包括梯度检查点、混合精度训练和预提取的VAE特征,在DiT-XL/2上的速度提高了95%,内存减少了60%。一些数据点使用A100的全局批处理大小为128:

gradient checkpointingmixed precision trainingfeature pre-extractiontraining speedmemory
-out of memory
0.43 steps/sec44045 MB
0.56 steps/sec40461 MB
0.84 steps/sec27485 MB

Evaluation (FID, Inception Score, etc.)

我们提供一个’ sample_ddp.py '脚本,它可以并行地从DiT模型中对大量图像进行采样。

这个脚本生成一个样本文件夹和. npz文件一样,可以直接与ADM的TensorFlow评估套件一起使用,以计算FID, Inception分数和其他指标。

例如,要在N个gpu上从预训练的DiT-XL/2模型中采样50K张图像,运行:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

补充建议,详情请参见 sample_ddp.py

总结

开源项目仅能下载DiT-XL模型
微软开源了DiT-B模型,下载链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【毕业】 医药药店销售管理系统

1、引言 设计结课作业,课程设计无处下手,网页要求的总数量太多?没有合适的模板?数据库,java,python,vue,html作业复杂工程量过大?毕设毫无头绪等等一系列问题。你想要解决的问题&am…

OpenJDK 目前主要发展方向

Loom:得赶紧解决 synchronized pin 线程的问题(据说 Java 23 会解决,现在有预览版)。各个 Java 库需要改造原来使用 ThreadLocal 的方式:如果是为了穿参数,则可以使用 ScopedLocal;如果是对象池…

【leetcode热题】寻找旋转排序数组中的最小值 II

难度: 困难通过率: 38.7%题目链接:. - 力扣(LeetCode) 题目描述 假设按照升序排序的数组在预先未知的某个点上进行了旋转。 ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1,2] )。 请找出其中最小的…

激光打标机红光与激光不重合:原因及解决方案

激光打标机红光和激光不在一个位置的问题可能由多种原因导致。以下是一些可能的原因和解决方法: 1. 激光器光路调整不当:激光器光路调整不当会导致激光束偏移,从而使红光与激光不重合。解决方法是重新调整激光器的光路,确保激光束…

Session登陆实践

Session登陆实践 Session登录是一种常见的Web应用程序身份验证和状态管理机制。当用户成功登录到应用程序时,服务器会为其创建一个会话(session),并在会话中存储有关用户的信息。这样,用户在与应用程序交互的整个会话…

C语言逗号运算符(,)

在C语言中,逗号运算符(,)用于在表达式中分隔多个子表达式,并按照从左到右的顺序依次计算这些子表达式。逗号运算符的运算结果是最后一个子表达式的值。 逗号运算符的底层行为是依次计算每个子表达式,并将每个子表达式…

SSM框架,MyBatis-Plus的学习(下)

条件构造器 使用MyBatis-Plus的条件构造器,可以构建灵活高效的查询条件,可以通过链式调用来组合多个条件。 条件构造器的继承结构 Wrapper : 条件构造抽象类,最顶端父类 AbstractWrapper : 用于查询条件封装&#xf…

广度优先搜索和深度优先搜索

广度优先搜索 广度优先搜索(Breadth-First-Search,BFS)类似于二叉树的层序遍历算法(借助队列),其基本思想是:首先访问起始顶点,接着由v出发,依次访问v的各个未访问过的邻…

git命令行提交——github

1. 克隆仓库至本地 git clone 右键paste(github仓库地址) cd 仓库路径(进入到仓库内部准备提交文件等操作) 2. 查看main分支 git branch(列出本地仓库中的所有分支) 3. 创建新分支(可省…

纪年哥的文物挽救木牌

左(江南制造局,曾国藩书天道酬勤,李鸿章少荃印,光绪三十四年制造) 中(汉阳兵工厂,民国二十六年制造,公元1937年七月七日,抗日战争全面爆发) 右(…

linux、windows 动态库与静态库的实现

动态库与静态库的实现 在使用keil的时候遇到这样一个事情,我调用了一个函数,只有函数声明,但是我想查看函数的实现却不行,为什么会这样,这不来了嘛, 我们在使用printf函数等,都是加上头文件直接调用&…

HarmonyOS NEXT应用开发案例——列表编辑实现

介绍 本示例介绍用过使用ListItem组件属性swipeAction实现列表左滑编辑效果的功能。 该场景多用于待办事项管理、文件管理、备忘录的记录管理等。 效果图预览 使用说明: 点击添加按钮,选择需要添加的待办事项。长按待办事项,点击删除后&am…

考研408 2014年第41题(二叉树带权路径长度【WPL】)

function.h(结构体)&#xff1a; // // Created by legion on 2024/3/5. //#ifndef INC_14_4_TREE_FUNCTION_H #define INC_14_4_TREE_FUNCTION_H #include <stdio.h> #include <stdlib.h>typedef int BiElemType; typedef struct BiTNode{BiElemType weight;//直…

【Python】Python Astar算法生成最短路径GPS轨迹

简介 最短路径问题是计算机科学中一个经典问题&#xff0c;它涉及找到图中两点之间距离最短的路徑。在实际应用中&#xff0c;最短路径算法用于解决广泛的问题&#xff0c;例如导航、物流和网络优化。 步骤 1&#xff1a;加载道路网络数据 要计算最短路径&#xff0c;我们需…

【Python】装饰器函数

专栏文章索引&#xff1a;Python 原文章&#xff1a;装饰器函数基础_装饰函数-CSDN博客 目录 1. 学习装饰器的基础 2.最简单的装饰器 3.闭包函数装饰器 4.装饰器将传入的函数中的值大写 5. 装饰器的好处 6. 多个装饰器的执行顺序 7. 装饰器传递参数 8. 结语 1. 学习装饰…

【UE5】创建蓝图

创建GamePlay需要的相关蓝图 项目资源文末百度网盘自取 在 内容游览器 文件夹中创建文件夹&#xff0c;命名为 Blueprints &#xff0c;用来放这个项目的所有蓝图(Blueprint) 在 Blueprints 文件夹下新建文件夹 GamePlay ,用存放GamePlay相关蓝图 在 Blueprints 文件夹下创建文…

Java17 --- SpringCloud初始项目创建

目录 一、cloud项目创建 1.1、项目编码规范 1.2、注解生效激活 1.3、导入父工程maven的pom依赖 二、创建子工程并导入相关pom依赖 2.1、相关配置文件 2.1.1、数据库配置文件内容 2.1.2、自动生成文件配置内容 三、创建微服务8001子工程 3.1、导入相关pom依赖 3.…

利用IDEA创建Java项目使用Servlet工具

【文件】-【项目结构】 【模块】-【依赖】-【】-【JAR】 找到Tomcat的安装路径打开【lib】找到【servlet.jar】点击【确定】 勾选上jar,然后【应用】-【确定】 此时新建文件可以发现多了一个Servlet&#xff0c;我们点击会自动创建一个继承好的Servlet类

对比学习概念与如何标注标签

对比学习公式讲述 对比学习倾向于将同一图像的转换视图之间的一致性最大化&#xff0c;而将不同图像的转换视图之间的一致性最小化。令是一个输出特征空间的卷积神经网络。一个图像x的两个增广图像补丁通过进行映射&#xff0c;生成一个查询特征q和一个关键特征k。此外&#x…

ospf静态路由实验简述

1、ospf静态路由实验简述 实验拓扑图 实验命令 r2: sys sysname r2 undo info enable int loopb 0 ip add 2.2.2.2 32 quit int e0/0/0 ip add 23.1.1.2 24 quit ospf 1 area 0 network 23.1.1.0 0.0.0.255 network 2.2.2.2 0.0.0.0 ret r3: sys sysname r3 undo info enable …