从头训练一个神经网络!教它学会莫奈风格作画!⛵

简介: 本文使用 GAN(生成对抗网络)进行AI绘画。torchgan是基于PyTorch的一个GAN工具库,本文讲解搭建DCGAN神经网络,并应用于『莫奈』风格绘画的全过程。
899a0111c0bfb0dcb12f0c6c090a6885.png
💡 作者: 韩信子@ ShowMeAI
📘 深度学习实战系列https://www.showmeai.tech/tutorials/42
📘 PyTorch 实战系列https://www.showmeai.tech/tutorials/44
📘 本文地址https://www.showmeai.tech/article-detail/324
📢 声明:版权所有,转载请联系平台与作者并注明出处
📢 收藏 ShowMeAI查看更多精彩内容
203c4eff8588b35327d5a455bcd08888.png

如今 AI 艺术创作能力越来越强大,在艺术作画上表现也异常惊人,大家在ShowMeAI的文章 📘AI绘画 | 使用Hugging Face发布的diffuser模型快速绘画 和 📘AI绘画 | 使用Disco Diffusion基于文本约束绘画 了解新技术进展和 AI 作画效果展示,其中 OpenAI 的 DALL-E 2 和 Google 的 ImageGen 等项目基于文本提示作画的结果和真实艺术家的成品难辨真假。

但上述效果好的大厂项目通常是付费而非开源的,即使有少数开源项目,也远远超出了本地电脑的计算能力(至少对于那些电脑没有 GPU 的宝宝来说)。 本篇我们用不同于 diffuser 模型的另外一种方法:GAN(生成对抗网络)来完成AI作画。

本篇内容中ShowMeAI将带大家来使用 GAN 生成对抗网络完成莫奈风格作画。

0e46c9cc288f660e03cbbd130b25653e.png

💡 GAN简介

我们本篇使用到的技术是GAN,中文名是『生成对抗网络』,它由两个部分组成:

  • 生成器:『生成器』负责生成所需的内容(在当前场景下是图像),未经训练的生成器随机生成的效果类似噪声,但随着训练过程推进,生成器会产出越来越逼真的结果,直至『判别器』无法分辨真实图像与AI绘制的图像。
  • 判别器:『判别器』负责监督生成器学习,它将真实图像与生成器生成的图像进行比较,检测和分辨真假。随着训练过程推进,它越来越有分辨能力,并督促生成器不断优化。

下图是一个简易的 GAN 示意图:

dba44ccc10feaf938c9d73d961bf744b.png

自2014年第一个 GAN 被研究者提出,经过多年它已经有非常长足的进步,产生越来越好的结果。在本教程中,ShowMeAI将基于 Pytorch 基础上的一个 GAN 工具库 torchgan 完成一个 DCGAN 并应用于莫奈风格的图像绘制任务上。

f0ba85d14a5beb49f1572a493c26eb7f.png

💡 数据集&数据处理

本篇使用到的数据集来源于著名大师莫奈的画作,我们基于这些优秀的画作,让神经网络学习和尝试产生类似的内容。法国画家 📘克劳德·莫奈 生活在 19 世纪,他的画作可以在 📘https://www.wikiart.org/en/claude-monet 获取。

093cebb8ef59f36ebf8834fcdc79219a.png

因为希望模型学习到的信息更充分,我们还扩充使用了很多类似大师风格的图像,更大的数据量可以使训练过程更容易。我们人类有很多背景知识先验知识,例如天空是蓝色的,树木是绿色的,但从神经网络的角度来看,任何图像都只是一个 RGB 数组,更多的数据可以帮助它们掌握这些基本规律。

1a9afb92b84278e6634f5d7e137c914c.png

关于数据处理与神经网络的详细原理知识,大家可以查看ShowMeAI制作的深度学习系列教程和对应文章

不过,即使采用了外观相似的图像,数据量依旧有点小。我们将使『数据增强』技术——它通过对图像的变换来构建新的图像达到数据扩增的效果。

我们创建一个自定义 Dataset 类,借助于 pytorch 的 transforms 功能,可以轻松完成数据扩增中的各种变换:

import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
import globimg_size = 256
class ImagesDataset(data.Dataset):
    def __init__(self, images_path: str):
        self.files = glob.glob(images_path)
        self.images = [None] * self.__len__()

    def __len__(self):
        return 1000

    def __getitem__(self, index):
        if self.images[index] is None:
            self.images[index] = self.generate_image()
        return self.images[index]

    def generate_image(self):
        index = random.randint(0, len(self.files) - 1)
        img = Image.open(self.files[index]).convert('RGB')
        transform = transforms.Compose(
           [transforms.Resize(img_size + img_size//2),                                                   
            transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()])
        resized = transform(img)
        return resized, index

上面的代码中,我们将所有图像调整为稍大的尺寸,然后应用随机裁剪和翻转构建新的输出图像。 对于莫奈的画,只使用了水平翻转和裁剪比较稳妥,但对于现代艺术样本,垂直翻转或随机旋转可能也是适用的。

我们随机取一点数据集,做可视化和验证有效性:

import torchvision.utils as vutils
import matplotlib.pyplot as plt
def show_images(batch):
    plt.figure(figsize=(12, 12))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(batch, padding=2, 
                            normalize=True).cpu(), (1, 2, 0)))
    plt.show()dataset = ImagesDataset(images_path="Paintings/Monet/*.png")
dataloader = data.DataLoader(dataset, batch_size=batch_size,  
                             shuffle=True)
batch = next(iter(dataloader))
show_images(batch[0][:64])

运行代码后,我们可以看到如下结果:

f954dc21203a3f8758adc7ed50de152a.png

莫奈创作了大约 2500 幅画作,当然完整的画作集中可能会包含不同的内容物,大家看可以稍作筛选。

💡 构建神经网络

我们准备好数据集后,下一步就开始创建神经网络模型了。我们基于 📘torchgan 工具库,构建 GAN 并不复杂:

import torch
from torchgan.models import *
from torchgan.losses import *
dcgan_network = {
    "generator": {
        "name": DCGANGenerator,
        "args": {
            "encoding_dims": 100,
            "step_channels": 40,
            "out_channels": 3,
            "out_size": img_size,
            "nonlinearity": nn.LeakyReLU(0.3),
            "last_nonlinearity": nn.Tanh()
        },
        "optimizer": {"name": Adam, 
                      "args": {"lr": 0.0005, "betas": (0.5, 0.999)}}
    },
    "discriminator": {
        "name": DCGANDiscriminator,
        "args": {
            "in_channels": 3,
            "in_size": img_size,
            "step_channels": 40,
            "nonlinearity": nn.LeakyReLU(0.3),
            "last_nonlinearity": nn.LeakyReLU(0.2)
        },
        "optimizer": {"name": Adam, 
                      "args": {"lr": 0.0006, "betas": (0.5, 0.999)}}
    }
}

lsgan_losses = [LeastSquaresGeneratorLoss(),
                LeastSquaresDiscriminatorLoss()]

我们通过配置的方式,通过字典对网络结构和参数进行了设置。我们这里定义的DCGAN模型包含一个生成器 DCGANGenerator 和一个判别器 DCGANDiscriminator

下一步我们训练网络:

# 使用GPU或者CPU
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")batch_size = 64

# 迭代轮次
epochs = 2000

# 训练器
trainer = Trainer(dcgan_network, lsgan_losses, 
                  sample_size=batch_size, epochs=epochs,
                  device=device,
                  recon="./torchgan_images",
                  checkpoints="./torchgan_model/gan2",
                  log_dir="./torchgan_logs",
                  retain_checkpoints=2)
# 训练
trainer(dataloader)
trainer.complete()

上述代码中涉及的参数都可以调整,每一轮训练后会在 torchgan_images 文件夹中生成图像样本,训练得到的模型保存在 torchgan_model 文件夹中(模型文件不小,对于 256x256 的小尺寸图像,它的大小约为 440M Bytes),但我们仅在磁盘上保留最后 2 个模型 checkpoint。

训练过程中的中间数据日志记录在 torchgan_logs 文件夹下,我们可以通过 TensorBoard 工具实时查看训练中间状态,只需要运行 tensorboard -logdir torchgan_logs 命令即可,运行后我们可以在浏览器界面的 http://localhost:6006 URL中查看中间训练过程,如下图所示:

ffb187638e938f2cf012034ec6e3c58d.png

💡 训练与优化

GAN的训练过程是比较缓慢的,大家可能需要一些耐心。 GeForce RTX 3060 显卡 GPU + Ryzen 9 CPU 的设备上,对尺寸为 256x256 的图像数据集进行 2000 次训练大约需要 4 小时。

整个训练过程中,可以看到神经网络逐步生成越来越好的图像,我们把不同阶段的生产效果做成动图,如下所示:

054c030c47fe0790e7c1460ff78aae40.gif

有兴趣大家可以试着调整一下输入参数,也可以采集和提供更多的训练图片,效果可能会更好。

💡 总结

对比之前 ShowMeAI 提到过的 diffuser 模型,我们这里使用 📘DALL-E Mini 的在线版本 也生成了莫奈画作的图像,如下所示:

9b0fea8946d2c79680ab0c81f03bca72.png

我们的DCGAN代码生成的结果分辨率会弱一点:

74bd1e773d9a2ca1802c9049115f61ec.png

DALL-E Mini 的模型结构做过调整,且在数百万张图像进行过训练,比我们几个小时训练完的小模型效果好是正常的。大家如果采集更多的数据,尝试不同模型参数,结果可能会更好,快来一起试一试吧。

参考资料

e9190f41b8de4af38c8a1a0c96f0513b~tplv-k3u1fbpfcp-zoom-1.image

目录
相关文章
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
80 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
1月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
54 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
1月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
45 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 算法 TensorFlow
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
学习率是深度学习中的关键超参数,它影响模型的训练进度和收敛性,过大或过小的学习率都会对网络训练产生负面影响,需要通过适当的设置和调整策略来优化。
258 0
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
|
1月前
|
机器学习/深度学习 算法
【机器学习】揭秘反向传播:深度学习中神经网络训练的奥秘
【机器学习】揭秘反向传播:深度学习中神经网络训练的奥秘
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
3月前
|
机器学习/深度学习
|
3月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
43 0
|
3月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
26 0
|
2天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【10月更文挑战第38天】本文将探讨网络安全与信息安全的重要性,包括网络安全漏洞、加密技术和安全意识等方面。我们将通过代码示例和实际操作来展示如何保护网络和信息安全。无论你是个人用户还是企业,都需要了解这些知识以保护自己的网络安全和信息安全。

热门文章

最新文章