生成对抗网络(GAN)原理

news2025/5/23 5:37:29

生成对抗网络(GAN)原理

  • 介绍
    • 示例代码
    • 一、GAN 的基本结构
      • 1. 生成器(Generator,记作 G)
      • 2. 判别器(Discriminator,记作 D)
    • 二、对抗过程(博弈思想)
    • 三、训练过程
    • 四、存在的问题与改进方向
      • 1. 模式崩溃(Mode Collapse)
      • 2. 训练不稳定
      • 3. 衡量指标困难
    • 五、GAN 的改进与变种
    • 六、GAN 的应用
    • ✅ DCGAN 与普通 GAN 的区别
    • ✅ DCGAN 示例(基于 MNIST)
      • 🔧 安装依赖
      • 🧠 DCGAN 架构代码(Generator + Discriminator)
    • 🧪 使用说明

介绍

示例代码

生成对抗网络(Generative Adversarial Network,GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种深度生成模型。它通过两个神经网络之间的博弈(对抗)过程,学习数据的生成分布,从而生成以假乱真的数据(如图像、语音等)。GAN 是近年来生成模型领域的重要突破,广泛应用于图像生成、风格迁移、图像修复等任务中。


一、GAN 的基本结构

GAN 主要由两个部分组成:

1. 生成器(Generator,记作 G)

  • 目标:生成尽可能真实的数据,欺骗判别器。
  • 输入:随机噪声向量(一般从正态分布或均匀分布中采样)
  • 输出:“伪造”的样本,尽可能与真实样本相似。

2. 判别器(Discriminator,记作 D)

  • 目标:判断输入数据是真实的样本还是生成器生成的伪造样本。
  • 输入:真实样本或生成样本
  • 输出:一个概率值,表示输入是“真实”的概率。

二、对抗过程(博弈思想)

GAN 的训练过程是一个零和博弈(min-max game):

  • 生成器试图最小化判别器对生成样本的识别能力;
  • 判别器试图最大化识别真实样本与生成样本的能力。

这个过程可以表示为一个最优化问题:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中:

  • p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据的分布;
  • p z ( z ) p_z(z) pz(z) 是生成器输入噪声的分布(如高斯分布);
  • D ( x ) D(x) D(x) 是判别器输出 x 是真实数据的概率;
  • G ( z ) G(z) G(z) 是生成器输出的伪造样本。

三、训练过程

  1. 固定生成器 G,训练判别器 D

    • 给 D 一部分真实样本(标签为 1);
    • 给 D 一部分 G 生成的样本(标签为 0);
    • 通过交叉熵损失训练 D,使其能区分真假样本。
  2. 固定判别器 D,训练生成器 G

    • 通过 G 生成假样本;
    • D 会判断其为假;
    • G 的目标是欺骗 D,即最大化 D ( G ( z ) ) D(G(z)) D(G(z)),让 D 判错;
    • 通常优化的是 log ⁡ D ( G ( z ) ) \log D(G(z)) logD(G(z)) 的反函数,例如 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1D(G(z))) 或更稳定的变体(如使用 feature matching 或 Wasserstein loss)。
  3. 交替训练 D 和 G,直到生成器生成的样本无法被判别器区分为假(判别器输出接近 0.5)。


四、存在的问题与改进方向

1. 模式崩溃(Mode Collapse)

生成器只学会生成一小部分模式样本,导致多样性丢失。

2. 训练不稳定

D 和 G 的能力不均衡、学习率不合适等因素可能导致 GAN 训练震荡或失败。

3. 衡量指标困难

GAN 的损失函数不能很好地反映生成质量,因此通常使用 FID、IS 等指标辅助评估。


五、GAN 的改进与变种

为了克服原始 GAN 的不足,研究人员提出了许多变种:

名称简介
DCGAN使用卷积神经网络的 GAN,适合图像数据生成
WGANWasserstein GAN,引入 Wasserstein 距离,解决训练不稳定问题
WGAN-GP在 WGAN 基础上加上梯度惩罚项,提高训练稳定性
CGAN条件 GAN,可以控制生成样本的类别(如生成特定数字)
CycleGAN用于图像风格转换(如马<->斑马、夏天<->冬天)
StyleGAN高质量人脸图像生成的里程碑,支持精细控制生成风格
BigGAN更大规模、更高分辨率的图像生成模型

六、GAN 的应用

  1. 图像生成(人脸、卡通、合成图像)
  2. 图像风格转换(如照片转油画)
  3. 图像修复与超分辨率重建
  4. 数据增强(医学图像、遥感等领域)
  5. 音频合成、文本生成(结合 RNN、Transformer)
  6. 深度伪造(Deepfake)

下面是一个使用 PyTorch 实现的 DCGAN(Deep Convolutional GAN)完整示例,用于在 MNISTCIFAR-10 数据集上生成图像。与普通 GAN 相比,DCGAN 使用卷积神经网络结构,更适合图像处理,生成效果更清晰、真实。


✅ DCGAN 与普通 GAN 的区别

模块普通 GANDCGAN
生成器结构全连接层(MLP)转置卷积(ConvTranspose2d)
判别器结构全连接层卷积层(Conv2d)
激活函数ReLU + SigmoidLeakyReLU + Sigmoid / No sigmoid
正则化无 / BatchNorm1dBatchNorm2d

✅ DCGAN 示例(基于 MNIST)

🔧 安装依赖

pip install torch torchvision matplotlib

🧠 DCGAN 架构代码(Generator + Discriminator)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os

# 超参数
img_size = 64
channels = 1  # 改为 3 可用于 CIFAR-10
latent_dim = 100
batch_size = 128
lr = 0.0002
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建目录
os.makedirs("dcgan_images", exist_ok=True)

# 数据预处理(MNIST 被 resize 成 64x64 )
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True
)

# 生成器(使用转置卷积)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            # 输入是 latent_dim 向量,输出 1024
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), latent_dim, 1, 1)
        return self.model(z)

# 判别器(使用卷积)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x).view(-1, 1).squeeze(1)

# 初始化模型
G = Generator().to(device)
D = Discriminator().to(device)

# 损失和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 训练 DCGAN
for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)

        # 标签
        valid = torch.ones(b_size, device=device)
        fake = torch.zeros(b_size, device=device)

        # ========== 训练判别器 ==========
        optimizer_D.zero_grad()
        real_loss = criterion(D(real_imgs), valid)

        z = torch.randn(b_size, latent_dim, device=device)
        gen_imgs = G(z)
        fake_loss = criterion(D(gen_imgs.detach()), fake)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # ========== 训练生成器 ==========
        optimizer_G.zero_grad()
        g_loss = criterion(D(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # 保存生成图像
    with torch.no_grad():
        z = torch.randn(64, latent_dim, device=device)
        gen_imgs = G(z)
        grid = make_grid(gen_imgs, nrow=8, normalize=True)
        save_image(grid, f"dcgan_images/{epoch:03d}.png")

print("DCGAN 训练完成,图像保存在 dcgan_images 文件夹中。")

🧪 使用说明

  • 若想改用 彩色图像(如 CIFAR-10),需:

    • channels = 3
    • 使用 datasets.CIFAR10 替代 MNIST
    • 修改 transforms.Normalize([0.5]*3, [0.5]*3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2383655.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

centos8 配置网桥,并禁止kvm默认网桥

环境背景&#xff1a; 我使用vmware部署了一台kvm服务器&#xff0c;网络模式是nat。我想要kvm创建的虚拟机可以访问公网&#xff1b;所以kvm默认的地址不行&#xff0c;我必须使用nat地址才可以&#xff1b; 实现方式&#xff1a; 创建一个网桥&#xff0c;将本地的网络接口…

【Node.js】全栈开发实践

个人主页&#xff1a;Guiat 归属专栏&#xff1a;node.js 文章目录 1. Node.js 全栈开发概述1.1 全栈开发的优势1.2 Node.js 全栈开发技术栈 2. 开发环境搭建2.1 Node.js 和 npm 安装2.2 开发工具安装2.3 版本控制设置2.4 项目初始化流程 3. 后端开发 (Node.js)3.1 Express 框架…

ubuntu sh安装包的安装方式

ubuntu sh安装包的安装方式以Miniconda2为例 https://repo.anaconda.com/miniconda/ 如果需要python2.7版本可下载以下版本 Miniconda2-latest-Linux-x86_64.sh 打开终端输入安装命令 sudo sh Miniconda2-latest-Linux-x86_64.sh 然后按提示安装&#xff0c;注意安装位置 …

OpenAI宣布:核心API支持MCP,助力智能体开发

今天凌晨&#xff0c;OpenAI全资收购io的消息成为头条。同时&#xff0c;OpenAI还宣布其核心API——Responses API支持MCP服务。过去&#xff0c;开发智能体需通过函数调用与外部服务交互&#xff0c;过程复杂且延迟高。而今&#xff0c;Responses API支持MCP后&#xff0c;开发…

微服务中的 AKF 拆分原则:构建可扩展系统的核心方法论

在数字化浪潮的推动下&#xff0c;互联网应用规模呈指数级增长&#xff0c;传统单体架构逐渐暴露出难以扩展、维护成本高等问题&#xff0c;微服务架构应运而生并成为企业应对复杂业务场景的主流选择。然而&#xff0c;随着业务的不断扩张和用户量的持续增加&#xff0c;如何确…

vue element-plus 集成多语言

main.js中 // 引入i18n import i18n from /i18n/index 使用i18 app.use(i18n) 在App.vue中 <template><el-config-provider :locale"locale" namespace"el" size"small"><router-view /></el-config-provider> </tem…

如何测试JWT的安全性:全面防御JSON Web Token的安全漏洞

在当今的Web应用安全领域&#xff0c;JSON Web Token(JWT)已成为身份认证的主流方案&#xff0c;但OWASP统计显示&#xff0c;错误配置的JWT导致的安全事件占比高达42%。本文将系统性地介绍JWT安全测试的方法论&#xff0c;通过真实案例剖析典型漏洞&#xff0c;帮助我们构建全…

车载网关策略 --- 车载网关重置前的请求转发机制

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 钝感力的“钝”,不是木讷、迟钝,而是直面困境的韧劲和耐力,是面对外界噪音的通透淡然。 生活中有两种人,一种人格外在意别人的眼光;另一种人无论…

EtpBot:安卓自动化脚本开发神器

EtpBot 是什么&#xff1f; EtpBot是一款专为安卓设备设计的自动化脚本开发工具&#xff0c;支持用户通过编写脚本实现自动化操作。该模块提供了丰富的API接口&#xff0c;涵盖点击、滑动、输入、截图等常见操作&#xff0c;帮助开发者快速构建自动化任务。ETPBot支持多设备并行…

连锁企业管理系统对门店运营的促进作用

连锁企业管理系统通过整合数字化工具与流程优化&#xff0c;能从多维度提升门店运营效率与竞争力&#xff0c;以下是其对门店运营的具体促进作用&#xff1a; 一、数据化管理&#xff1a;精准决策与运营监控 实时数据同步与分析 系统可整合各门店销售数据、库存信息、客流统计…

现代生活健康养生新策略

在充满挑战的现代生活中&#xff0c;各种健康问题悄然来袭&#xff0c;亚健康状态困扰着不少人。摒弃中医概念&#xff0c;运用现代科学理念&#xff0c;也能找到行之有效的养生之道。​ 饮食上&#xff0c;遵循 “彩虹饮食法” 能让营养摄入更全面。不同颜色的蔬果富含不同的…

车载以太网网络测试-27【SOME/IP-SD简述】

文章目录 1 摘要2 SOME/IP-SD协议介绍2.1 定义与作用2.2 SOMEIP/SD协议通俗易懂的理解2.2.1 SOMEIP/SD协议是什么&#xff1f;2.2.2 通信流程&#xff08;简化&#xff09;2.2.3 车载功能示例2.2.4 类比理解 2.3 SOME/IP-SD报文结构2.3.1 Flags2.3.1.1 REBOOT (Bit 7)2.3.1.2 U…

【Redis8】最新安装版与手动运行版

1. 下载 Redis 百度网盘 2. 解压后直接运行 redis-server.exe 3. 使用安装版 双击 install_redis_service.bat 输入安装路径&#xff08;请提前创建好安装路径&#xff09;后直接回车下一步直接回车即可&#xff0c;因为是使用配置模板文件为默认解压出来的&#xff0c;然后…

Spring Boot 集成 Elasticsearch【实战】

前言&#xff1a; 上一篇我们简单分享了 Elasticsearch 的一些概念性的知识&#xff0c;本篇我们来分享 Elasticsearch 的实际运用&#xff0c;也就是在 Spring Booot 项目中使用 Elasticsearch。 Elasticsearch 系列文章传送门 Elasticsearch 基础篇【ES】 Elasticsearch …

06算法学习_58. 区间和

58. 区间和 06算法学习_58. 区间和题目描述&#xff1a;个人代码&#xff1a;学习思路&#xff1a;第一种写法&#xff1a;题解关键点&#xff1a; 个人学习时疑惑点解答&#xff1a; 06算法学习_58. 区间和 卡码网题目链接: 59. 螺旋矩阵 II 题目描述&#xff1a; 58. 区间…

Python爬虫之路(14)--playwright浏览器自动化

playwright 前言 ​ 你有没有在用 Selenium 抓网页的时候&#xff0c;体验过那种「明明点了按钮&#xff0c;它却装死不动」的痛苦&#xff1f;或者那种「刚加载完页面&#xff0c;它又刷新了」的抓狂&#xff1f;别担心&#xff0c;你不是一个人——那是 Selenium 在和现代前…

Python开启智能之眼:OpenCV+深度学习实战

开篇导言 场景痛点 "某汽车零部件厂每月因人工质检遗漏损失300万,直到部署了基于Python的视觉检测系统..." 传统质检效率低下、成本高昂 深度学习技术带来的产业变革 Python在视觉识别领域的独特优势 一、技术架构解析 1.1 系统组成模块 图表 代码 下载 检测结…

华为模拟器练习简单的拓扑图(3台路由器和2台pc)

1、题目要求 根据下图&#xff0c;pc1连通pc2&#xff0c;实现不同网段直接的互通 2、思路整理 2.1 根据图上的要求&#xff0c;为主机和路由器相连接的端口设置对应IP地址&#xff08;子网掩码都是24位&#xff09;,路由器连接pc的那个端口&#xff0c;是主机pc的网关 2.2 …

uniapp生成的app,关于跟其他设备通信的支持和限制

以下内容通过AI生成&#xff0c;这里做一下记录。 蓝牙 移动应用&#xff08;App&#xff09;通过蓝牙与其他设备通信&#xff0c;是通过分层协作实现的。 一、通信架构分层 应用层&#xff08;App&#xff09; 调用操作系统提供的蓝牙API&#xff08;如Android的BluetoothA…

Proxmox 主机与虚拟机全部断网问题排查与解决记录

Proxmox 主机与虚拟机全部断网问题排查与解决记录 关键词&#xff1a;Proxmox、e1000e、板载网卡、断网、网络桥接、Hardware Unit Hang、网卡挂死 背景 近期在使用 Proxmox VE 管理服务器时&#xff0c;遇到一个奇怪的问题&#xff1a;每当在某个虚拟机中执行某些操作&#x…