动手玩玩头像动漫化

简介: 动手玩玩头像动漫化

前言


很久没有弄点好玩的东西了,逛逛github,想到最近看到那种动漫化的头像很可爱,于是决定也去找找相关的项目玩玩,毕竟直接调用百度啥的api太没意思了。


正文


之前写过一些有关GAN的博客,大家应该对GAN有了基本的了解。最基础的内容就是基于零和博弈的思想使得生成器生成的假图像逼真的可以足够骗过判别器,更本质就是生成器从对抗中不断学习,学习到了真实图像中的数据分布。但是真正动手之后会发现训练GAN其实还是比较困难的,难点就在于难以收敛,而且还有模式崩塌的情况出现。

再说回这次的目的,想要实现头像动漫化,这在概念上应该是图像的风格迁移,那最基本就应该会想到内容损失风格损失,当然对具体的问题会提出更多不同的损失。在github上搜索下就会发现这几个名字


  • CartoonGAN
  • AnimeGAN
  • AnimeGAN2再去看看相关的论文和github链接,会发现上面的顺序就是逐渐优化的过程,具体的论文阅读部分我会放在论文专栏,今天的注意力还是集中在实现上。我最终还是选择了SOTA模型AnimeGAN2-tensorflow


当然也有pytorch版本AnimeGAN2-pytorch

我选择的是pytorch版本,因为tensorflow是1.X版本实现,而我环境已经全部转为TF2了,而且越来越觉得TF复杂而杂乱的api很不喜欢,更多的网络实现都是基于keras,有时候又失去了TF本身的灵活性。好啦,又扯远了,现在开始正式动手实现。


开始实现


进去AnimeGAN2-pytorch,如果只需要实现的话不需要下载整个项目,只需要下载model.py


import torch
from torch import nn
import torch.nn.functional as F
class ConvNormLReLU(nn.Sequential):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
        pad_layer = {
            "zero": nn.ZeroPad2d,
            "same": nn.ReplicationPad2d,
            "reflect": nn.ReflectionPad2d,
        }
        if pad_mode not in pad_layer:
            raise NotImplementedError
        super(ConvNormLReLU, self).__init__(
            pad_layer[pad_mode](padding),
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
            nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
            nn.LeakyReLU(0.2, inplace=True)
        )
class InvertedResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, expansion_ratio=2):
        super(InvertedResBlock, self).__init__()
        self.use_res_connect = in_ch == out_ch
        bottleneck = int(round(in_ch * expansion_ratio))
        layers = []
        if expansion_ratio != 1:
            layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
        # dw
        layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
        # pw
        layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
        layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
        self.layers = nn.Sequential(*layers)
    def forward(self, input):
        out = self.layers(input)
        if self.use_res_connect:
            out = input + out
        return out
class Generator(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.block_a = nn.Sequential(
            ConvNormLReLU(3, 32, kernel_size=7, padding=3),
            ConvNormLReLU(32, 64, stride=2, padding=(0, 1, 0, 1)),
            ConvNormLReLU(64, 64)
        )
        self.block_b = nn.Sequential(
            ConvNormLReLU(64, 128, stride=2, padding=(0, 1, 0, 1)),
            ConvNormLReLU(128, 128)
        )
        self.block_c = nn.Sequential(
            ConvNormLReLU(128, 128),
            InvertedResBlock(128, 256, 2),
            InvertedResBlock(256, 256, 2),
            InvertedResBlock(256, 256, 2),
            InvertedResBlock(256, 256, 2),
            ConvNormLReLU(256, 128),
        )
        self.block_d = nn.Sequential(
            ConvNormLReLU(128, 128),
            ConvNormLReLU(128, 128)
        )
        self.block_e = nn.Sequential(
            ConvNormLReLU(128, 64),
            ConvNormLReLU(64, 64),
            ConvNormLReLU(64, 32, kernel_size=7, padding=3)
        )
        self.out_layer = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()
        )
    def forward(self, input, align_corners=True):
        out = self.block_a(input)
        half_size = out.size()[-2:]
        out = self.block_b(out)
        out = self.block_c(out)
        if align_corners:
            out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
        else:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
        out = self.block_d(out)
        if align_corners:
            out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
        else:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
        out = self.block_e(out)
        out = self.out_layer(out)
        return out
复制代码


其中包含了生成器代码,这也是生成动漫风格图片的关键部分。 然后就是需要去对应的网盘下载不同风格下训练好的模型,可能无法访问,所以我把我下载的放到网盘分享。模型链接  密码: n78v


然后就是自己写个生成图像的代码,当然可以根据原项目中的test_faces.ipynb来修改 下面是我的生成代码,需要提前将模型放在同一目录下,并且创建samples文件夹用来存放生成图片


import os
import cv2
import matplotlib.pyplot as plt
import torch
import random
import numpy as np
from model import Generator
def load_image(path, size=None):
    image = image2tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))
    w, h = image.shape[-2:]
    if w != h:
        crop_size = min(w, h)
        left = (w - crop_size) // 2
        right = left + crop_size
        top = (h - crop_size) // 2
        bottom = top + crop_size
        image = image[:, :, left:right, top:bottom]
    if size is not None and image.shape[-1] != size:
        image = torch.nn.functional.interpolate(image, (size, size), mode="bilinear", align_corners=True)
    return image
def image2tensor(image):
    image = torch.FloatTensor(image).permute(2, 0, 1).unsqueeze(0) / 255.
    return (image - 0.5) / 0.5
def tensor2image(tensor):
    tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1, 2, 0).cpu().numpy()
    return tensor * 0.5 + 0.5
def imshow(img, size=5, cmap='jet'):
    plt.figure(figsize=(size, size))
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.show()
if __name__ == '__main__':
    device = 'cuda'
    torch.set_grad_enabled(False)
    image_size = 300
    img=input("")
    model = Generator().eval().to(device)
    ckpt = torch.load(f"./new.pth", map_location=device)
    model.load_state_dict(ckpt)
    result=[]
    image = load_image(f"./face/{img}", image_size)
    output = model(image.to(device))
    result.append(torch.cat([image, output.cpu()], 3))
    result = torch.cat(result, 2)
    imshow(tensor2image(result), 40)
    cv2.imwrite(f'./samples/new+{img}', cv2.cvtColor(255 * tensor2image(result), cv2.COLOR_BGR2RGB))
复制代码


device中可以选择cpu或者cuda(使用GPU)


可能遇到的问题


当电脑中pytorch版本低于1.6时,载入模型时torch.load会报错,但是受限于显卡算力好像不支持安装高版本torch,那只能找其他解决办法。


找个笔记本,一般都可以安装高版本pytorch(大于1.6就行),配好环境之后运行下面代码,     用新生成的模型文件去载入,也就是我代码中的new.pth修改模型文件的代码


import torch
weight = torch.load("高版本的模型地址")
torch.save(weight, '自定义新的模型地址', _use_new_zipfile_serialization=False)
复制代码


效果


image.png

效果还是比较不错的,生成图片的速度也比较快,但是我感觉还是存在一些问题


  • 由于代码中会对图片进行截取,所以有可能丢失原图中的信息
  • 可能是训练数据集的问题,传入光照不好的图片出现的效果很不好,轮廓曲线十分模糊。


目录
相关文章
|
9月前
|
iOS开发 Python Windows
|
域名解析 小程序 Linux
朋友圈超火的盲盒交友小程序,完整搭建教程及源码分享~(多图)
朋友圈超火的盲盒交友小程序,完整搭建教程及源码分享~(多图)
朋友圈超火的盲盒交友小程序,完整搭建教程及源码分享~(多图)
|
9月前
|
Python Windows
|
9月前
|
自然语言处理 Python Windows
|
11月前
|
人工智能 Serverless 异构计算
【有奖体验】叮!你有一张 3D 卡通头像请查收
立即体验基于函数计算部署【图生图】一键部署 3D 卡通风格模型,秒生成属于自己的 3D 卡通图!
|
12月前
|
小程序
女神节你也能自己动手制作一个漂亮的微信小游戏
嗨!大家好,我是小蚂蚁。 这是我之前制作的一个非常漂亮的微信小游戏,你可以给予它进行改编,然后自己制作一个小游戏送给你想送的人。 我发现这篇文章每年至少可以发四次,情人节一次,女神节一次,520一次,七夕一次[阴险]。 今年的我做了点儿改进,增加了一首背景音乐,是经典的《卡农》钢琴曲,希望你喜欢。
71 0
|
12月前
|
小程序
|
机器学习/深度学习 人工智能 开发者
尼日利亚学生使用 PAI 打造卡通头像神器|学习笔记
快速学习尼日利亚学生使用 PAI 打造卡通头像神器。
64 0
uiu
|
JavaScript 前端开发
制作别踩白块网页小游戏
制作别踩白块网页小游戏
uiu
116 0
制作别踩白块网页小游戏
|
小程序 API 开发者
基于wxapp的圣诞帽头像小程序【完整项目源码】
基于wxapp的圣诞帽头像小程序【完整项目源码】
基于wxapp的圣诞帽头像小程序【完整项目源码】