图像生成:Pytorch实现一个简单的对抗生成网络模型

news2025/6/24 19:37:35

图像生成:Pytorch实现一个简单的对抗生成网络模型

  • 前言
  • 相关介绍
  • 具体步骤
    • 准备并读取数据集
    • 定义生成器
    • 定义判别器
    • 定义损失函数
    • 定义优化器
    • 开始训练
    • 完整代码
  • 训练生成的图片

前言

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

相关介绍

对抗生成网络(Adversarial Generative Networks,简称GANs)是由伊恩·古德费洛(Ian Goodfellow)等人在2014年提出的深度学习框架,主要用于无监督学习中的数据生成任务。GAN的设计灵感来源于博弈论中的极小极大游戏,在机器学习领域开辟了一种全新的生成模型方法。

GAN包含两个主要组成部分:

  1. 生成器(Generator, G)

    • 生成器G是一个神经网络,它的功能是从随机噪声向量(通常来自高斯分布或其他先验分布)出发,通过一系列变换来生成新的数据样本,这些样本应该尽可能接近于训练数据的真实分布。
    • G的目标是尽可能地“欺骗”判别器,使其无法区分生成的数据和真实数据。
  2. 判别器(Discriminator, D)

    • 判别器D也是一个神经网络,它的作用是接收输入的样本,并将其分类为“真”(来自真实数据分布)或“假”(来自生成器产生的分布)。
    • D的目标是准确地区分真实数据和伪造数据,从而提高自身鉴别能力。

训练过程中,GAN采用了极小极大博弈策略,具体步骤如下:

  • 先固定判别器参数,更新生成器参数以使得生成的数据更有可能被判别器误认为真实数据;
  • 再固定生成器参数,更新判别器参数以更好地区分真实数据和生成器生成的伪数据。

这种交替优化的方式促使两个网络性能不断提升,直到达到纳什均衡,此时生成器能够生成非常逼真的新样本,而判别器再也无法有效区分真实数据和生成数据。

GAN在图像生成、图像到图像转换、视频生成、音频合成等多个领域都取得了显著成果,并且随着研究的深入,衍生出了许多改进版本和变体,比如条件GAN(Conditional GANs)、 Wasserstein GAN(WGAN)、CycleGAN等,进一步提高了生成效果并拓展了应用范围。
在这里插入图片描述

具体步骤

准备并读取数据集

以mnist数据集为例。

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

定义损失函数

# Loss function
adversarial_loss = torch.nn.BCELoss()

定义优化器

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

开始训练

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        # 生成器损失,目的是生成的图像,逼近真是图像,是判别器认为是真
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        # 真实图像在判别器的目标函数损失,即判别器应该判为真
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        # 生成器生成出来的图像在判别器的目标函数损失,即判别器应该判为假
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        # 判别器损失,目的是判别生成的图像是假
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

完整代码

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述

训练生成的图片

在这里插入图片描述

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584395.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RTSP/Onvif视频安防监控平台EasyNVR调用接口返回匿名用户名和密码的原因排查

视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。平台拓展性强、支持二次开发与集成,可应用在景区、校园、水利、社区、工地等场…

坚持十天做完Python入门编程100题第三天加班

坚持十天做完Python入门编程100题第三天加班 第24题 扫描文件列表第25题 如何将字典转换成JSON并写入json文件?第26题 JSON转换成字典 第24题 扫描文件列表 如何扫描当前目录下的文件列表?解析:可以使用python内置的glob模块,用法…

C++设计模式:单例模式(十)

1、单例设计模式 单例设计模式,使用的频率比较高,整个项目中某个特殊的类对象只能创建一个 并且该类只对外暴露一个public方法用来获得这个对象。 单例设计模式又分懒汉式和饿汉式,同时对于懒汉式在多线程并发的情况下存在线程安全问题 饿汉…

《深入Linux内核架构》第2章 进程管理和调度 (2)

目录 2.4 进程管理相关的系统调用 2.4.1 进程复制 2.4.2 内核线程 2.4.3 启动新程序 2.4.4 退出进程 本专栏文章将有70篇左右,欢迎关注,订阅后续文章。 2.4 进程管理相关的系统调用 2.4.1 进程复制 1. _do_fork函数 fork vfork clone都最终调用_…

微信小程序转盘抽奖

场景: 在微信小程序里面开展抽奖活动使用转盘抽奖;类似下图(图片来自百度) 方法: 使用lukcy-canvas组件 在 微信小程序 中使用 | 基于 Js / TS / Vue / React / 微信小程序 / uni-app / Taro 的【大转盘 & 九宫…

【Qt踩坑】ARM 编译Qt5.14.2源码-QtWebEngine

1.下载源码 下载网站:Index of /new_archive/qt/5.14/5.14.2/single 2.QWebEngine相关依赖 sudo apt-get install flex libicu-dev libxslt-dev sudo apt-get install libssl-dev libxcursor-dev libxcomposite-dev libxdamage-dev libxrandr-dev sudo apt-get …

dyld: Library not loaded: @rpath/SDK.framework/SDK错误问题

关于导入三方SDK.framework之后,启动崩溃之后如下报错的解决方式: 截屏2020-10-14 上午9.55.09.png 在正常导入framework之后,做如图示操作, image.png 以上步骤之后,重新启动运行xcode,即可成功运行。

Harmony鸿蒙南向驱动开发-PIN

PIN即管脚控制器,用于统一管理各SoC的管脚资源,对外提供管脚复用功能。 基本概念 PIN是一个软件层面的概念,目的是为了统一对各SoC的PIN管脚进行管理,对外提供管脚复用功能,配置PIN管脚的电气特性。 SoC(…

ChatGPT在地学,自然科学等了领域应用教程

原文链接:ChatGPT在地学,自然科学等了领域应用教程https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247600722&idx2&sn291ea8c935b1d9b1459170baa9057053&chksmfa820bb5cdf582a39086e5ee9596ab020784fa78ac7dc49ced4969e28817c3f0…

MAC: 自己制作https的ssl证书(自己签发免费ssl证书)(OPENSSL生成SSL自签证书)

MAC: 自己制作https的ssl证书(自己签发免费ssl证书)(OPENSSL生成SSL自签证书) 前言 现在https大行其道, ssl又是必不可少的环节. 今天就教大家用开源工具openssl自己生成ssl证书的文件和私钥 环境 MAC电脑 openssl工具自行搜索安装 正文 1、终端执行命令 //生成rsa私钥&…

探索艺术的新领域——3D线上艺术馆如何改变艺术作品的传播方式

在数字化时代的浪潮下,3D线上艺术馆成为艺术家们展示和传播自己作品的新平台。不仅突破了地域和物理空间的限制,还提供了全新的互动体验。 一、无界限的展示空间:艺术家的新展示平台 3D线上艺术馆通过数字化技术,为艺术家提供了一…

虚拟货币:数字金融时代的新工具

在数字化时代的到来之后,虚拟货币逐渐成为了一种广为人知的金融工具。虚拟货币是一种数字化的资产,它不像传统货币那样由政府或中央银行发行和监管。相反,虚拟货币通过密码学技术和分布式账本技术来实现去中心化的发行和交易。 虚拟货币的代…

机器学习和深度学习-- 李宏毅(笔记与个人理解)Day10

Day 10 Genaral GUidance training Loss 不够的case Loss on Testing data over fitting 为什么over fitting 留到下下周哦~~ 期待 solve CNN卷积神经网络 Bias-Conplexiy Trade off cross Validation how to split? N-fold Cross Validation mismatch 这节课总体听下来比较…

大厂MVP技术JAVA架构师培养

课程介绍 这是一个很强悍的架构师涨薪计划课程,课程由专家级MVP讲师进行教学,分为是一个章节进行分解式面试及讲解,不仅仅是面试,更像是一个专业的架构师研讨会课程。课程内容从数据结构与算法、Spring Framwork、JVM原理、 JUC并…

环信 IM 客户端将适配鸿蒙 HarmonyOS

自华为推出了自主研发操作系统鸿蒙 HarmonyOS 后,国内许多应用软件开始陆续全面兼容和接入鸿蒙操作系统。环信 IM 客户端计划将全面适配统鸿蒙 HarmonyOS ,助力开发者快速实现社交娱乐、语聊房、在线教育、智能硬件、社交电商、在线金融、线上医疗等广泛…

代码学习记录40---动态规划

随想录日记part40 t i m e : time: time: 2024.04.10 主要内容:今天开始要学习动态规划的相关知识了,今天的内容主要涉及: 买卖股票的最佳时机加强版。 123.买卖股票的最佳时机III 188.买卖股票的最佳时机…

代码随想录--数组--有序数组的平方

题目 给你一个按 非递减顺序 排序的整数数组 nums,返回 每个数字的平方 组成的新数组,要求也按 非递减顺序 排序。 示例 1: 输入:nums [-4,-1,0,3,10] 输出:[0,1,9,16,100] 解释:平方后,数组…

【CSS】一篇文章讲清楚screen、window和html元素的位置:top、left、width、height

一个Web网页从内到外的顺序是: 元素div,ul,table... → 页面body → 浏览器window → 屏幕screen 分类详情屏幕screen srceen.width - 屏幕的宽度 screen.height - 屏幕的高度(屏幕未缩放时,表示屏幕分辨率) screen.availLeft …

云手机解决海外社媒运营的诸多挑战

随着海外社交媒体运营的兴起,如何有效管理多个账户成为了一项挑战。云手机作为一种新兴的解决方案,为海外社媒运营带来了前所未有的便利。 云手机的基本原理是基于云计算和虚拟化技术,允许用户在物理手机之外创建和使用多个虚拟手机。这种创新…

【开发篇】十三、JVM基础参数设置与垃圾回收器的选择

文章目录 1、-Xmx 和 –Xms2、-XX:MaxMetaspaceSize 和 –XX:MetaspaceSize3、-Xss4、不建议改的参数5、其他参数6、选择GC回收器的调试思路7、CMS的并发模式失败现象的解决8、调优案例 GC问题解决方式: 优化JVM基础参数,避免频繁Full GC减少对象的产生…