使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下

简介: 生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。

相关资料链接为:使用PyTorch构建GAN生成对抗

本次训练代码使用了本地GPU计算。

文章的上篇讲解了数据集class和鉴别器class,下面将会继续建立生成器class,并完成鉴别器与生成器的对抗。


1 转置卷积


生成器的结构应与鉴别器相逆,因此生成器不再使用卷积操作,而是使用卷积的逆向操作,我们称之为转置卷积(transposed convolution)。

与普通的卷积不同,转置卷积中的卷积核是从输入图像的底部向上卷的,而不是从顶部向下卷。这是通过交换卷积核的行和列,并在每次卷积时对卷积核进行逆时针旋转180度来实现的。

转置卷积的步骤如下:


对卷积核进行行列交换,将其变为与输入图像矩阵相同的大小。

对卷积核进行逆时针旋转180度。

将卷积核应用于输入图像。

因此,转置卷积不同于普通卷积,它不需要对卷积核进行任何变换,因此不需要学习任何参数。此外,转置卷积的输出图像的大小一般大于输入图像的大小,因此可以用于图像放大等应用。




d9738bc27aaa4e329797723c74f24e39.png



图片来源:《Python生成对抗网络编程》


转置卷积用到的函数是nn.ConvTranspose2d。其中的参数有in_channels:输入通道数;out_channels:输出通道数;kernel_size:卷积核的大小;stride:卷积步长;padding:卷积边缘填充;output_padding:输出大小增加的填充;groups:卷积核分组数;bias:是否使用偏置。


2 生成器修改


此次代码中,生成器的总体结构为:输入 — 全连接层 — 转置卷积层1 — 转置卷积层2 — 转置卷积层3 — 输出

卷积层后没有接新的全连接层而直接输出,是为了保障将能够将局部特征直接生成最终图像。


class Generator(nn.Module):
    def __init__(self):
        # 修改后的网络结构
        self.model = nn.Sequential(
            # 输入是一个一维数组
            nn.Linear(100, 3 * 11 * 11),
            nn.LeakyReLU(0.2),
            # 转换成四维
            View((1, 3, 11, 11)),
            nn.ConvTranspose2d(3, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 256, kernel_size=8, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 3, kernel_size=8, stride=2, padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid()
        )



结构讲解:


生成器的输入仍然保留为100个,使用全连接的方式,连接至(3*11*11)维度的张量,激活函数使用漏斗形非线性激活函数LeakReLu。

第一个转置卷积层接受了3个11*11的输入,卷积核大小为8,步长为2,不补全,则输出为28*28。激活函数同样使用LeakReLu。

转置卷积输出的计算方法:

输出 = 输入 + (输入 − 1 ) × (步长 − 1 ) + (卷积核 − 1 ) − 补全 × 2 输出=输入+(输入-1)×(步长-1)+(卷积核-1)-补全×2

输出=输入+(输入−1)×(步长−1)+(卷积核−1)−补全×2

上面的4部分代表了原始输入的大小、因步长而增加的大小、因卷积核而增加的大小(卷积核在边缘时为其增加的空白像素)、因补全而降低的大小(是否补全与补全在函数中定义,本处未定义默认不补全,此处补全的含义与词语含义相反)。

第二个转置卷积层接受了256个28*28的输入,卷积核大小为8,步长为2,不补全,则输出为62*62。激活函数再次使用LeakReLu。

第三个转置卷积层接受了256个28*28的输入,卷积核大小为8,这里卷积核个数恢复为3(如输出通道的RGB颜色相匹配),步长为2,不补全,则输出为128*128。使用常见的sigmoid函数进行输出。

3 生成器测试


在不进行训练的情况下,直接生成图片,已检查生成器class是否存在明显BUG。


G = Generator()
G.to(device)
output = G.forward(generate_random_seed(100))
img = output.detach().permute(0,2,3,1).view(128,128,3).cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


输出如下:




20041c73edb84cd1bfa129b9ea6c7cec.png


因为增加了转置卷积,所以与单纯的随机生成图片不同,边缘颜色的色彩鲜明都要明显弱于中心。


4 训练模型


训练部分与普通的GAN模型相比不需任何修改,此处不再贴代码,有需要可以直接下载查看。

鉴别器和生成器在训练后期,均呈现稳态震荡。




7fae3a01e675448db1b837251b57941d.png2be0de4a07b643979a77c17bc61aaee8.png1e9080e9ff65466d808899dc9279d836.png



与非卷积相比,卷积GAN生成的图片具备明显的五官特征。但五官匹配并不完美,增加训练集与训练次数后在一定程度上可以解决此问题。


下面查看使用卷积后,内存消耗是否有优化


print(torch.cuda.memory_allocated(device) / (1024*1024*1024))
print(torch.cuda.max_memory_allocated(device) / (1024*1024*1024))


结果分别为:

2da9470e49f04e169bcb778c9104c61e.png



可以看出与之前超过1个G的内存占用相比,优化明显。

本案例的完整代码链接可点此下载或文末留言申请。

————————————————


相关文章
|
2月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
15天前
|
机器学习/深度学习 数据采集 PyTorch
构建你的第一个PyTorch神经网络模型
【4月更文挑战第17天】本文介绍了如何使用PyTorch构建和训练第一个神经网络模型。首先,准备数据集,如MNIST。接着,自定义神经网络模型`SimpleNet`,包含两个全连接层和ReLU激活函数。然后,定义交叉熵损失函数和SGD优化器。训练模型涉及多次迭代,计算损失、反向传播和参数更新。最后,测试模型性能,计算测试集上的准确率。这是一个基础的深度学习入门示例,为进一步探索复杂项目打下基础。
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于全连接网络构建RNN并生成人名
【PyTorch实战演练】基于全连接网络构建RNN并生成人名
24 0
|
2月前
|
PyTorch 算法框架/工具 Python
Pytorch构建网络模型时super(__class__, self).__init__()的作用
Pytorch构建网络模型时super(__class__, self).__init__()的作用
10 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch用GAN生成手写数字实例(附代码)
基于Pytorch用GAN生成手写数字实例(附代码)
30 0
|
2月前
|
机器学习/深度学习 算法 大数据
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
31 3
|
2月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的机器学习Regression问题实例(附源码)
基于Pytorch的机器学习Regression问题实例(附源码)
33 1
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
67 2
|
3月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)