【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成

简介: 【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。
热爱写作,愿意让自己成为更好的人…


  👉引言💎


铭记于心
🎉✨🎉我唯一知道的,便是我一无所知🎉✨🎉


【深度学习实践(八)】对抗生成网络(GAN)之手写数字生成


一、🌹对抗生成网络


1 定义与背景 :


生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构,GAN的核心本质是通过对抗训练将随机噪声的分布拉近到真实的数据分布


2 基本结构:


  • GAN本身是一个不断博弈,识别真假的过程,下面通过手写数字生成案例 窥探GAN对抗生成网络的原理及操作流程:

image.png

  • 定义一个模型来作为生成器(图三中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像(生成噪声)
  • 定义一个分类器来作为判别器(图三中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签

并且,既然是神经网络,那么模型就可以根据 外界反馈 自行调整参数,也就是会根据 标签匹配结果进行相应的学习与调整, 训练完成后 可以达到 ** 生成以假乱真的 手写数字图片效果**


二、🌹模型训练


💎1 设置GPU


  • GPU能够为大量数据的运算提供算力支持
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
warnings.filterwarnings("ignore")             
plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False    


💎2 构建GAN对抗网络生成器


def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])
    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)
    return Model(noise, img)


💎3 构造鉴别器


def build_discriminator():
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    img = layers.Input(shape=img_shape)
    validity = model(img)
  • 最后传入img以及model参数构造Model对象
    return Model(img, validity)
  • 鉴别器训练原理:通过对输入的图片进行鉴别,从而达到提升的效果
  • 生成器训练原理:通过鉴别器对其生成的图片进行鉴别,来实现提升


💎4 构造生成器


# 创建判别器
dis = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
dis.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
# 创建生成器                       
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input) 
#训练generate时候停止训练判别器
dis.trainable = False  
# 测试:对生成的假图片进行预测 
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)


💎5 训练模型


  • train_on_batch详解:
    keras在compile完模型后需要训练,除了常用的model.fit()与model.fit_generator外
    还有model.train_on_bantch作用:对一批样品进行单梯度更新,即对一个epoch中的一个样本进行一次训练
  • 使用train_on_batch优点:
    更精细自定义训练过程,更精准的收集 loss 和 metrics
    分布训练模型-GAN生成对抗神经网络的实现
    多GPU训练保存模型更加方便
def train(epochs, batch_size=128, sample_interval=50):
  • 加载数据
(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()
  • 将图片标准化到 [-1, 1] 区间内
train_images = (train_images - 127.5) / 127.5
  • 数据
train_images = np.expand_dims(train_images, axis=3)
  • 创建标签
true = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
  • 开始训练
for epoch in range(epochs): 
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict(noise)
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)

image.png

  • 动图展示
def compose_gif():
    # 图片地址
    data_dir = "F:/jupyter notebook/DL-100-days/code/images"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)


🌹写在最后💖:

路漫漫其修远兮,吾将上下而求索!伙伴们,再见!🌹🌹🌹



相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
11天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
143 55
|
18天前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
84 5
|
21天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
114 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
8天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于yolov4深度学习网络的公共场所人流密度检测系统matlab仿真,带GUI界面
本项目使用 MATLAB 2022a 进行 YOLOv4 算法仿真,实现公共场所人流密度检测。通过卷积神经网络提取图像特征,将图像划分为多个网格进行目标检测和识别,最终计算人流密度。核心程序包括图像和视频读取、处理和显示功能。仿真结果展示了算法的有效性和准确性。
53 31
|
20天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
116 30
|
15天前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如"How are you"、"I am fine"、"I love you"等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。
|
21天前
|
机器学习/深度学习 人工智能 自然语言处理
揭秘人工智能:深度学习的奥秘与实践
在本文中,我们将深入浅出地探索深度学习的神秘面纱。从基础概念到实际应用,你将获得一份简明扼要的指南,助你理解并运用这一前沿技术。我们避开复杂的数学公式和冗长的论述,以直观的方式呈现深度学习的核心原理和应用实例。无论你是技术新手还是有经验的开发者,这篇文章都将为你打开一扇通往人工智能新世界的大门。
|
18天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于深度学习网络的宝石类型识别算法matlab仿真
本项目利用GoogLeNet深度学习网络进行宝石类型识别,实验包括收集多类宝石图像数据集并按7:1:2比例划分。使用Matlab2022a实现算法,提供含中文注释的完整代码及操作视频。GoogLeNet通过其独特的Inception模块,结合数据增强、学习率调整和正则化等优化手段,有效提升了宝石识别的准确性和效率。
|
20天前
|
机器学习/深度学习 人工智能 自然语言处理
深入理解深度学习中的卷积神经网络(CNN)##
在当今的人工智能领域,深度学习已成为推动技术革新的核心力量之一。其中,卷积神经网络(CNN)作为深度学习的一个重要分支,因其在图像和视频处理方面的卓越性能而备受关注。本文旨在深入探讨CNN的基本原理、结构及其在实际应用中的表现,为读者提供一个全面了解CNN的窗口。 ##
|
20天前
|
机器学习/深度学习 存储 人工智能
探索深度学习的奥秘:从理论到实践的技术感悟
本文深入探讨了深度学习技术的核心原理、发展历程以及在实际应用中的体验与挑战。不同于常规摘要,本文旨在通过作者个人的技术实践经历,为读者揭示深度学习领域的复杂性与魅力,同时提供一些实用的技术见解和解决策略。
29 0