生成对抗网络(GAN)

news2025/7/13 8:14:28

GAN简介

GAN思想是一种二人的零和博弈思想,GAN中有两个博弈者,一个生成器(G),一个判别器(D),这两个模型都有各自的输入和输出。具体功能如下:
生成器(G):输入一个随机噪声样本,通过生成器生成一个与真实样本无差的样本。
判别器(D):对输出模型进行打分,类似一个分类器,打分的对照样本是真实样本。
在这里插入图片描述

GAN的简易模型如下:

在这里插入图片描述

  • GAN训练
  • GAN训练一开始都是训练判别器的,目的是让判别器获得一个标准的,也就是说,让判别器看一堆好的图片,从而知道好图究竟是怎么样的,这些图片可以来自于各种地方,一般是将好的图片收集到数据库中,在从中抽取这些图片输入判别器。
  • 训练完判别器后,判别器已经有标准了,此时再来训练生成器,生成器的训练流程也比较简单,随便生成一组噪声给生成器即可,如生成符合正太分布的噪声。生成器会通过这些噪声生成一张图片,当然一开始这张图片是惨不忍睹的,这种图片过不论判别器这关。
  • 生成图片后,判别器就会判断这张图片是来自数据库的真实图片,还是来自于生成器的生成图片。如果判断是真实图片,就会给图片赋予一个较高的分数,如真实图片赋为1,如果判断是生成土图片,就赋予图片一个较低的分数,如生成图片赋予0,同时还会产生一个损失。

GAN的简要流程如下:

在这里插入图片描述

GAN公式如下:

在这里插入图片描述
生成器的目标就是让判别器无法判断是生成图片还是真实图片,换种说法就是,生成器的目标都是生成真实图片,至少让判别器认为是真实的,生成器一开始生成图片过于模糊抽象,判别器可以轻易的将其识别,生成器为了提高自己生成图片的能力,就要不断的学习,具体而言,就是找到自己生成图片与真实图片的差距。然后弥补这个差距。这就是所谓的差距,其实就是损失,也就是在高维空间中生成图片的概率分布与真实图片的概率分布的不同之处,具体而言,就是两个概率图片的 J S 散 度 JS散度 JS就是最小化生成图片的概率分布与真实图片的概率分布的 J S JS JS散度。

  • 生成器损失:判别器给生成图片赋予的分数和目标分数,的差距。
  • 判别器损失:其损失由两部分构成,判别器给真实图片赋予的分数和目标分数的差距。判别器给生成图片和目标分数的差距
    计算损失时候使用 t f . n n . s i g m o i d c r o s s e n t r o p y w i t h l o g i t s tf.nn.sigmoid_cross_entropy_with_logits tf.nn.sigmoidcrossentropywithlogits的方法,其对传入 l o g i t s logits logits参数,先使用 S i g m o i d 函 数 计 算 Sigmoid函数计算 Sigmoid,然后再计算它们的 c r o s s e n t r o p y cross entropy crossentropy交叉熵损失,同时该方法优化了交叉熵的计算方式,使得结果不会溢出。
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
from tensorflow.keras import layers

# mnist = tf.keras.datasets.mnist
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# # 可视化训练集输入特征的第一个元素
# plt.imshow(x_train[0], cmap='gray')  # 绘制灰度图
# plt.show()

(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
'''mnist中的reshape
x_image = tf.reshape(x, [-1, 28, 28, 1])
       这里是将一组图像矩阵x重建为新的矩阵,该新矩阵的维数为(a,28,28,1),其中-1表示a由实际情况来定'''
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images-127.5)/127.5      # -1~1

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


# 生成器
def generator_model():  # 用100个随机数(噪音)生成手写数据集
    model = keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())     # 规范化
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(28 * 28 * 1, use_bias=False, activation='tanh'))
    model.add(layers.BatchNormalization())

    model.add(layers.Reshape((28, 28, 1)))

    return model


# 判别器
def discriminator_model():  # 识别输入的图片
    model = keras.Sequential()
    model.add(layers.Flatten())

    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(1))

    return model


cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)


# 判别器损失
def discriminator_loss(real_out, fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out), real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
    return real_loss + fake_loss


# 生成器损失
def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out), fake_out)


generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

generator = generator_model()
discriminator = discriminator_model()

noise_dim = 100  # 即用100个随机数生成图片


def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_out = discriminator(images, training=True)

        gen_image = generator(noise, training=True)
        fake_out = discriminator(gen_image, training=True)

        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out, fake_out)
    gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc, discriminator.trainable_variables))


def generate_plot_image(gen_model, test_noise):
    pre_images = gen_model(test_noise, training=False)
    plt.figure(figsize=(4, 4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap='gray')
        plt.axis('off')
    plt.show()


EPOCHS = 100  # 训练100次
num_exp_to_generate = 16  # 生成16张图片
seed = tf.random.normal([num_exp_to_generate, noise_dim])  # 16组随机数组,每组含100个随机数,用来生成16张图片。


def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
            print('.', end='')
        if epoch % 10 == 0:
            print('epoch: ', epoch)
            generate_plot_image(generator, seed)


train(datasets, EPOCHS)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

声门脉冲语音处理

对于 0<t<tpeak&#xff0c;gattack(t) 攻击部分&#xff0c;即上升分支的时间&#xff0c;时间 t 的范围从 0 秒到最大峰值时间 tpeak &#xff0c;图示例中选择为大约总长度的 35%&#xff0c;即 tpeak35%⋅T0&#xff0c;或者在样本 Lattack⌊35%⋅Lg⌉ 中&#xff0c…

2023年系统规划与设计管理师-第三章信息技术服务知识

一. 思维导图 二.IT 服务管理 (ITSM) 1. 什么是 IT 服务管理 (ITSM)&#xff1f; IT 服务管理 (ITSM) 包含一组策略和实践&#xff0c;这些策略和实践可用于为最终用户实施、交付和管理 IT 服务&#xff0c;以满足最终用户的既定需求和企业的既定目标。 在此定义中&#xff0…

otn 709帧结构

otn架构说明: 基于G.709接口,包括波分侧和客户侧,客户侧通常用于互联互通。 光通路净荷单元:OPU0/OPU1/OPU2/OPU3/OPU4/flex,主要用于完成业务同步或异步映射; 光通路数据单元:ODU0/ODU1/ODU2/ODU3/ODU4/ODU-flex,完成通道连接性能监测和子速率复用、 光通路传送单元…

POJ1008:玛雅日历

一、Description During his last sabbatical, professor M. A. Ya made a surprising discovery about the old Maya calendar. From an old knotted message, professor discovered that the Maya civilization used a 365 day long year, called Haab, which had 19 months.…

Netty学习笔记

文章目录二、Netty 入门2.1、概述2.1.1、Netty 是什么&#xff1f;2.1.2、Netty 的作者2.1.3、Netty 的地位2.1.4、Netty 的优势2.2、Hello World2.2.1、目标2.2.2、服务器端2.2.3、客户端2.2.4、流程梳理&#x1f4a1; 提示2.3、组件2.3.1、EventLoop&#x1f4a1; 优雅关闭演…

保姆级二进制安装高可用k8s集群文档(1.23.8)

保姆级二进制安装高可用k8s集群文档k8s搭建方式前期准备集群规划机器准备1、master vagrantfile2、master install.sh3、node vagrantfile4、node install.sh5、时间同步vagran 启动脚本vagrant up注意点安装conntrack 工具ipvs的安装VBoxManage snapshot 准备虚拟机快照ETCD部…

C语言编程作业参考答案

编程题参考答案 文章目录编程题参考答案week1_test选择结构-编程题循环结构上机练习数组编程函数编程2week1_test Write a program to output the average of 2 integers. #include <stdio.h>void main(){int a , b;double c;printf("Please enter 1 integers\n&q…

官网下载mysql 8.0.27及安装

https://www.mysql.com/downloads/&#xff0c;找到社区版下载链接MySQL Community (GPL) Downloads 1、 2、 3、 4、 5、

光谱异常样本检测分析

以近红外光谱为例&#xff0c;大部分光谱数据在不考虑分类问题时&#xff0c;在构建模型前需要对采集数据进行样本分析&#xff0c;以降低因生产过程异常、人为误操作和其他原因对软测量模型的影响&#xff0c;即异常样本检测分析。 按照定义&#xff0c;异常样本检测任务指的是…

k8s编程operator——(3) 自定义资源CRD.md

文章目录1、自定义资源的使用1.1 注册自定义资源1.2 使用自定义资源&#xff1a;1.3 Finalizers1.4 合法性验证2、如何操作自定义资源2.1 使用RestClient和DynamicClient来操作自定义资源对象2.2 使用sharedIndexInformer2.3 code-generator2.3.1 下载安装2.3.2 code-generator…

Ajax、Fetch、Axios三者的区别

1.Ajax&#xff08;Asynchronous JavaScript And XML&#xff09; Ajax 是一个技术统称&#xff0c;它很重要的特性之一就是让页面实现局部刷新。 特点&#xff1a; 局部刷新页面&#xff0c;无需重载整个页面。 简单来说&#xff0c;Ajax 是一种思想&#xff0c;XMLHttpReq…

毕业设计-基于机器学习的图片处理图片倾斜校正

前言 &#x1f4c5;大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,一边要为毕业设计耗费大量精力。近几年各个学校要求的毕设项目越来越难,有不少课题是研究生级别难度的,对本科同学来说是充满挑战。为帮助大家顺利通过和节省时间与精力投…

如何简单理解大数据

如何简单理解大数据 HDFS-存储 海量的数据存储 hadoop 只是一套工具的总称&#xff0c;它包含三部分&#xff1a;HDFS&#xff0c;Yarn&#xff0c;MapReduce&#xff0c;功能分别是分布式文件存储、资源调度和计算。 按理来说&#xff0c;这就足够了&#xff0c;就可以完成大…

matlab实现MCMC的马尔可夫转换MS- ARMA - GARCH模型估计

状态转换模型&#xff0c;尤其是马尔可夫转换&#xff08;MS&#xff09;模型&#xff0c;被认为是识别时间序列非线性的不错的方法。 估计非线性时间序列的方法是将MS模型与自回归移动平均 - 广义自回归条件异方差&#xff08;ARMA - GARCH&#xff09;模型相结合&#xff0c;…

Ubuntu22.04+Nvidia驱动+Cuda11.8+cudnn8.6

Ubuntu22.04Nvidia驱动Cuda11.8 一、准备环境 ubuntu 22.04nvidia显卡 这里使用的是RTX3060已安装Python3.10 二、安装pip3 # 安装 sudo apt install python3-pip # 升级 sudo pip3 install --upgrade pip # 如果要卸载&#xff0c;使用命令&#xff1a; sudo apt-get remov…

MySQL创建和管理表

基础知识 一条数据存储的过程 存储数据是处理数据的第一步 。只有正确地把数据存储起来&#xff0c;我们才能进行有效的处理和分析。否则&#xff0c;只能是一团乱麻&#xff0c;无从下手。 那么&#xff0c;怎样才能把用户各种经营相关的、纷繁复杂的数据&#xff0c;有序、…

ES6解析赋值

ES6中新增了一种数据处理方式&#xff0c;可以将数组和对象的值提取出来对变量进行赋值&#xff0c;这个过程时将一个数据结构分解成更小的部分&#xff0c;称之为解析。 1.对象解析赋值: 在ES5中&#xff0c;要将一个对象的属性提取出来&#xff0c;需要经过一下几个过程。 …

aws sdk 学习和使用aws-sdk-go

https://www.go-on-aws.com/infrastructure-as-go/cdk-go/sdk example&#xff0c;https://github.com/awsdocs/aws-doc-sdk-examples golang的安装&#xff0c;使用1.15之后默认开启GO15VENDOREXPERIMENT的版本 yum install golang -y go env -w GOPROXYhttps://goproxy.cn,…

智慧教室解决方案-最新全套文件

智慧教室解决方案-最新全套文件一、建设背景1、教育信息化2.0行动计划2、中国教育现代化20353、加快推进教育现代化实施方案二、建设思路互联网时代教学环境定义三、建设方案四、获取 - 智慧教室全套最新解决方案合集一、建设背景 1、教育信息化2.0行动计划 “网络学习空间覆…

【直播】-DRM和TZC400的介绍-2022/11/26

直播背景和内容 最近有两位SOC大佬再和我探讨TZC的设计&#xff0c;以及使用场景。也有几位软件工程师&#xff0c;在深入得学习安全技术&#xff0c;也问到了TZC相关的技术。 然后就搞了本次的直播&#xff0c;共计17人报名。 上线12位同学。(有ASIC大佬、有软件大佬、芯片严…