生成对抗网络

简介: 生成对抗网络

生成对抗网络(Generative Adversarial Network,简称GAN)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)组成,用于生成具有逼真度的新数据样本。生成器负责生成假样本,判别器负责区分真假样本,两者通过对抗训练来提高生成器生成逼真样本的能力。

 

下面是一个简单的基于TensorFlow的GAN实现示例,用于生成服从正态分布的随机数。

```python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
 
# 参数设置
num_samples = 1000
latent_dim = 100
generator_hidden_units = 128
discriminator_hidden_units = 128
batch_size = 64
epochs = 10000
 
# 生成器
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(generator_hidden_units, input_shape=(latent_dim,), activation='relu'))
    model.add(tf.keras.layers.Dense(1, activation='linear'))
    return model
 
# 判别器
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(discriminator_hidden_units, input_shape=(1,), activation='relu'))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    return model
 
# 生成器损失函数
def generator_loss(fake_output):
    return tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)
 
# 判别器损失函数
def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
 
# 优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
 
# 实例化生成器和判别器
generator = make_generator_model()
discriminator = make_discriminator_model()
 
# 定义训练步骤
@tf.function
def train_step():
    noise = tf.random.normal([batch_size, latent_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_data = generator(noise, training=True)
 
        real_output = discriminator(tf.random.normal([batch_size, 1]), training=True)
        fake_output = discriminator(generated_data, training=True)
 
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
 
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
 
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
 
# 训练模型
for epoch in range(epochs):
    for _ in range(num_samples // batch_size):
        train_step()
 
# 生成数据
generated_data = generator(tf.random.normal([num_samples, latent_dim]), training=False)
 
# 可视化生成数据分布
plt.hist(np.reshape(generated_data.numpy(), (-1,)), bins=50, density=True)
plt.show()
```

在这个示例中,我们通过生成器生成服从正态分布的随机数,并使用判别器区分真实样本和生成样本。通过对生成器和判别器进行对抗训练,最终生成器可以生成逼真的数据样本。

 

补充一个关于生成对抗网络(GAN)的应用示例,用于生成手写数字图像。这里使用的是基于TensorFlow的Keras API来构建和训练GAN模型。

```python
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
 
# 加载MNIST数据集
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # 将图像标准化到[-1, 1]范围内
 
# 定义生成器模型
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Reshape((7, 7, 256)))
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model
 
# 定义判别器模型
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model
 
# 实例化生成器和判别器
generator = make_generator_model()
discriminator = make_discriminator_model()
 
# 定义损失函数和优化器
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
# 判别器损失函数
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
 
# 生成器损失函数
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
 
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
 
# 定义训练步骤
noise_dim = 100
num_examples_to_generate = 16
 
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, noise_dim])
 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
 
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
 
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
 
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
 
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
 
# 定义生成和保存图像函数
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
 
# 定义训练函数
def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
 
        if epoch % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)
 
# 批量大小和周期数
batch_size = 256
epochs = 100
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size)
 
# 生成初始随机向量
seed = tf.random.normal([num_examples_to_generate, noise_dim])
 
# 训练模型
train(train_dataset, epochs)
```

在这个示例中,我们使用了MNIST手写数字数据集来训练GAN模型,生成手写数字图像。通过训练,生成器可以生成逼真的手写数字图像,展示了GAN在图像生成领域的强大能力。

相关文章
|
网络协议
aws-vpc-对等连接(不同vpc之间的内网互通)
aws-vpc-对等连接(不同vpc之间的内网互通)
1988 0
aws-vpc-对等连接(不同vpc之间的内网互通)
|
6月前
|
存储 安全 数据安全/隐私保护
阿里云服务器五代、六代、七代、八代实例简介及性能提升介绍
随着技术的不断进步,到2025年,阿里云服务器实例也经历了多代升级,从五代实例到最新的八代实例,每一代都在性能、稳定性、能效比等方面取得了显著提升。有的用户由于是初次接触阿里云服务器,所以不是很清楚阿里云服务器五代、六代、七代、八代实例有哪些,它们各自在云服务器性能上有哪些提升。本文将详细介绍阿里云服务器五代、六代、七代、八代实例的特点及性能提升,帮助用户更好地了解并选择适合自己的云服务器实例。
286 29
|
7月前
|
云安全 安全 网络协议
游戏服务器被攻击,游戏盾防护具有哪些作用
在数字化时代蓬勃发展,但也面临着黑客攻击、DDoS和CC攻击等网络安全威胁。游戏盾防护应运而生,专为游戏行业提供全面的网络安全解决方案,不仅有效防御大型DDoS攻击,还能精准抵御特有TCP协议的CC攻击,同时通过智能行为分析和业务安全防护,确保游戏服务器的稳定运行,提升用户体验,维护游戏生态和品牌声誉,助力游戏行业健康发展。
|
7月前
|
机器学习/深度学习 存储 人工智能
智能语音识别技术的深度剖析与应用前景####
本文深入探讨了智能语音识别技术的技术原理、关键技术突破及广泛应用场景,通过具体实例展现了该技术如何深刻改变我们的日常生活和工作方式。文章还分析了当前面临的挑战与未来发展趋势,为读者提供了一幅全面而深入的智能语音识别技术图景。 ####
|
Python
掌握Python中循环语句的精髓:基础用法与高级技巧
掌握Python中循环语句的精髓:基础用法与高级技巧
176 0
|
11月前
【threejs教程】让你的场景贴图变得多姿多彩:UV坐标详解
【8月更文挑战第6天】threejs教程:让你的场景贴图变得多姿多彩,UV坐标详解
504 5
【threejs教程】让你的场景贴图变得多姿多彩:UV坐标详解
|
机器学习/深度学习 人工智能 自然语言处理
|
11月前
|
Linux 开发工具 数据安全/隐私保护
在Linux中,如何添加和管理用户账户以及如何设置sudo权限?
在Linux中,如何添加和管理用户账户以及如何设置sudo权限?
关闭windows11(10)自动更新工具
关闭windows11(10)自动更新工具
|
存储 JSON 关系型数据库
深入探索MySQL中JSON数据的查询、转换及springboot中的应用
MySQL版本引入了对JSON数据类型的支持,这为我们处理和存储非结构化数据提供了新的可能性。通过灵活利用MySQL的JSON函数,我们可以实现高效的查询和转换操作,提取有用的数据,并将其转换为有意义的格式。本文将深入探索MySQL中JSON数据的查询与转换技巧,帮助您更好地利用这一功能。
845 0
深入探索MySQL中JSON数据的查询、转换及springboot中的应用
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等

登录插画

登录以查看您的控制台资源

管理云资源
状态一览
快捷访问