pytorch基于AnimeFace128数据集训练DCGAN

简介: 基于AnimeFace128数据集,使用PyTorch构建DCGAN生成动漫人脸。包含生成器与判别器网络设计、数据加载及训练流程,通过对抗学习生成64×64清晰图像。

基础引入

数据集来自魔搭:https://www.modelscope.cn/datasets/yanghaitao/AnimeFace128/files

import os

from PIL import Image
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 预处理层: 100 -> 4*4*1024
        latent_dim = 100
        self.linear1 = nn.Linear(in_features=100, out_features=4*4*1024)
        # view重组: 4*4*1024 -> 1024,4,4
        # 网络组合
        self.model_blocks = nn.Sequential(
            # 第一层网络:1024,4,4 -> 512,8,8
            nn.Upsample(scale_factor=2), # 1024,4,4 -> 1024,8,8
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), # 1024,8,8 -> 512,8,8
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # 第二层网络: 512,8,8 -> 256,16,16
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # 第三层网络: 256,16,16 -> 128,32,32
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # 第四层网络:128,32,32 -> 3,64,64
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(3),
            nn.Tanh()
        )
    # 前向传播方法
    def forward(self, z):
        z = self.linear1(z)
        z = z.view(z.shape[0], 1024, 4, 4)
        img = self.model_blocks(z)
        return img

判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model_blocks = nn.Sequential(
            # 输入: 3,64,64
            # 第一层网络:3,64,64 -> 128,32,32
            nn.Conv2d(3, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第二层网络:128,32,32 -> 256,16,16
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第三层网络:256,16,16 -> 512,8,8
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第四层网络:512,8,8 -> 1024,4,4
            nn.Conv2d(512, 1024, 3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5)
        )
        # 展开: 1024,4,4 -> 4*4*1024
        # 输出层
        self.output = nn.Sequential(
            nn.Linear(in_features=4*4*1024, out_features=1),
            nn.Sigmoid()
        )
    # 前向传播
    def forward(self, x):
        """ x : batch,channel,w,h -> 64,3,64,64"""
        y = self.model_blocks(x)
        y = y.view(x.shape[0], -1)
        y = self.output(y)
        return y

数据加载器

# 数据加载器
class ImgDataset(Dataset):
    # 初始化
    def __init__(self, root_dir, transform=None):
        self.transform=transform
        # 获取所有图片路径
        self.img_paths = []
        for filename in os.listdir(root_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.img_paths.append(os.path.join(root_dir, filename))
        # 打印图片总数
        print(f"All : {len(self.img_paths)}")
    # 长度方法
    def __len__(self):
        return len(self.img_paths)
    # 提取其中一张图片的数据
    def __getitem__(self, idx):
        # 找到加载图片
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        # 转换
        if self.transform:
            img = self.transform(img)
        # 返回图片数据及其分类,所有图片只有一个分类
        return img, 0

训练过程

def get_model_instance(device, g_lr, d_lr, b1, b2):
    # 图片生成器
    generator = Generator().to(device)
    # 图片判别器
    discriminator = Discriminator().to(device)
    # 优化器
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))
    # 损失函数 AnimeFace
    criterion = torch.nn.BCELoss()
    return generator, discriminator, optimizer_g, optimizer_d, criterion

def get_data_loader(root_dir, batch_size):
    # 图片数据转换:所有图片均为128*128, Resize为 64*64
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # 图片加载器
    dataset = ImgDataset(root_dir=root_dir, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        prefetch_factor=4,
        persistent_workers=True,
        drop_last=True
    )
    return dataloader


def append_logs(data, path='dcgan_logs.txt'):
    with open(path, 'a', encoding='utf-8') as f:
        print(data)
        f.write(data + '\n')


def train_model(dataloader, discriminator, optimizer_d, generator, optimizer_g, criterion, epochs, device):
    append_logs(f"start : {datetime.now()}")
    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            # 内部批次
            in_batch_size = real_imgs.shape[0]
            # 图片数据加载到设备
            real_imgs = real_imgs.to(device)
            # 真假标签
            real_labels = torch.ones(in_batch_size, 1, device=device) * 0.8
            fake_labels = torch.ones(in_batch_size, 1, device=device) * 0.2

            # 训练判别器
            optimizer_d.zero_grad()
            # 真实损失
            d_loss_real = criterion(discriminator(real_imgs), real_labels)
            # 生成假图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            with torch.no_grad():
                fake_imgs = generator(z)
            fake_imgs_detach = fake_imgs.detach()
            # 假图片损失
            d_loss_fake = criterion(discriminator(fake_imgs_detach), fake_labels)
            # 总的判别器损失为:
            d_loss = (d_loss_real + d_loss_fake) / 2
            # 反向传播与梯度更新
            d_loss.backward()
            optimizer_d.step()

            # 生成图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            fake_imgs = generator(z)
            # 生成器训练
            optimizer_g.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            g_loss.backward()
            optimizer_g.step()

            # 每100批次记录日志
            if i % 100 == 0:
                append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每轮结束打印
        append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每5轮输出一次样例图片
        if epoch % 2 == 0:
            save_image(
                fake_imgs.data[:25],
                f"sample_{epoch}.png",
                nrow=5,
                normalize=True
            )
            # 保存模型
            torch.save(generator.state_dict(), f'generator{epoch}.pth')
            print(f'generator{epoch}.pth saved')
            torch.save(discriminator.state_dict(), f'discriminator{epoch}.pth')
            print(f'discriminator{epoch}.pth saved')
    append_logs(f"end : {datetime.now()}")

基本参数

base_path = "/opt/notebook/dcgan/images_resized"

batch_size = 64
g_lr = 0.0002
d_lr = 0.0002
b1 = 0.5
b2 = 0.999
epochs = 50
device = torch.device('cuda:0')
print(device)

(generator,
 discriminator,
 optimizer_g,
 optimizer_d,
 criterion) = get_model_instance(device, g_lr, d_lr, b1, b2)

dataloader = get_data_loader(base_path, batch_size)

train_model(
    dataloader,
    discriminator, optimizer_d,
    generator, optimizer_g,
    criterion,
    epochs,
    device
)

# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
print('generator.pth saved')

生成器和判别器相互博弈,两个损失此消彼长不稳定,关键看生成的图片质量如何。

相关文章
|
2月前
|
人工智能 自然语言处理 数据挖掘
阿里云百炼支持哪些AI大模型?文本生成、图像生成、语音合成及视频编辑等模型整理
阿里云百炼支持通义千问、通义万相等自研模型及DeepSeek、Kimi、Llama等第三方大模型,覆盖文本生成、图像生成、语音合成、视频生成、向量计算等多类AI能力,助力开发者高效构建应用。新用户可免费领取最高5000万Tokens。
939 156
|
Linux
解决CentOS yum安装Mysql8提示“公钥尚未安装”或“密钥已安装,但是不适用于此软件包”的问题
解决CentOS yum安装Mysql8提示“公钥尚未安装”或“密钥已安装,但是不适用于此软件包”的问题
5909 0
|
7月前
|
机器学习/深度学习 算法 定位技术
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现裂缝的检测识别(C#代码UI界面版)
本项目基于YOLOv8模型与C#界面,结合Baumer工业相机,实现裂缝的高效检测识别。支持图像、视频及摄像头输入,具备高精度与实时性,适用于桥梁、路面、隧道等多种工业场景。
888 27
|
5月前
|
运维 Cloud Native 应用服务中间件
阿里云微服务引擎 MSE 及 API 网关 2025 年 9 月产品动态
阿里云微服务引擎 MSE 面向业界主流开源微服务项目, 提供注册配置中心和分布式协调(原生支持 Nacos/ZooKeeper/Eureka )、云原生网关(原生支持Higress/Nginx/Envoy,遵循Ingress标准)、微服务治理(原生支持 Spring Cloud/Dubbo/Sentinel,遵循 OpenSergo 服务治理规范)能力。API 网关 (API Gateway),提供 APl 托管服务,覆盖设计、开发、测试、发布、售卖、运维监测、安全管控、下线等 API 生命周期阶段。帮助您快速构建以 API 为核心的系统架构.满足新技术引入、系统集成、业务中台等诸多场景需要。
515 142
|
4月前
|
人工智能 自然语言处理 监控
110_微调数据集标注:众包与自动化
在大语言模型(LLM)的微调过程中,高质量的标注数据是模型性能提升的关键因素。随着模型规模的不断扩大和应用场景的日益多样化,如何高效、准确地创建大规模标注数据集成为了研究者和工程师面临的重要挑战。众包与自动化标注技术的结合,为解决这一挑战提供了可行的方案。
|
5月前
|
关系型数据库 Linux PHP
开源站群服务器方案:构建高效流量矩阵的全攻略
正在寻找高性价比、可控性强且功能强大的站群解决方案?小编将深度解析开源站群服务器方案,从核心优势、主流工具选型到部署实践,助您构建稳定、高效的站群流量体系。
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
837 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法

热门文章

最新文章