pytorch基于AnimeFace128数据集训练DCGAN

简介: 基于AnimeFace128数据集,使用PyTorch构建DCGAN生成动漫人脸。包含生成器与判别器网络设计、数据加载及训练流程,通过对抗学习生成64×64清晰图像。

基础引入

数据集来自魔搭:https://www.modelscope.cn/datasets/yanghaitao/AnimeFace128/files

import os

from PIL import Image
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 预处理层: 100 -> 4*4*1024
        latent_dim = 100
        self.linear1 = nn.Linear(in_features=100, out_features=4*4*1024)
        # view重组: 4*4*1024 -> 1024,4,4
        # 网络组合
        self.model_blocks = nn.Sequential(
            # 第一层网络:1024,4,4 -> 512,8,8
            nn.Upsample(scale_factor=2), # 1024,4,4 -> 1024,8,8
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), # 1024,8,8 -> 512,8,8
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # 第二层网络: 512,8,8 -> 256,16,16
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # 第三层网络: 256,16,16 -> 128,32,32
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # 第四层网络:128,32,32 -> 3,64,64
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(3),
            nn.Tanh()
        )
    # 前向传播方法
    def forward(self, z):
        z = self.linear1(z)
        z = z.view(z.shape[0], 1024, 4, 4)
        img = self.model_blocks(z)
        return img

判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model_blocks = nn.Sequential(
            # 输入: 3,64,64
            # 第一层网络:3,64,64 -> 128,32,32
            nn.Conv2d(3, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第二层网络:128,32,32 -> 256,16,16
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第三层网络:256,16,16 -> 512,8,8
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第四层网络:512,8,8 -> 1024,4,4
            nn.Conv2d(512, 1024, 3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5)
        )
        # 展开: 1024,4,4 -> 4*4*1024
        # 输出层
        self.output = nn.Sequential(
            nn.Linear(in_features=4*4*1024, out_features=1),
            nn.Sigmoid()
        )
    # 前向传播
    def forward(self, x):
        """ x : batch,channel,w,h -> 64,3,64,64"""
        y = self.model_blocks(x)
        y = y.view(x.shape[0], -1)
        y = self.output(y)
        return y

数据加载器

# 数据加载器
class ImgDataset(Dataset):
    # 初始化
    def __init__(self, root_dir, transform=None):
        self.transform=transform
        # 获取所有图片路径
        self.img_paths = []
        for filename in os.listdir(root_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.img_paths.append(os.path.join(root_dir, filename))
        # 打印图片总数
        print(f"All : {len(self.img_paths)}")
    # 长度方法
    def __len__(self):
        return len(self.img_paths)
    # 提取其中一张图片的数据
    def __getitem__(self, idx):
        # 找到加载图片
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        # 转换
        if self.transform:
            img = self.transform(img)
        # 返回图片数据及其分类,所有图片只有一个分类
        return img, 0

训练过程

def get_model_instance(device, g_lr, d_lr, b1, b2):
    # 图片生成器
    generator = Generator().to(device)
    # 图片判别器
    discriminator = Discriminator().to(device)
    # 优化器
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))
    # 损失函数 AnimeFace
    criterion = torch.nn.BCELoss()
    return generator, discriminator, optimizer_g, optimizer_d, criterion

def get_data_loader(root_dir, batch_size):
    # 图片数据转换:所有图片均为128*128, Resize为 64*64
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # 图片加载器
    dataset = ImgDataset(root_dir=root_dir, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        prefetch_factor=4,
        persistent_workers=True,
        drop_last=True
    )
    return dataloader


def append_logs(data, path='dcgan_logs.txt'):
    with open(path, 'a', encoding='utf-8') as f:
        print(data)
        f.write(data + '\n')


def train_model(dataloader, discriminator, optimizer_d, generator, optimizer_g, criterion, epochs, device):
    append_logs(f"start : {datetime.now()}")
    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            # 内部批次
            in_batch_size = real_imgs.shape[0]
            # 图片数据加载到设备
            real_imgs = real_imgs.to(device)
            # 真假标签
            real_labels = torch.ones(in_batch_size, 1, device=device) * 0.8
            fake_labels = torch.ones(in_batch_size, 1, device=device) * 0.2

            # 训练判别器
            optimizer_d.zero_grad()
            # 真实损失
            d_loss_real = criterion(discriminator(real_imgs), real_labels)
            # 生成假图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            with torch.no_grad():
                fake_imgs = generator(z)
            fake_imgs_detach = fake_imgs.detach()
            # 假图片损失
            d_loss_fake = criterion(discriminator(fake_imgs_detach), fake_labels)
            # 总的判别器损失为:
            d_loss = (d_loss_real + d_loss_fake) / 2
            # 反向传播与梯度更新
            d_loss.backward()
            optimizer_d.step()

            # 生成图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            fake_imgs = generator(z)
            # 生成器训练
            optimizer_g.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            g_loss.backward()
            optimizer_g.step()

            # 每100批次记录日志
            if i % 100 == 0:
                append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每轮结束打印
        append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每5轮输出一次样例图片
        if epoch % 2 == 0:
            save_image(
                fake_imgs.data[:25],
                f"sample_{epoch}.png",
                nrow=5,
                normalize=True
            )
            # 保存模型
            torch.save(generator.state_dict(), f'generator{epoch}.pth')
            print(f'generator{epoch}.pth saved')
            torch.save(discriminator.state_dict(), f'discriminator{epoch}.pth')
            print(f'discriminator{epoch}.pth saved')
    append_logs(f"end : {datetime.now()}")

基本参数

base_path = "/opt/notebook/dcgan/images_resized"

batch_size = 64
g_lr = 0.0002
d_lr = 0.0002
b1 = 0.5
b2 = 0.999
epochs = 50
device = torch.device('cuda:0')
print(device)

(generator,
 discriminator,
 optimizer_g,
 optimizer_d,
 criterion) = get_model_instance(device, g_lr, d_lr, b1, b2)

dataloader = get_data_loader(base_path, batch_size)

train_model(
    dataloader,
    discriminator, optimizer_d,
    generator, optimizer_g,
    criterion,
    epochs,
    device
)

# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
print('generator.pth saved')

生成器和判别器相互博弈,两个损失此消彼长不稳定,关键看生成的图片质量如何。

相关文章
|
算法 Java 调度
种群进化+邻域搜索的混合算法(GA+TS)求解作业车间调度问题(JSP)-算法介绍
种群进化+邻域搜索的混合算法(GA+TS)求解作业车间调度问题(JSP)-算法介绍
887 0
种群进化+邻域搜索的混合算法(GA+TS)求解作业车间调度问题(JSP)-算法介绍
|
机器学习/深度学习 编解码 Unix
超分数据集概述和超分经典网络模型总结
超分数据集概述和超分经典网络模型总结
999 1
|
网络安全 Python Windows
ImportError: DLL load failed while importing _ssl: 找不到指定的模块。
找到Anaconda3\pkgs\python-3.8.12-h900ac77_2_cpython\DLLs下的_ssl.pyd文件,查阅在该环境上安装的python版本号,下载python寻找对应的_ssl.pyd覆盖到上述目录中,即可解决问题。
2208 0
|
算法 编译器 调度
HLS-指令使用指南(三)
HLS-指令使用指南
3094 0
HLS-指令使用指南(三)
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
930 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
|
网络协议 Linux 开发工具
Centos7 /etc/sysconfig/network-scripts/ifcfg-<interface>网络配置
自动化网络配置:NetworkManager 可以自动检测网络连接,并根据网络环境自动配置网络。这使得用户可以无需手动配置即可连接到网络。 支持多种网络连接:NetworkManager 支持多种网络连接,包括有线、无线、VPN、Wi-Fi 热点等。这使得用户可以根据需要选择合适的网络连接。 提供图形化和命令行工具:NetworkManager 提供了图形化工具和命令行工具,用户可以根据自己的喜好选择使用。
2585 4

热门文章

最新文章