手把手 | 初学者如何用Chainer为漫画上色 深度学习帮你逆袭漫画家(附代码)-阿里云开发者社区

开发者社区> 大数据文摘> 正文

手把手 | 初学者如何用Chainer为漫画上色 深度学习帮你逆袭漫画家(附代码)

简介:


0?wx_fmt=gif

最近一直有人说深度学习(Deep Learning)的附加价值高,于是我也在一两个月前开始学习chainer了。机会难得就想试着用chainer做一些各种各样的尝试,比如写个给线描上色的小程序之类的。

线描上色这个任务的性质是监督式学习(supervised learning),因此需要大量的线稿和上完色的图片,越多越好。

这次是用opencv用角色的画像把线稿生成了出来。生成例子如下:

0?wx_fmt=jpeg

0?wx_fmt=jpeg

收集了角色们的画像,将其转化为线稿之后数据集(dataset)就完成了。这次用了约60万张图片。

关于神经网络(neural network)的结构,我用了一种叫做U-net的网络。它的特点是会把卷积(convolution)和反卷积(deconvolution)的层混合着连接在一起。这样就可以做到参照着一开始的线稿来上色。(译者注:有兴趣的可以看一下unet的论文以及文末提供的神经网络代码。)

这样生成出来的图像和原来上完色的图像对比然后取平方差,神经网络的训练让平方差最小即可。

只要保证网络每一层的输入和下一层的输出相匹配就大体ok,但是自己定义的数据如何制作这方面,因为没有什么例子所以可能对各位来说有些难懂。

花了整整一晚上时间使劲用data喂饱了neural net(掩面)的结果↓

0?wx_fmt=jpeg

嗯~神经网络这边大概是想说“肌肤的颜色我大概能搞懂但是除此之外的实在不知道啊,发色啊衣服颜色啊我当然不可能知道吧。”

这里要登场的是叫对抗网络(adversarial net)的神经网络,简称“怼”。

“怼”要做的是学习真正的图像和被神经网络生成出来的图像之间的颜色差别,然后找出两个图像中的那个叛徒。

所以如果神经网络一直生成老照片那样颜色的图像的话“怼”只要学一会儿就能准确的找出哪一张是神经网络生成的。

但是如果“怼”太用力的话上色的神经网络会拼命反抗导致上色失败请多加注意。

0?wx_fmt=jpeg

这样的颜色都已经说不上是线稿上色后的东西而更加接近艺术了。(嘛,顺着这条道走,把上色用神经网络怼成艺术生也不是不可以。。。现在暂时还是回到学习和原画之间的差别上面吧)

0?wx_fmt=jpeg

呼,上色终于完成了!照着这个势头下去再接着干吧。

第一阶段是学习了128x128的图像,第二阶段是给512x512的图像学习上色。以下是没“怼”过的训练结果↓

0?wx_fmt=jpeg

还不错哟

0?wx_fmt=jpeg

不错不错。

0?wx_fmt=jpeg

把实际的线稿拿来喂给神经网络如何?我从pixiv上借用来了线稿类的画

0?wx_fmt=jpeg

(因为神经网络大部分都是卷积神经网络(cnn),宽高比有一些变化也没关系)

0?wx_fmt=jpeg

好棒!

变成了彩色的怪物了,嘛这种怪物也是有的。

平安收工。

说回来,果然还是会想亲手在线稿上上一些色吧?于是稍微改变一下输入,和一阶段不同的是在原来的线稿之外多加了三个输入层(rgb),给神经网络一些用色上的提示吧。

总之:

0?wx_fmt=jpeg

茶色的头发淡蓝的水手服和藏青的裙子,之类的要求也可以提了。

稍微霸气的像这样画上一笔也是可以的。

0?wx_fmt=jpeg0?wx_fmt=jpeg

不管是大概的提示也好非常用心的每个细节都提示也好效果都不错。(可能有些难懂,就是用不同的颜色在各种地方点一下来提示,比如下面)

0?wx_fmt=jpeg

这样我也从工程师升职成画师了!

至此,我觉得线稿的自动上色和带提示的上色已经做的还不错了。

虽然还不如画师们认真画出来的,如果想随随便便涂个色还是非常方便的(译者:比如在经费不够的情况下。。。)。

漫画之类的也是,比起用网点贴纸(Screen tone)还是大致的上一个色比较快速方便的。(这次的神经网络非常擅长给肤色上色。。。我想说的各位懂得吧~)

顺便补充一下,弱点还是有几个的。

例如同时用对抗网络和上色提示一起训练的时候,上色提示会干涉到对抗网络,有时会导致训练结果不稳定。

0?wx_fmt=jpeg

↑明明只是想给泳装上一个不同的颜色,结果其他部分的颜色也跟着变了。

如果只是作为一个简单的上色工具的话,只加提示来训练神经网络可能会更加稳定。

另外,如果线稿的线太粗或者太细的情况下,线会崩坏掉导致结果不怎么样的情况也有,仔细的给了上色提示但是没有反应在结果上的情况也有。

不同的细节都用同一个神经网络来对付虽然比较厉害,但是作为工具使用的时候需要根据用途来做一些调整。

借鉴的线稿原画:

「【プリンセスロワイヤル】パンドラ」/「鉛筆工房【IRITH】」[pixiv] (http://www.pixiv.net/member_illust.php?mode=medium&illust;_id=31274285)

線画詰め(http://www.pixiv.net/member_illust.php?mode=manga&illust;_id=43369404)

「改弐」/「炬燵魂」[pixiv](http://www.pixiv.net/member_illust.php?mode=medium&illust;_id=56689287)

「めりくり線画」/「タマコ」[pixiv](http://www.pixiv.net/member_illust.php?mode=medium&illust;_id=40487409)

「泰1」/「20100301」[pixiv](http://www.pixiv.net/member_illust.php?mode=medium&illust;_id=10552795)

※生成的线稿的训练材料和原画一时半会儿找不到,实在抱歉,还请多包涵。


顺便,这次的神经网络第一阶段和第二阶段的构造都是一样的,基本上感觉如下:


unet.py

class UNET(chainer.Chain):
def __init__(self):
super(UNET, self).__init__(
c0 = L.Convolution2D(4, 32, 3, 1, 1),
c1 = L.Convolution2D(32, 64, 4, 2, 1),
c2 = L.Convolution2D(64, 64, 3, 1, 1),
c3 = L.Convolution2D(64, 128, 4, 2, 1),
c4 = L.Convolution2D(128, 128, 3, 1, 1),
c5 = L.Convolution2D(128, 256, 4, 2, 1),
c6 = L.Convolution2D(256, 256, 3, 1, 1),
c7 = L.Convolution2D(256, 512, 4, 2, 1),
c8 = L.Convolution2D(512, 512, 3, 1, 1),

dc8 = L.Deconvolution2D(1024, 512, 4, 2, 1),
dc7 = L.Convolution2D(512, 256, 3, 1, 1),
dc6 = L.Deconvolution2D(512, 256, 4, 2, 1),
dc5 = L.Convolution2D(256, 128, 3, 1, 1),
dc4 = L.Deconvolution2D(256, 128, 4, 2, 1),
dc3 = L.Convolution2D(128, 64, 3, 1, 1),
dc2 = L.Deconvolution2D(128, 64, 4, 2, 1),
dc1 = L.Convolution2D(64, 32, 3, 1, 1),
dc0 = L.Convolution2D(64, 3, 3, 1, 1),

bnc0 = L.BatchNormalization(32),
bnc1 = L.BatchNormalization(64),
bnc2 = L.BatchNormalization(64),
bnc3 = L.BatchNormalization(128),
bnc4 = L.BatchNormalization(128),
bnc5 = L.BatchNormalization(256),
bnc6 = L.BatchNormalization(256),
bnc7 = L.BatchNormalization(512),
bnc8 = L.BatchNormalization(512),

bnd8 = L.BatchNormalization(512),
bnd7 = L.BatchNormalization(256),
bnd6 = L.BatchNormalization(256),
bnd5 = L.BatchNormalization(128),
bnd4 = L.BatchNormalization(128),
bnd3 = L.BatchNormalization(64),
bnd2 = L.BatchNormalization(64),
bnd1 = L.BatchNormalization(32)
)

def calc(self,x, test = False):
e0 = F.relu(self.bnc0(self.c0(x), test=test))
e1 = F.relu(self.bnc1(self.c1(e0), test=test))
e2 = F.relu(self.bnc2(self.c2(e1), test=test))
e3 = F.relu(self.bnc3(self.c3(e2), test=test))
e4 = F.relu(self.bnc4(self.c4(e3), test=test))
e5 = F.relu(self.bnc5(self.c5(e4), test=test))
e6 = F.relu(self.bnc6(self.c6(e5), test=test))
e7 = F.relu(self.bnc7(self.c7(e6), test=test))
e8 = F.relu(self.bnc8(self.c8(e7), test=test))

d8 = F.relu(self.bnd8(self.dc8(F.concat([e7, e8])), test=test))
d7 = F.relu(self.bnd7(self.dc7(d8), test=test))
d6 = F.relu(self.bnd6(self.dc6(F.concat([e6, d7])), test=test))
d5 = F.relu(self.bnd5(self.dc5(d6), test=test))
d4 = F.relu(self.bnd4(self.dc4(F.concat([e4, d5])), test=test))
d3 = F.relu(self.bnd3(self.dc3(d4), test=test))
d2 = F.relu(self.bnd2(self.dc2(F.concat([e2, d3])), test=test))
d1 = F.relu(self.bnd1(self.dc1(d2), test=test))
d0 = self.dc0(F.concat([e0, d1]))

return d0

“怼”

adv.py

class DIS(chainer.Chain):
def __init__(self):
super(DIS, self).__init__(
c1 = L.Convolution2D(3, 32, 4, 2, 1),
c2 = L.Convolution2D(32, 32, 3, 1, 1),
c3 = L.Convolution2D(32, 64, 4, 2, 1),
c4 = L.Convolution2D(64, 64, 3, 1, 1),
c5 = L.Convolution2D(64, 128, 4, 2, 1),
c6 = L.Convolution2D(128, 128, 3, 1, 1),
c7 = L.Convolution2D(128, 256, 4, 2, 1),
l8l = L.Linear(None, 2, wscale=0.02*math.sqrt(8*8*256)),

bnc1 = L.BatchNormalization(32),
bnc2 = L.BatchNormalization(32),
bnc3 = L.BatchNormalization(64),
bnc4 = L.BatchNormalization(64),
bnc5 = L.BatchNormalization(128),
bnc6 = L.BatchNormalization(128),
bnc7 = L.BatchNormalization(256),
)

def calc(self,x, test = False):
h = F.relu(self.bnc1(self.c1(x), test=test))
h = F.relu(self.bnc2(self.c2(h), test=test))
h = F.relu(self.bnc3(self.c3(h), test=test))
h = F.relu(self.bnc4(self.c4(h), test=test))
h = F.relu(self.bnc5(self.c5(h), test=test))
h = F.relu(self.bnc6(self.c6(h), test=test))
h = F.relu(self.bnc7(self.c7(h), test=test))
return self.l8l(h)

原文发布时间为:2017-03-04

本文来自云栖社区合作伙伴“大数据文摘”,了解相关信息可以关注“BigDataDigest”微信公众号

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

分享:
大数据文摘
使用钉钉扫一扫加入圈子
+ 订阅

官方博客
官网链接