R3GAN训练自己的数据集

news2025/6/6 12:30:18

简介

简介:这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。

论文题目:The GAN is dead; long live the GAN! A Modern Baseline GAN

会议:NeurIPS 2024

源码地址:https://www.github.com/brownvc/R3GAN

本文在调试代码的时候对代码做了一些修改,如果有遇到报错的问题可以直接复制我这篇博客修改后的代码:R3GAN利用配置好的Pytorch训练自己的数据集-CSDN博客这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。 https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link

摘要:论文反驳了GANs难以训练的普遍观点,提出了一个理论有保障的现代GAN基线。首先,推导出一个良好行为的正则化相对论GAN损失函数,解决了模式丢弃和不收敛问题,并数学证明了其局部收敛性。其次,该损失函数允许丢弃所有经验性技巧,用现代架构替换常见GANs中的过时骨干网络。以StyleGAN2为例,展示了简化和现代化的路线图,产生了新的极简基线R3GAN。尽管简单,该方法在FFHQ、ImageNet、CIFAR和Stacked MNIST数据集上超越了StyleGAN2,与最先进的GANs和扩散模型相比表现优异。

模型结构

生成器架构

核心设计原则:

  • 基于现代化ResNet架构,摒弃VGG-like设计
  • 每个分辨率阶段包含一个过渡层和两个残差块
  • 采用分组卷积和倒置瓶颈设计

关键特性:

  • 无归一化层:避免批量归一化等数据相关的归一化
  • Fix-up初始化:零初始化每个残差块的最后一层卷积
  • 双线性插值:用于上采样,避免棋盘效应

鉴别器架构

设计特点:

  • 与生成器完全对称的架构
  • 相同的残差块结构和过渡层设计
  • 分类器头:全局4×4深度卷积 + 线性层

损失函数

相对论配对GAN损失 (RpGAN):

L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]

R1正则化:

R1(ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_D)

R2正则化:

R2(θ,ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_θ)

训练自己的数据集

1. 准备数据集

首先使用 dataset_tool.py 将您的图像数据转换为适合训练的格式:

# 从文件夹创建数据集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip

# 如果需要调整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \
    --resolution=256x256 --transform=center-crop

数据集要求:

  • 图像必须是正方形(如256x256, 512x512)
  • 分辨率必须是2的幂次(64, 128, 256, 512, 1024等)
  • 支持RGB或灰度图像
  • 可以是文件夹或ZIP格式

2. 创建自定义训练配置

train.py 中添加您自己的预设配置。参考现有预设,在 main() 函数中添加:

if opts.preset == 'YOUR_DATASET':
    # 网络架构参数
    WidthPerStage = [768, 768, 768, 512, 256]  # 每阶段宽度
    BlocksPerStage = [2, 2, 2, 2, 2]           # 每阶段块数
    CardinalityPerStage = [96, 96, 96, 48, 24] # 每阶段基数
    FP16Stages = [-1, -2, -3, -4]              # FP16优化的阶段
    NoiseDimension = 64                         # 噪声维度
    
    # 如果是条件生成(有类别标签)
    if opts.cond:
        c.G_kwargs.ConditionEmbeddingDimension = NoiseDimension
        c.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]
    
    # 训练调度参数
    ema_nimg = 500 * 1000      # EMA开始的图像数
    decay_nimg = 2e7           # 总衰减图像数
    
    # 各种调度器
    c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }
    c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }
    c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }
    c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }
    c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }

3. 开始训练

# 无条件生成(如人脸、风景等)
python train.py \
    --outdir=./training-runs \
    --data=./datasets/your_dataset.zip \
    --gpus=4 \
    --batch=256 \
    --mirror=1 \
    --aug=1 \
    --preset=YOUR_DATASET \
    --tick=1 \
    --snap=200

# 条件生成(有类别标签)
python train.py \
    --outdir=./training-runs \
    --data=./datasets/your_dataset.zip \
    --gpus=4 \
    --batch=256 \
    --mirror=1 \
    --aug=1 \
    --cond=1 \
    --preset=YOUR_DATASET \
    --tick=1 \
    --snap=200

4. 参数说明

  • --gpus: GPU数量
  • --batch: 总批次大小
  • --mirror: 是否启用水平翻转增强
  • --aug: 是否启用数据增强
  • --cond: 是否训练条件模型(需要标签)
  • --tick: 多少kimg输出一次进度
  • --snap: 多少tick保存一次模型

5. 生成图像

训练完成后,使用保存的模型生成图像:

# 生成8张图像
python gen_images.py \
    --seeds=0-7 \
    --outdir=generated_images \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

# 条件生成(指定类别)
python gen_images.py \
    --seeds=0-7 \
    --outdir=generated_images \
    --class=5 \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

6. 评估指标

python calc_metrics.py \
    --metrics=fid50k_full,kid50k_full \
    --data=./datasets/your_dataset.zip \
    --network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

7.报错指南

1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment

解决办法:在 train.py 中,NoiseDimension 只在特定的预设配置块中定义(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset 参数不匹配任何现有预设,这个变量就不会被定义,导致使用时出错。可以使用作者定义好的预先设置。

--preset=CIFAR10
--preset=FFHQ-64  
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64

2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".

解决办法:这个错误是因为R3GAN使用了自定义的CUDA操作符,需要C++编译器来编译。在Windows系统上缺少MSVC/GCC/CLANG编译器。

修改 torch_utils/custom_ops.py:找到 get_plugin 函数(大约第84行),在函数开头添加:

def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
    # 禁用所有自定义插件
    return None


def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
    # 强制使用 'ref' 实现
    impl = 'ref'

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2397834.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【容器docker】启动容器kibana报错:“message“:“Error: Cannot find module ‘./logs‘

说明: 1、服务器数据盘挂了,然后将以前的数据用rsync拷贝过去,启动容器kibana服务,报错信息如下图所示: 2、可能是拷贝docker文件夹,有些文件没有拷贝过去,导致无论是给文件夹授权用户kibana或者…

C#里与嵌入式系统W5500网络通讯(4)

怎么样修改W5500里的socket收发缓冲区呢? 需要进行下面的工作,首先要了解socket缓冲区的作用,接着了解缓冲区的硬件资源, 最后就是要了解自己的需求,比如自己需要哪个socket的收发送缓冲区多大。 硬件的寄存器为: 这是 W5500 数据手册中关于 Sn_RXBUF_SIZE(Socket n …

Spring boot集成milvus(spring ai)

服务器部署Milvus Run Milvus with Docker Compose (Linux) milvus版本可在docker-compose.yml中进行image修改 启动后,docker查看启动成功 spring boot集成milvus 参考了这篇文章 Spring AI开发RAG示例,理解RAG执行原理 但集成过程中遇到了一系列…

Visual Studio+SQL Server数据挖掘

这里写自定义目录标题 工具准备安装Visual studio 2017安装SQL Server安装SQL Server Management Studio安装analysis service SSMS连接sql serverVisual studio新建项目数据源数据源视图挖掘结构部署模型设置挖掘预测 部署易错点 工具准备 Visual studio 2017 analysis servi…

通过阿里云服务发送邮件

通过阿里云服务发送邮件 1. 整体描述2. 方案选择2.1 控制台发送2.2 API接口接入2.3 SMTP接口接入2.4 结论 3. 前期工作3.1 准备工作3.2 配置工作3.3 总结 4. 收费模式4.1 免费额度4.2 资源包4.3 按量付费 5. Demo开发5.1 选择SMTP服务器5.2 pom引用5.3 demo代码5.4 运行结果 6 …

Vad-R1:通过从感知到认知的思维链进行视频异常推理

文章目录 速览摘要1 引言2 相关工作视频异常检测与数据集视频多模态大语言模型具备推理能力的多模态大语言模型 3 方法:Vad-R13.1 从感知到认知的思维链(Perception-to-Cognition Chain-of-Thought)3.2 数据集:Vad-Reasoning3.3 A…

黑马Java面试笔记之MySQL篇(事务)

一. 事务的特性 事务的特性是什么?可以详细说一下吗? 事务是一组操作的集合,他是一个不可分割的工作单位,事务会把所有的操作作为一个整体一起向系统提交或撤销操作请求,即这些操作要么同时成功,要么同时失…

群辉(synology)NAS老机器连接出现网页端可以进入,但是本地访问输入一样的账号密码是出现错误时解决方案

群辉(synology)NAS老机器连接出现网页端可以进入,但是本地访问输入一样的账号密码是出现错误时解决方案 老机器 装的win7 系统 登入后端网页端的时候正常,但是本地访问登入时输入登入网页端一样的密码时候出现问题解决方案 1.登…

【深度学习】实验四 卷积神经网络CNN

实验四 卷积神经网络CNN 一、实验学时: 2学时 二、实验目的 掌握卷积神经网络CNN的基本结构;掌握数据预处理、模型构建、训练与调参;探索CNN在MNIST数据集中的性能表现; 三、实验内容 实现深度神经网络CNN。 四、主要实验步…

实现一个免费可用的文生图的MCP Server

概述 文生图模型为使用 Cloudflare Worker AI 部署 Flux 模型,是参照视频https://www.bilibili.com/video/BV1UbkcYcE24/?spm_id_from333.337.search-card.all.click&vd_source9ca2da6b1848bc903db417c336f9cb6b的复现Cursor MCP Server实现是参照文章https:/…

【手搓一个原生全局loading组件解决页面闪烁问题】

页面闪烁效果1 页面闪烁效果2 封装一个全局loading组件 class GlobalLoading extends HTMLElement {constructor() {super();this.attachShadow({ mode: open });}connectedCallback() {this.render();this.init();}render() {this.shadowRoot.innerHTML <style>.load…

CSS基础巩固-基础-选择

目录 CSS是如何工作的&#xff1f; 当浏览器遇到无法解析的CSS代码时 如何导入CSS样式&#xff1f; 改变元素的默认样式 选择 前缀符号&#xff08;后面会具体介绍&#xff09; 优先级 同时应用样式到多个类上 属性选择器 伪类 伪元素 关系选择器 后代选择器 子代…

一种在SQL Server中传递多行数据的方法

这是一种比较偷懒的方法&#xff0c;其实各种数据库对Json 支持的很好。sql server 、oracle都不错。所以可以直接传json declare 这是一个json varchar(max) set 这是一个json{"data":[{"code":"1","name":"啥1"},{"…

【Docker 从入门到实战全攻略(一):核心概念 + 命令详解 + 部署案例】

1. 是什么 Docker 是一个用于开发、部署和运行应用程序的开源平台&#xff0c;它使用 容器化技术 将应用及其依赖打包成独立的容器&#xff0c;确保应用在不同环境中一致运行。 2. Docker与虚拟机 2.1 Docker&#xff08;容器化&#xff09; 容器化是一种轻量级的虚拟化技术…

github 提交失败,连接不上

1. 第一种情况&#xff0c;开了加速器&#xff0c;导致代理错误 删除hosts文件里相关的github代理地址 2. 有些ip不支持22端口连接,改为443连接 ssh -vT gitgithub.com // 命令执行结果 OpenSSH_for_Windows_9.5p1, LibreSSL 3.8.2 debug1: C…

系统架构设计师(一):计算机系统基础知识

系统架构设计师&#xff08;一&#xff09;&#xff1a;计算机系统基础知识 引言计算机系统概述计算机硬件处理器处理器指令集常见处理器 存储器总线总线性能指标总线分类按照总线在计算机中所处的位置划分按照连接方式分类按照功能分类 接口接口分类 计算机软件文件系统文件类…

清理 pycharm 无效解释器

1. 起因&#xff0c; 目的: 经常使用 pycharm 来调试深度学习项目&#xff0c;每次新建虚拟环境&#xff0c;都是显示一堆不存在的名称&#xff0c;删也删不掉。 总觉得很烦&#xff0c;是个痛点。决定深入研究一下。 2. 先看效果 效果是能行&#xff0c;而且清爽多了。 3. …

手机如何压缩文件为 RAR 格式:详细教程与工具推荐

在如今这个数字化时代&#xff0c;手机已经成为我们生活中不可或缺的工具。随着我们使用手机的频率越来越高&#xff0c;手机中的文件也越来越多&#xff0c;照片、视频、文档等各种类型的文件不断占据着手机的存储空间。 据统计&#xff0c;普通用户的手机存储空间中&#xf…

Java 注解式限流教程(使用 Redis + AOP)

Java 注解式限流教程&#xff08;使用 Redis AOP&#xff09; 在上一节中&#xff0c;我们已经实现了基于 Redis 的请求频率控制。现在我们将进一步升级功能&#xff0c;使用 Spring AOP 自定义注解 实现一个更优雅、可复用的限流方式 —— 即通过 RateLimiter 注解&#xf…

C# XAML 基础:构建现代 Windows 应用程序的 UI 语言

在现代 Windows 应用程序开发中&#xff0c;XAML (eXtensible Application Markup Language) 扮演着至关重要的角色。作为一种基于 XML 的声明性语言&#xff0c;XAML 为 WPF (Windows Presentation Foundation)、UWP (Universal Windows Platform) 和 Xamarin.Forms 应用程序提…