【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成

简介: 【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。
热爱写作,愿意让自己成为更好的人…


  👉引言💎


铭记于心
🎉✨🎉我唯一知道的,便是我一无所知🎉✨🎉


【深度学习实践(八)】对抗生成网络(GAN)之手写数字生成


一、🌹对抗生成网络


1 定义与背景 :


生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构,GAN的核心本质是通过对抗训练将随机噪声的分布拉近到真实的数据分布


2 基本结构:


  • GAN本身是一个不断博弈,识别真假的过程,下面通过手写数字生成案例 窥探GAN对抗生成网络的原理及操作流程:

image.png

  • 定义一个模型来作为生成器(图三中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像(生成噪声)
  • 定义一个分类器来作为判别器(图三中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签

并且,既然是神经网络,那么模型就可以根据 外界反馈 自行调整参数,也就是会根据 标签匹配结果进行相应的学习与调整, 训练完成后 可以达到 ** 生成以假乱真的 手写数字图片效果**


二、🌹模型训练


💎1 设置GPU


  • GPU能够为大量数据的运算提供算力支持
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
warnings.filterwarnings("ignore")             
plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False    


💎2 构建GAN对抗网络生成器


def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])
    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)
    return Model(noise, img)


💎3 构造鉴别器


def build_discriminator():
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    img = layers.Input(shape=img_shape)
    validity = model(img)
  • 最后传入img以及model参数构造Model对象
    return Model(img, validity)
  • 鉴别器训练原理:通过对输入的图片进行鉴别,从而达到提升的效果
  • 生成器训练原理:通过鉴别器对其生成的图片进行鉴别,来实现提升


💎4 构造生成器


# 创建判别器
dis = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
dis.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
# 创建生成器                       
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input) 
#训练generate时候停止训练判别器
dis.trainable = False  
# 测试:对生成的假图片进行预测 
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)


💎5 训练模型


  • train_on_batch详解:
    keras在compile完模型后需要训练,除了常用的model.fit()与model.fit_generator外
    还有model.train_on_bantch作用:对一批样品进行单梯度更新,即对一个epoch中的一个样本进行一次训练
  • 使用train_on_batch优点:
    更精细自定义训练过程,更精准的收集 loss 和 metrics
    分布训练模型-GAN生成对抗神经网络的实现
    多GPU训练保存模型更加方便
def train(epochs, batch_size=128, sample_interval=50):
  • 加载数据
(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()
  • 将图片标准化到 [-1, 1] 区间内
train_images = (train_images - 127.5) / 127.5
  • 数据
train_images = np.expand_dims(train_images, axis=3)
  • 创建标签
true = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
  • 开始训练
for epoch in range(epochs): 
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict(noise)
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)

image.png

  • 动图展示
def compose_gif():
    # 图片地址
    data_dir = "F:/jupyter notebook/DL-100-days/code/images"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)


🌹写在最后💖:

路漫漫其修远兮,吾将上下而求索!伙伴们,再见!🌹🌹🌹



相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
8天前
|
机器学习/深度学习 人工智能 TensorFlow
人工智能浪潮下的自我修养:从Python编程入门到深度学习实践
【10月更文挑战第39天】本文旨在为初学者提供一条清晰的道路,从Python基础语法的掌握到深度学习领域的探索。我们将通过简明扼要的语言和实际代码示例,引导读者逐步构建起对人工智能技术的理解和应用能力。文章不仅涵盖Python编程的基础,还将深入探讨深度学习的核心概念、工具和实战技巧,帮助读者在AI的浪潮中找到自己的位置。
|
3天前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习中的卷积神经网络(CNN)及其在图像识别中的应用
本文旨在通过深入浅出的方式,为读者揭示卷积神经网络(CNN)的神秘面纱,并展示其在图像识别领域的实际应用。我们将从CNN的基本概念出发,逐步深入到网络结构、工作原理以及训练过程,最后通过一个实际的代码示例,带领读者体验CNN的强大功能。无论你是深度学习的初学者,还是希望进一步了解CNN的专业人士,这篇文章都将为你提供有价值的信息和启发。
|
4天前
|
机器学习/深度学习 人工智能 网络架构
深入理解深度学习中的卷积神经网络(CNN)
深入理解深度学习中的卷积神经网络(CNN)
20 1
|
6天前
|
机器学习/深度学习 人工智能 算法框架/工具
深度学习中的卷积神经网络(CNN)入门
【10月更文挑战第41天】在人工智能的璀璨星空下,卷积神经网络(CNN)如一颗耀眼的新星,照亮了图像处理和视觉识别的路径。本文将深入浅出地介绍CNN的基本概念、核心结构和工作原理,同时提供代码示例,带领初学者轻松步入这一神秘而又充满无限可能的领域。
|
12天前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习中的卷积神经网络:从理论到实践
【10月更文挑战第35天】在人工智能的浪潮中,深度学习技术以其强大的数据处理能力成为科技界的宠儿。其中,卷积神经网络(CNN)作为深度学习的一个重要分支,在图像识别和视频分析等领域展现出了惊人的潜力。本文将深入浅出地介绍CNN的工作原理,并结合实际代码示例,带领读者从零开始构建一个简单的CNN模型,探索其在图像分类任务中的应用。通过本文,读者不仅能够理解CNN背后的数学原理,还能学会如何利用现代深度学习框架实现自己的CNN模型。
|
9天前
|
机器学习/深度学习 数据采集 自然语言处理
深入浅出深度学习:从理论到实践
【10月更文挑战第38天】本文旨在通过浅显易懂的语言和直观的代码示例,带领读者探索深度学习的奥秘。我们将从深度学习的基本概念出发,逐步深入到模型构建、训练以及应用实例,让初学者也能轻松入门。文章不仅介绍了深度学习的原理,还提供了实战操作指南,帮助读者在实践中加深理解。无论你是编程新手还是有一定基础的学习者,都能在这篇文章中找到有价值的内容。让我们一起开启深度学习之旅吧!
|
11天前
|
机器学习/深度学习 人工智能 算法框架/工具
深度学习中的卷积神经网络(CNN)及其在图像识别中的应用
【10月更文挑战第36天】探索卷积神经网络(CNN)的神秘面纱,揭示其在图像识别领域的威力。本文将带你了解CNN的核心概念,并通过实际代码示例,展示如何构建和训练一个简单的CNN模型。无论你是深度学习的初学者还是希望深化理解,这篇文章都将为你提供有价值的见解。
|
9天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
37 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
11天前
|
机器学习/深度学习 自然语言处理 语音技术
深度学习的奇妙之旅:从理论到实践
【10月更文挑战第36天】在本文中,我们将一起探索深度学习的神秘世界。我们将首先了解深度学习的基本概念和原理,然后通过一个简单的Python代码示例,学习如何使用深度学习库Keras进行图像分类。无论你是深度学习的初学者,还是有一定基础的学习者,都可以从这篇文章中获得新的知识和启示。
|
12天前
|
数据采集 网络协议 算法
移动端弱网优化专题(十四):携程APP移动网络优化实践(弱网识别篇)
本文从方案设计、代码开发到技术落地,详尽的分享了携程在移动端弱网识别方面的实践经验,如果你也有类似需求,这篇文章会是一个不错的实操指南。
32 1

热门文章

最新文章

下一篇
无影云桌面