基于Stable Diffusion XL模型进行文本生成图像的训练

news2025/5/11 23:02:56

基于Stable Diffusion XL模型进行文本生成图像的训练

flyfish

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME \
  --enable_xformers_memory_efficient_attention \
  --resolution=512 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=10000 \
  --use_8bit_adam \
  --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --report_to="wandb" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir="sdxl-naruto-model" \
  --push_to_hub

环境变量部分

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
  • MODEL_NAME:指定预训练模型的名称或路径。这里使用的是 stabilityai/stable-diffusion-xl-base-1.0,也就是Stable Diffusion XL的基础版本1.0。
  • VAE_NAME:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix 是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。
  • DATASET_NAME:指定训练所使用的数据集名称或路径。这里使用的是 lambdalabs/naruto-blip-captions,是一个包含火影忍者相关图像及其描述的数据集。

accelerate launch 命令参数部分

accelerate launch train_text_to_image_sdxl.py \

这行代码使用 accelerate 工具来启动 train_text_to_image_sdxl.py 脚本,accelerate 可以帮助我们在多GPU、TPU等环境下进行分布式训练。

脚本参数部分

  • --pretrained_model_name_or_path=$MODEL_NAME:指定预训练模型的名称或路径,这里使用前面定义的 MODEL_NAME 环境变量。
  • --pretrained_vae_model_name_or_path=$VAE_NAME:指定预训练VAE模型的名称或路径,使用前面定义的 VAE_NAME 环境变量。
  • --dataset_name=$DATASET_NAME:指定训练数据集的名称或路径,使用前面定义的 DATASET_NAME 环境变量。
  • --enable_xformers_memory_efficient_attention:启用 xformers 库的内存高效注意力机制,能减少训练过程中的内存占用。
  • --resolution=512 --center_crop --random_flip
    • --resolution=512:将输入图像的分辨率统一调整为512x512像素。
    • --center_crop:对图像进行中心裁剪,使其达到指定的分辨率。
    • --random_flip:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
  • --proportion_empty_prompts=0.2:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。
  • --train_batch_size=1:每个训练批次包含的样本数量为1。
  • --gradient_accumulation_steps=4 --gradient_checkpointing
    • --gradient_accumulation_steps=4:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。
    • --gradient_checkpointing:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
  • --max_train_steps=10000:最大训练步数为10000步。
  • --use_8bit_adam:使用8位Adam优化器,能减少内存占用。
  • --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
    • --learning_rate=1e-06:学习率设置为1e-6。
    • --lr_scheduler="constant":学习率调度器设置为常数,即训练过程中学习率保持不变。
    • --lr_warmup_steps=0:学习率预热步数为0,即不进行学习率预热。
  • --mixed_precision="fp16":使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。
  • --report_to="wandb":将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。
  • --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
    • --validation_prompt="a cute Sundar Pichai creature":指定验证时使用的文本提示,这里是“一个可爱的桑达尔·皮查伊形象”。
    • --validation_epochs 5:每5个训练轮次进行一次验证。
  • --checkpointing_steps=5000:每5000步保存一次模型的检查点。
  • --output_dir="sdxl-naruto-model":指定训练好的模型的输出目录为 sdxl-naruto-model
  • --push_to_hub:将训练好的模型推送到Hugging Face模型库。

离线环境运行

# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹

# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"

# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME \
  --enable_xformers_memory_efficient_attention \
  --resolution=512 --center_crop --random_flip \
  --proportion_empty_prompts=0.2 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --max_train_steps=10000 \
  --use_8bit_adam \
  --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
  --checkpointing_steps=5000 \
  --output_dir="sdxl-naruto-model"

移除需要外网连接的参数:去掉 --report_to="wandb"--push_to_hub 参数,因为 wandb 需要外网连接来上传训练指标,--push_to_hub 则需要外网连接把模型推送到Hugging Face模型库。

推理

from diffusers import DiffusionPipeline
import torch

model_path = "you-model-id-goes-here" # <-- 替换为你的模型路径
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

训练后的文件夹结构

.
├── checkpoint-10000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── model.safetensors
├── text_encoder_2
│   ├── config.json
│   └── model.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model-00001-of-00002.safetensors
│   ├── diffusion_pytorch_model-00002-of-00002.safetensors
│   └── diffusion_pytorch_model.safetensors.index.json
└── vae
    ├── config.json
    └── diffusion_pytorch_model.safetensors

LoRA训练

accelerate launch train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=1024 --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=2 --checkpointing_steps=500 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --seed=42 \
  --output_dir="sd-naruto-model-lora-sdxl" \
  --validation_prompt="cute dragon creature"

推理

from diffusers import DiffusionPipeline
import torch

sdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"

pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)

prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

LoRA训练后的文件夹结构

├── checkpoint-1000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-1500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-2000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
└── pytorch_lora_weights.safetensors

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2373490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Facebook的元宇宙新次元:社交互动如何改变?

科技的浪潮正将我们推向一个全新的时代——元宇宙时代。Facebook&#xff0c;这个全球最大的社交网络平台&#xff0c;已经宣布将公司名称更改为 Meta&#xff0c;全面拥抱元宇宙概念。那么&#xff0c;元宇宙究竟是什么&#xff1f;它将如何改变我们的社交互动方式呢&#xff…

概统期末复习--速成

随机事件及其概率 加法公式 推三个的时候ABC&#xff0c;夹逼准则 减法准则 除法公式 相互独立定义 两种分析 两个解法 古典概型求概率&#xff08;排列组合&#xff09; 分步相乘、分类相加 全概率公式和贝叶斯公式 两阶段问题 第一个小概率*A在小概率的概率。。。累计 …

n8n系列(1)初识n8n:工作流自动化平台概述

1. 引言 随着各类自动化工具的涌现,n8n作为一款开源的工作流自动化平台,凭借其灵活性、可扩展性和强大的集成能力,正在获得越来越多技术团队的青睐。 本文作为n8n系列的开篇,将带您全面了解这个强大的自动化平台,探索其起源、特性以及与其他工具的差异,帮助您判断n8n是否…

QT6 源(82):阅读与注释日历类型 QCalendar,本类并未完结,儒略历,格里高利历原来就是公历,

&#xff08;1&#xff09;本代码来自于头文件 qcalendar . h &#xff1a; #ifndef QCALENDAR_H #define QCALENDAR_H#include <limits>#include <QtCore/qglobal.h> #include <QtCore/qlocale.h> #include <QtCore/qstring.h> #include <QtCore/…

CVE体系若消亡将如何影响网络安全防御格局

CVE体系的核心价值与当前危机 由MITRE运营的通用漏洞披露&#xff08;CVE&#xff09;项目的重要性不容低估。25年来&#xff0c;它始终是网络安全专业人员理解和缓解安全漏洞的基准参照系。通过提供标准化的漏洞命名与分类方法&#xff0c;这套体系为防御者建立了理解、优先级…

OpenKylin安装Elastic Search8

一、环境准备 Java安装 安装过程此处不做赘述&#xff0c;使用以下命令检查是否安装成功。 java -version 注意&#xff1a;Elasticsearch 自 7.0 版本起内置了 OpenJDK&#xff0c;无需单独安装。但如需自定义 JDK&#xff0c;可设置 JAVA_HOME。 二、安装Elasticsearch …

【ARM AMBA AHB 入门 3 -- AHB 总线介绍】

请阅读【ARM AMBA 总线 文章专栏导读】 文章目录 AHB Bus 简介AHB Bus 构成AHB BUS 工作机制AHB 传输阶段 AHB InterfacesAHB仲裁信号 AHB 数据访问零等待传输(no waitstatetransfer)等待传输(transfers with wait states)多重传送(multipletransfer)--Pipeline AHB 控制信号 A…

多模态大模型中的视觉分词器(Tokenizer)前沿研究介绍

文章目录 引言MAETok背景方法介绍高斯混合模型&#xff08;GMM&#xff09;分析模型架构 实验分析总结 FlexTok背景方法介绍模型架构 实验分析总结 Emu3背景方法介绍模型架构训练细节 实验分析总结 InternVL2.5背景方法介绍模型架构 实验分析总结 LLAVA-MINI背景方法介绍出发点…

sqli-labs靶场第二关——数字型

一&#xff1a;查找注入类型&#xff1a; 输入 ?id1--与第一关的差别&#xff1a;报错; 说明不是字符型 渐进测试&#xff1a;?id1--&#xff0c;结果正常&#xff0c;说明是数字型 二&#xff1a;判断列数和回显位 ?id1 order by 3-- 正常&#xff0c; 说明有三列&am…

[模型选择与调优]机器学习-part4

七 模型选择与调优 1 交叉验证 (1) 保留交叉验证HoldOut HoldOut Cross-validation&#xff08;Train-Test Split&#xff09; 在这种交叉验证技术中&#xff0c;整个数据集被随机地划分为训练集和验证集。根据经验法则&#xff0c;整个数据集的近70%被用作训练集&#xff…

【计算机网络-数据链路层】以太网、MAC地址、MTU与ARP协议

&#x1f4da; 博主的专栏 &#x1f427; Linux | &#x1f5a5;️ C | &#x1f4ca; 数据结构 | &#x1f4a1;C 算法 | &#x1f152; C 语言 | &#x1f310; 计算机网络 上篇文章&#xff1a;传输层-TCP协议TCP核心机制与可靠性保障 下篇文章&#xff1a; 网络…

Kafka 与 RabbitMQ、RocketMQ 有何不同?

一、不同的诞生背景&#xff0c;塑造了不同的“性格” 名称 背景与目标 产品定位 Kafka 为了解决 LinkedIn 的日志收集瓶颈&#xff0c;强调吞吐与持久化 更像一个“可持久化的分布式日志系统” RabbitMQ 出自金融通信协议 AMQP 的实现&#xff0c;强调协议标准与广泛适…

【MATLAB源码-第277期】基于matlab的AF中继系统仿真,AF和直传误码率对比、不同中继位置误码率对比、信道容量、中继功率分配以及终端概率。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 在AF&#xff08;放大转发&#xff09;中继通信系统中&#xff0c;信号的传输质量和效率受到多个因素的影响&#xff0c;理解这些因素对于系统的优化至关重要。AF中继通信的基本架构由发射端、中继节点和接收端组成。发射端负…

webRtc之指定摄像头设备绿屏问题

摘要&#xff1a;最近发现&#xff0c;在使用navigator.mediaDevices.getUserMedia({ deviceId: ‘xxx’}),指定设备的时候&#xff0c;video播放总是绿屏&#xff0c;发现关闭浏览器硬件加速不会出现&#xff0c;但显然这不是一个最好的方案; 播放后张这样 修复后 上代码 指定…

2023年03月青少年软件编程(图形化)等级考试四级编程题

求和 1.准备工作 &#xff08;1&#xff09;保留舞台中的小猫角色和白色背景。 2.功能实现 &#xff08;1&#xff09;计算1&#xff5e;100中&#xff0c;可以被3整除的数之和&#xff1b; &#xff08;2&#xff09;说出被3整除的数之和。 标准答案&#xff1a; 参考程序&…

ensp的华为小实验

1.先进行子网划分 2.进行接口的IP地址配置和ospf的简易配置&#xff0c;先做到全网小通 3.进行ospf优化 对区域所有区域域间路由器进行一个汇总 对区域1进行优化 对区域2.3进行nssa设置 4.对ISP的路由进行协议配置 最后ping通5.5.5.5

ragflow报错:KeyError: ‘\n “序号“‘

环境&#xff1a; ragflowv 0.17.2 问题描述&#xff1a; ragflow报错&#xff1a;KeyError: ‘\n “序号”’ **1. 推荐表&#xff08;输出json格式&#xff09;** [{"},{},{"},{} ]raceback (most recent call last): May 08 20:06:09 VM-0-2-ubuntu ragflow-s…

FHE与后量子密码学

1. 引言 近年来&#xff0c;关于 后量子密码学&#xff08;PQC, Post-Quantum Cryptography&#xff09; 的讨论愈发热烈。这是因为安全专家担心&#xff0c;一旦有人成功研发出量子计算机&#xff0c;会发生什么可怕的事情。由于 Shor 算法的存在&#xff0c;量子计算机将能够…

CSS: 选择器与三大特性

标签选择器 标签选择器就是选择一些HTML的不同标签&#xff0c;由于它们的标签需求不同&#xff0c;所以CSS需要设置标签去选择它们&#xff0c;为满足它们的需求给予对应的属性 基础选择器 标签选择器 <!DOCTYPE html> <head><title>HOME</title>…

M0基础篇之ADC

本节课使用到的例程 一、Single模式例程基本配置的解释 在例程中我们只使用到了PA25这一个通道&#xff0c;因此我们使用的是Single这个模式&#xff0c;也就是我们在配置模式的时候使用的是单一转换。 进行多个通道的测量我们可以使用Sequence这个模式。 二、Single模式例程基…