TensorFlow 2 和 Keras 高级深度学习:6~10(1)https://developer.aliyun.com/article/1426950
我们提供函数来构建Discriminator[0]
和Discriminator[1]
(dis0
和dis1
)。 dis0
判别器类似于 GAN 判别器,除了特征向量输入和辅助网络Q[0]
,其恢复z[0]
。 gan.py
中的构造器函数用于创建dis0
:
dis0 = gan.discriminator(inputs, num_codes=z_dim)
dis1
判别器由三层 MLP 组成,如清单 6.2.3 所示。 最后一层将区分为真假f[1]
。Q[1]
网络共享dis1
的前两层。 其第三层回收z[1]
。
“列表 6.2.3”:stackedgan-mnist-6.2.1.py
def build_discriminator(inputs, z_dim=50): """Build Discriminator 1 Model
Classifies feature1 (features) as real/fake image and recovers the input noise or latent code (by minimizing entropy loss)
# Arguments inputs (Layer): feature1 z_dim (int): noise dimensionality
# Returns dis1 (Model): feature1 as real/fake and recovered latent code """
# input is 256-dim feature1 x = Dense(256, activation='relu')(inputs) x = Dense(256, activation='relu')(x)
# first output is probability that feature1 is real f1_source = Dense(1)(x) f1_source = Activation('sigmoid', name='feature1_source')(f1_source)
# z1 reonstruction (Q1 network) z1_recon = Dense(z_dim)(x) z1_recon = Activation('tanh', name='z1')(z1_recon)
discriminator_outputs = [f1_source, z1_recon] dis1 = Model(inputs, discriminator_outputs, name='dis1') return dis1
有了所有可用的构建器函数,StackedGAN 就会在“列表 6.2.4”中进行组装。 在训练 StackedGAN 之前,对编码器进行了预训练。 请注意,我们已经在对抗模型训练中纳入了三个生成器损失函数(对抗,条件和熵)。Q
网络与判别器模型共享一些公共层。 因此,其损失函数也被纳入判别器模型训练中。
“列表 6.2.4”:stackedgan-mnist-6.2.1.py
def build_and_train_models(): """Load the dataset, build StackedGAN discriminator, generator, and adversarial models. Call the StackedGAN train routine. """ # load MNIST dataset (x_train, y_train), (x_test, y_test) = mnist.load_data()
# reshape and normalize images image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255
x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) x_test = x_test.astype('float32') / 255
# number of labels num_labels = len(np.unique(y_train)) # to one-hot vector y_train = to_categorical(y_train) y_test = to_categorical(y_test)
model_name = "stackedgan_mnist" # network parameters batch_size = 64 train_steps = 10000 lr = 2e-4 decay = 6e-8 input_shape = (image_size, image_size, 1) label_shape = (num_labels, ) z_dim = 50 z_shape = (z_dim, ) feature1_dim = 256 feature1_shape = (feature1_dim, )
# build discriminator 0 and Q network 0 models inputs = Input(shape=input_shape, name='discriminator0_input') dis0 = gan.discriminator(inputs, num_codes=z_dim) # [1] uses Adam, but discriminator converges easily with RMSprop optimizer = RMSprop(lr=lr, decay=decay) # loss fuctions: 1) probability image is real (adversarial0 loss) # 2) MSE z0 recon loss (Q0 network loss or entropy0 loss) loss = ['binary_crossentropy', 'mse'] loss_weights = [1.0, 10.0] dis0.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) dis0.summary() # image discriminator, z0 estimator
# build discriminator 1 and Q network 1 models input_shape = (feature1_dim, ) inputs = Input(shape=input_shape, name='discriminator1_input') dis1 = build_discriminator(inputs, z_dim=z_dim ) # loss fuctions: 1) probability feature1 is real # (adversarial1 loss) # 2) MSE z1 recon loss (Q1 network loss or entropy1 loss) loss = ['binary_crossentropy', 'mse'] loss_weights = [1.0, 1.0] dis1.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) dis1.summary() # feature1 discriminator, z1 estimator
# build generator models feature1 = Input(shape=feature1_shape, name='feature1_input') labels = Input(shape=label_shape, name='labels') z1 = Input(shape=z_shape, name="z1_input") z0 = Input(shape=z_shape, name="z0_input") latent_codes = (labels, z0, z1, feature1) gen0, gen1 = build_generator(latent_codes, image_size) gen0.summary() # image generator gen1.summary() # feature1 generator
# build encoder models input_shape = (image_size, image_size, 1) inputs = Input(shape=input_shape, name='encoder_input') enc0, enc1 = build_encoder((inputs, feature1), num_labels) enc0.summary() # image to feature1 encoder enc1.summary() # feature1 to labels encoder (classifier) encoder = Model(inputs, enc1(enc0(inputs))) encoder.summary() # image to labels encoder (classifier)
data = (x_train, y_train), (x_test, y_test) train_encoder(encoder, data, model_name=model_name)
# build adversarial0 model = # generator0 + discriminator0 + encoder0 optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5) # encoder0 weights frozen enc0.trainable = False # discriminator0 weights frozen dis0.trainable = False gen0_inputs = [feature1, z0] gen0_outputs = gen0(gen0_inputs) adv0_outputs = dis0(gen0_outputs) + [enc0(gen0_outputs)] # feature1 + z0 to prob feature1 is # real + z0 recon + feature0/image recon adv0 = Model(gen0_inputs, adv0_outputs, name="adv0") # loss functions: 1) prob feature1 is real (adversarial0 loss) # 2) Q network 0 loss (entropy0 loss) # 3) conditional0 loss loss = ['binary_crossentropy', 'mse', 'mse'] loss_weights = [1.0, 10.0, 1.0] adv0.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) adv0.summary()
# build adversarial1 model = # generator1 + discriminator1 + encoder1 # encoder1 weights frozen enc1.trainable = False # discriminator1 weights frozen dis1.trainable = False gen1_inputs = [labels, z1] gen1_outputs = gen1(gen1_inputs) adv1_outputs = dis1(gen1_outputs) + [enc1(gen1_outputs)] # labels + z1 to prob labels are real + z1 recon + feature1 recon adv1 = Model(gen1_inputs, adv1_outputs, name="adv1") # loss functions: 1) prob labels are real (adversarial1 loss) # 2) Q network 1 loss (entropy1 loss) # 3) conditional1 loss (classifier error) loss_weights = [1.0, 1.0, 1.0] loss = ['binary_crossentropy', 'mse', 'categorical_crossentropy'] adv1.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) adv1.summary()
# train discriminator and adversarial networks models = (enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1) params = (batch_size, train_steps, num_labels, z_dim, model_name) train(models, data, params)
最后,训练函数与典型的 GAN 训练相似,不同之处在于我们一次只训练一个 GAN(即GAN[0]
然后是GAN[0]
)。 代码显示在“列表 6.2.5”中。 值得注意的是,训练顺序为:
Discriminator[1]
和Q[1]
网络通过最小化判别器和熵损失Discriminator[0]
和Q[0]
网络通过最小化判别器和熵损失Adversarial[1]
网络通过最小化对抗性,熵和条件损失Adversarial[0]
网络通过最小化对抗性,熵和条件损失
“列表 6.2.5”:stackedgan-mnist-6.2.1.py
def train(models, data, params): """Train the discriminator and adversarial Networks
Alternately train discriminator and adversarial networks by batch. Discriminator is trained first with real and fake images, corresponding one-hot labels and latent codes. Adversarial is trained next with fake images pretending to be real, corresponding one-hot labels and latent codes. Generate sample images per save_interval.
# Arguments models (Models): Encoder, Generator, Discriminator, Adversarial models data (tuple): x_train, y_train data params (tuple): Network parameters
""" # the StackedGAN and Encoder models enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1 = models # network parameters batch_size, train_steps, num_labels, z_dim, model_name = params # train dataset (x_train, y_train), (_, _) = data # the generator image is saved every 500 steps save_interval = 500
# label and noise codes for generator testing z0 = np.random.normal(scale=0.5, size=[16, z_dim]) z1 = np.random.normal(scale=0.5, size=[16, z_dim]) noise_class = np.eye(num_labels)[np.arange(0, 16) % num_labels] noise_params = [noise_class, z0, z1] # number of elements in train dataset train_size = x_train.shape[0] print(model_name, "Labels for generated images: ", np.argmax(noise_class, axis=1))
for i in range(train_steps): # train the discriminator1 for 1 batch # 1 batch of real (label=1.0) and fake feature1 (label=0.0) # randomly pick real images from dataset rand_indexes = np.random.randint(0, train_size, size=batch_size) real_images = x_train[rand_indexes] # real feature1 from encoder0 output real_feature1 = enc0.predict(real_images) # generate random 50-dim z1 latent code real_z1 = np.random.normal(scale=0.5, size=[batch_size, z_dim]) # real labels from dataset real_labels = y_train[rand_indexes]
# generate fake feature1 using generator1 from # real labels and 50-dim z1 latent code fake_z1 = np.random.normal(scale=0.5, size=[batch_size, z_dim]) fake_feature1 = gen1.predict([real_labels, fake_z1])
# real + fake data feature1 = np.concatenate((real_feature1, fake_feature1)) z1 = np.concatenate((fake_z1, fake_z1))
# label 1st half as real and 2nd half as fake y = np.ones([2 * batch_size, 1]) y[batch_size:, :] = 0
# train discriminator1 to classify feature1 as # real/fake and recover # latent code (z1). real = from encoder1, # fake = from genenerator1 # joint training using discriminator part of # advserial1 loss and entropy1 loss metrics = dis1.train_on_batch(feature1, [y, z1]) # log the overall loss only log = "%d: [dis1_loss: %f]" % (i, metrics[0])
# train the discriminator0 for 1 batch # 1 batch of real (label=1.0) and fake images (label=0.0) # generate random 50-dim z0 latent code fake_z0 = np.random.normal(scale=0.5, size=[batch_size, z_dim]) # generate fake images from real feature1 and fake z0 fake_images = gen0.predict([real_feature1, fake_z0]) # real + fake data x = np.concatenate((real_images, fake_images)) z0 = np.concatenate((fake_z0, fake_z0)) # train discriminator0 to classify image # as real/fake and recover latent code (z0) # joint training using discriminator part of advserial0 loss # and entropy0 loss metrics = dis0.train_on_batch(x, [y, z0]) # log the overall loss only (use dis0.metrics_names) log = "%s [dis0_loss: %f]" % (log, metrics[0])
# adversarial training # generate fake z1, labels fake_z1 = np.random.normal(scale=0.5, size=[batch_size, z_dim]) # input to generator1 is sampling fr real labels and # 50-dim z1 latent code gen1_inputs = [real_labels, fake_z1]
# label fake feature1 as real y = np.ones([batch_size, 1])
# train generator1 (thru adversarial) by fooling i # the discriminator # and approximating encoder1 feature1 generator # joint training: adversarial1, entropy1, conditional1 metrics = adv1.train_on_batch(gen1_inputs, [y, fake_z1, real_labels]) fmt = "%s [adv1_loss: %f, enc1_acc: %f]" # log the overall loss and classification accuracy log = fmt % (log, metrics[0], metrics[6])
# input to generator0 is real feature1 and # 50-dim z0 latent code fake_z0 = np.random.normal(scale=0.5, size=[batch_size, z_dim]) gen0_inputs = [real_feature1, fake_z0]
# train generator0 (thru adversarial) by fooling # the discriminator and approximating encoder1 imag # source generator joint training: # adversarial0, entropy0, conditional0 metrics = adv0.train_on_batch(gen0_inputs, [y, fake_z0, real_feature1]) # log the overall loss only log = "%s [adv0_loss: %f]" % (log, metrics[0])
print(log) if (i + 1) % save_interval == 0: generators = (gen0, gen1) plot_images(generators, noise_params=noise_params, show=False, step=(i + 1), model_name=model_name)
# save the modelis after training generator0 & 1 # the trained generator can be reloaded for # future MNIST digit generation gen1.save(model_name + "-gen1.h5") gen0.save(model_name + "-gen0.h5")
tf.keras
中 StackedGAN 的代码实现现已完成。 训练后,可以评估生成器的输出以检查合成 MNIST 数字的某些属性是否可以以与我们在 InfoGAN 中所做的类似的方式进行控制。
StackedGAN 的生成器输出
在对 StackedGAN 进行 10,000 步训练之后,Generator[0]
和Generator[1]
模型被保存在文件中。 Generator[0]
和Generator[1]
堆叠在一起可以合成以标签和噪声代码z[0]
和z[1]
为条件的伪造图像。
StackedGAN 生成器可以通过以下方式进行定性验证:
- 从两个噪声代码
z[0]
和z[1]
的离散标签从 0 变到 9,从正态分布中采样,均值为 0.5,标准差为 1.0。 结果显示在“图 6.2.9”中。 我们可以看到 StackedGAN 离散代码可以控制生成器生成的数字:
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --digit=0
- 至
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --digit=9
- [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KUD8zxRt-1681704311652)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_16.png)]
图 6.2.9:当离散代码从 0 变为 9 时,StackedGAN 生成的图像。z0
和z1
均从正态分布中采样,平均值为 0,标准差为 0.5。 - 如下所示,将第一噪声码
z[0]
从 -4.0 到 4.0 的恒定向量变为从 0 到 9 的数字。 第二噪声代码z[1]
被设置为零向量。 “图 6.2.10”显示第一个噪声代码控制数字的粗细。 例如,对于数字 8:
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --z0=0 --z1=0 --p0 --digit=8
- [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A6w2h8Wp-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_17.png)]
图 6.2.10:使用 StackedGAN 作为第一个噪声代码z0
生成的图像,对于数字 0 到 9,其向量从 -4.0 到 4.0 不变。z0
似乎控制着每个数字的粗细。 - 如下所示,对于数字 0 到 9,从 -1.0 到 1.0 的恒定向量变化第二噪声代码
z[1]
。 将第一噪声代码z[0]
设置为零向量。“图 6.2.11”显示第二个噪声代码控制旋转(倾斜),并在一定程度上控制手指的粗细。 例如,对于数字 8:
python3 stackedgan-mnist-6.2.1.py --generator0=stackedgan_mnist-gen0.h5 --generator1=stackedgan_mnist-gen1.h5 --z0=0 --z1=0 --p1 --digit=8
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-363w44bd-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_18.png)]
图 6.2.11:由 StackedGAN 生成的图像作为第二个噪声代码z1
从 0 到 9 的恒定向量 -1.0 到 1.0 变化。z1
似乎控制着每个数字的旋转(倾斜)和笔划粗细
“图 6.2.9”至“图 6.2.11”证明 StackedGAN 提供了对生成器输出属性的附加控制。 控件和属性为(标签,哪个数字),(z0
,数字粗细)和(z1
,数字倾斜度)。 从此示例中,我们可以控制其他可能的实验,例如:
- 从当前数量 2 增加栈中的元素数量
- 像在 InfoGAN 中一样,减小代码
z[0]
和z[1]
的尺寸
“图 6.2.12”显示了 InfoGAN 和 StackedGAN 的潜在代码之间的区别:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-S88zFsW8-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_06_19.png)]
图 6.2.12:不同 GAN 的潜在表示
解开代码的基本思想是对损失函数施加约束,以使仅特定属性受代码影响。 从结构上讲,与 StackedGAN 相比,InfoGAN 更易于实现。 InfoGAN 的训练速度也更快。
4. 总结
在本章中,我们讨论了如何解开 GAN 的潜在表示。 在本章的前面,我们讨论了 InfoGAN 如何最大化互信息以迫使生成器学习解纠缠的潜向量。 在 MNIST 数据集示例中,InfoGAN 使用三种表示形式和一个噪声代码作为输入。 噪声以纠缠的形式表示其余的属性。 StackedGAN 以不同的方式处理该问题。 它使用一堆编码器 GAN 来学习如何合成伪造的特征和图像。 首先对编码器进行训练,以提供特征数据集。 然后,对编码器 GAN 进行联合训练,以学习如何使用噪声代码控制生成器输出的属性。
在下一章中,我们将着手一种新型的 GAN,它能够在另一个域中生成新数据。 例如,给定马的图像,GAN 可以将其自动转换为斑马的图像。 这种 GAN 的有趣特征是无需监督即可对其进行训练,并且不需要成对的样本数据。
5. 参考
Xi Chen et al.: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets. Advances in Neural Information Processing Systems, 2016 (http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf).
Xun Huang et al. Stacked Generative Adversarial Networks. IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Vol. 2, 2017 (http://openaccess.thecvf.com/content_cvpr_2017/papers/Huang_Stacked_Generative_Adversarial_CVPR_2017_paper.pdf).
七、跨域 GAN
在计算机视觉,计算机图形学和图像处理中,许多任务涉及将图像从一种形式转换为另一种形式。 灰度图像的着色,将卫星图像转换为地图,将一位艺术家的艺术品风格更改为另一位艺术家,将夜间图像转换为白天,将夏季照片转换为冬天只是几个例子。 这些任务被称为跨域迁移,将成为本章的重点。 源域中的图像将迁移到目标域,从而生成新的转换图像。
跨域迁移在现实世界中具有许多实际应用。 例如,在自动驾驶研究中,收集公路现场驾驶数据既费时又昂贵。 为了在该示例中覆盖尽可能多的场景变化,将在不同的天气条件,季节和时间中遍历道路,从而为我们提供了大量不同的数据。 使用跨域迁移,可以通过转换现有图像来生成看起来真实的新合成场景。 例如,我们可能只需要在夏天从一个区域收集道路场景,在冬天从另一地方收集道路场景。 然后,我们可以将夏季图像转换为冬季,并将冬季图像转换为夏季。 在这种情况下,它将必须完成的任务数量减少了一半。
现实的合成图像的生成是 GAN 擅长的领域。 因此,跨域翻译是 GAN 的应用之一。 在本章中,我们将重点介绍一种流行的跨域 GAN 算法,称为 CycleGAN [2]。 与其他跨域迁移算法(例如 pix2pix [3])不同,CycleGAN 不需要对齐的训练图像即可工作。 在对齐的图像中,训练数据应该是由源图像及其对应的目标图像组成的一对图像; 例如,卫星图像和从该图像得出的相应地图。
CycleGAN 仅需要卫星数据图像和地图。 这些地图可以来自其他卫星数据,而不必事先从训练数据中生成。
在本章中,我们将探讨以下内容:
- CycleGAN 的原理,包括其在
tf.keras
中的实现 - CycleGAN 的示例应用,包括使用 CIFAR10 数据集对灰度图像进行着色和应用于 MNIST 数字和街景门牌号码(SVHN) [1]数据集的样式迁移
让我们开始讨论 CycleGAN 背后的原理。
1. CycleGAN 的原理
将图像从一个域转换到另一个域是计算机视觉,计算机图形学和图像处理中的常见任务。“图 7.1.1”显示了边缘检测,这是常见的图像转换任务:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DKJSNQev-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_01.png)]
图 7.1.1:对齐图像对的示例:使用 Canny 边缘检测器的左,原始图像和右,变换后的图像。 原始照片是作者拍摄的。
在此示例中,我们可以将真实照片(左)视为源域中的图像,将边缘检测的照片(右)视为目标域中的样本。 还有许多其他具有实际应用的跨域翻译过程,例如:
- 卫星图像到地图
- 脸部图像到表情符号,漫画或动画
- 身体图像到头像
- 灰度照片的着色
- 医学扫描到真实照片
- 真实照片到画家的绘画
在不同领域中还有许多其他示例。 例如,在计算机视觉和图像处理中,我们可以通过发明一种从源图像中提取特征并将其转换为目标图像的算法来执行翻译。 坎尼边缘算子就是这种算法的一个例子。 但是,在很多情况下,翻译对于手工工程师而言非常复杂,因此几乎不可能找到合适的算法。 源域分布和目标域分布都是高维且复杂的。
解决图像翻译问题的一种方法是使用深度学习技术。 如果我们具有来自源域和目标域的足够大的数据集,则可以训练神经网络对转换进行建模。 由于必须在给定源图像的情况下自动生成目标域中的图像,因此它们必须看起来像是来自目标域的真实样本。 GAN 是适合此类跨域任务的网络。 pix2pix [3]算法是跨域算法的示例。
pix2pix 算法与条件 GAN(CGAN)[4]相似,我们在“第 4 章”,“生成对抗网络(GAN)”。 我们可以回想起在 CGAN 中,除了z
噪声输入之外,诸如单热向量之类的条件会限制生成器的输出。 例如,在 MNIST 数字中,如果我们希望生成器输出数字 8,则条件为单热向量[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
。 在 pix2pix 中,条件是要翻译的图像。 生成器的输出是翻译后的图像。 通过优化 CGAN 损失来训练 pix2pix 算法。 为了使生成的图像中的模糊最小化,还包括 L1 损失。
类似于 pix2pix 的神经网络的主要缺点是训练输入和输出图像必须对齐。“图 7.1.1”是对齐的图像对的示例。 样本目标图像是从源生成的。 在大多数情况下,对齐的图像对不可用或无法从源图像生成,也不昂贵,或者我们不知道如何从给定的源图像生成目标图像。 我们拥有的是来自源域和目标域的样本数据。“图 7.1.2”是来自同一向日葵主题上源域(真实照片)和目标域(范高的艺术风格)的数据示例。 源图像和目标图像不一定对齐。
与 pix2pix 不同,CycleGAN 会学习图像翻译,只要源数据和目标数据之间有足够的数量和差异即可。 无需对齐。 CycleGAN 学习源和目标分布,以及如何从给定的样本数据中将源分布转换为目标分布。 无需监督。 在“图 7.1.2”的上下文中,我们只需要数千张真实向日葵的照片和数千张梵高向日葵画的照片。 在训练了 CycleGAN 之后,我们可以将向日葵的照片转换成梵高的画作:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wFIkXXou-1681704311653)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_02.png)]
图 7.1.2:未对齐的图像对示例:左侧为菲律宾大学沿着大学大道的真实向日葵照片,右侧为伦敦国家美术馆的梵高的向日葵, 英国。 原始照片由作者拍摄。
下一个问题是:我们如何建立可以从未配对数据中学习的模型? 在下一部分中,我们将构建一个使用正向和反向循环 GAN 的 CycleGAN,以及一个循环一致性检查,以消除对配对输入数据的需求。
CycleGAN 模型
“图 7.1.3”显示了 CycleGAN 的网络模型:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kiIRicws-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_03.png)]
图 7.1.3:CycleGAN 模型包含四个网络:生成器G
,生成器F
,判别器D[y]
和判别器D[x]
让我们逐个讨论“图 7.1.3”。 让我们首先关注上层网络,即转发周期 GAN。 如下图“图 7.1.4”所示,正向循环 CycleGAN 的目标是学习以下函数:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gsjmUDa2-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_001.png)] (Equation 7.1.1)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NOQ9T982-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_04.png)]
图 7.1.4:伪造y
的 CycleGAN 生成器G
“公式 7.1.1”只是假目标数据y'
的生成器G
。 它将数据从源域x
转换为目标域y
。
要训练生成器,我们必须构建 GAN。 这是正向循环 GAN,如图“图 7.1.5”所示。 该图表明,它类似于“第 4 章”,“生成对抗网络(GANs)”中的典型 GAN,由生成器G
和判别器D[y]
组成,它可以以相同的对抗方式进行训练。通过仅利用源域中的可用实际图像x
和目标域中的实际图像y
,进行无监督学习。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SC3RgcGJ-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_05.png)]
图 7.1.5:CycleGAN 正向循环 GAN
与常规 GAN 不同,CycleGAN 施加了周期一致性约束,如图“图 7.1.6”所示。 前向循环一致性网络可确保可以从伪造的目标数据中重建真实的源数据:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CBQIl6q7-1681704311654)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_004.png)] (Equation 7.1.2)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QlWNFqm4-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_06.png)]
图 7.1.6:CycleGAN 循环一致性检查
通过最小化正向循环一致性 L1 损失来完成:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ekYw53po-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_005.png)] (Equation 7.1.3)
周期一致性损失使用 L1 或平均绝对误差(MAE),因为与 L2 或均方误差(MSE)相比,它通常导致较少的模糊图像重建。
循环一致性检查表明,尽管我们已将源数据x
转换为域y
,但x
的原始特征仍应保留在y
中并且可恢复。 网络F
只是我们将从反向循环 GAN 借用的另一个生成器,如下所述。
CycleGAN 是对称的。 如图“图 7.1.7”所示,后向循环 GAN 与前向循环 GAN 相同,但将源数据x
和目标数据y
的作用逆转。 现在,源数据为y
,目标数据为x
。 生成器G
和F
的作用也相反。F
现在是生成器,而G
恢复输入。 在正向循环 GAN 中,生成器F
是用于恢复源数据的网络,而G
是生成器。
Backward Cycle GAN 生成器的目标是合成:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8l4O360i-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_006.png)] (Equation 7.1.2)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fLcG2jW4-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_07.png)]
图 7.1.7:CycleGAN 向后循环 GAN
这可以通过对抗性训练反向循环 GAN 来完成。 目的是让生成器F
学习如何欺骗判别器D[x]
。
此外,还具有类似的向后循环一致性,以恢复原始源y
:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mvyIpDFe-1681704311655)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_008.png)] (Equation 7.1.4)
这是通过最小化后向循环一致性 L1 损失来完成的:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rxcsZxBP-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_009.png)] (Equation 7.1.5)
总而言之,CycleGAN 的最终目标是使生成器G
学习如何合成伪造的目标数据y'
,该伪造的目标数据y'
会在正向循环中欺骗识别器D[y]
。 由于网络是对称的,因此 CycleGAN 还希望生成器F
学习如何合成伪造的源数据x'
,该伪造的源数据可以使判别器D[x]
在反向循环中蒙蔽。 考虑到这一点,我们现在可以将所有损失函数放在一起。
让我们从 GAN 部分开始。 受到最小二乘 GAN(LSGAN) [5]更好的感知质量的启发,如“第 5 章”,“改进的 GAN” 中所述,CycleGAN 还使用 MSE 作为判别器和生成器损失。 回想一下,LSGAN 与原始 GAN 之间的差异需要使用 MSE 损失,而不是二进制交叉熵损失。
CycleGAN 将生成器-标识符损失函数表示为:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YumlqfLU-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_014.png)] (Equation 7.1.6)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vcjHdscP-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_015.png)] (Equation 7.1.7)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D1jqJXMo-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_016.png)] (Equation 7.1.8)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oXpqJ2Ij-1681704311656)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_017.png)] (Equation 7.1.9)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fIfSu1cM-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_018.png)] (Equation 7.1.10)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VK9QUtYW-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_019.png)] (Equation 7.1.11)
损失函数的第二组是周期一致性损失,可以通过汇总前向和后向 GAN 的贡献来得出:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1dxsG5PQ-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_020.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-apnJZhCA-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_021.png)] (Equation 7.1.12)
CycleGAN 的总损失为:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3TdRzKpk-1681704311657)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_022.png)] (Equation 7.1.13)
CycleGAN 建议使用以下权重值λ1 = 1.0
和λ2 = 10.0
,以更加重视循环一致性检查。
训练策略类似于原始 GAN。 “算法 7.1.1”总结了 CycleGAN 训练过程。
“算法 7.1.1”:CycleGAN 训练
对n
训练步骤重复上述步骤:
- 通过使用真实的源数据和目标数据训练前向循环判别器,将
L_forward_GAN^(D)
降至最低。 实际目标数据的小批量y
标记为 1.0。 伪造的目标数据y' = G(x)
的小批量标记为 0.0。 - 通过使用真实的源数据和目标数据训练反向循环判别器,将
L_backward_GAN^(D)
最小化。 实际源数据的小批量x
标记为 1.0。 一小部分伪造的源数据x' = F(y)
被标记为 0.0。 - 通过训练对抗网络中的前向周期和后向周期生成器,将
L_GAN^(D)
和L_cyc
最小化。 伪造目标数据的一个小批量y' = G(x)
被标记为 1.0。 一小部分伪造的源数据x' = F(y)
被标记为 1.0。 判别器的权重被冻结。
在神经样式迁移问题中,颜色组合可能无法成功地从源图像迁移到伪造目标图像。 此问题显示在“图 7.1.8”中:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TtNbGPMZ-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_08.png)]
图 7.1.8:在样式迁移过程中,颜色组合可能无法成功迁移。 为了解决此问题,将恒等损失添加到总损失函数中
为了解决这个问题,CycleGAN 建议包括正向和反向循环身份损失函数:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-akghAJMo-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_033.png)] (Equation 7.1.14)
CycleGAN 的总损失变为:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-abCvM3he-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_034.png)] (Equation 7.1.15)
其中λ3 = 0.5
。 在对抗训练中,身份损失也得到了优化。“图 7.1.9”重点介绍了实现身份正则器的 CycleGAN 辅助网络:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jn1AhAKg-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_09.png)]
图 7.1.9:具有身份正则化网络的 CycleGAN 模型,图像左侧突出显示
在下一个部分,我们将在tf.keras
中实现 CycleGAN。
使用 Keras 实现 CycleGAN
我们来解决,这是 CycleGAN 可以解决的简单问题。 在“第 3 章”,“自编码器”中,我们使用了自编码器为 CIFAR10 数据集中的灰度图像着色。 我们可以记得,CIFAR10 数据集包含 50,000 个训练过的数据项和 10,000 个测试数据样本,这些样本属于 10 个类别的32 x 32
RGB 图像。 我们可以使用rgb2gray
(RGB)将所有彩色图像转换为灰度图像,如“第 3 章”,“自编码器”中所述。
接下来,我们可以将灰度训练图像用作源域图像,将原始彩色图像用作目标域图像。 值得注意的是,尽管数据集是对齐的,但我们 CycleGAN 的输入是彩色图像的随机样本和灰度图像的随机样本。 因此,我们的 CycleGAN 将看不到训练数据对齐。 训练后,我们将使用测试的灰度图像来观察 CycleGAN 的表现。
如前几节所述,要实现 CycleGAN,我们需要构建两个生成器和两个判别器。 CycleGAN 的生成器学习源输入分布的潜在表示,并将该表示转换为目标输出分布。 这正是自编码器的功能。 但是,类似于“第 3 章”,“自编码器”中讨论的典型自编码器,使用的编码器会对输入进行下采样,直到瓶颈层为止,此时解码器中的处理过程相反。
由于在编码器和解码器层之间共享许多低级特征,因此该结构不适用于某些图像转换问题。 例如,在着色问题中,灰度图像的形式,结构和边缘与彩色图像中的相同。 为了解决这个问题,CycleGAN 生成器使用 U-Net [7]结构,如图“图 7.1.10”所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YCKH97o8-1681704311658)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_10.png)]
图 7.1.10:在 Keras 中实现正向循环生成器G
。 产生器是包括编码器和解码器的 U 网络[7]。
在 U-Net 结构中,编码器层的输出e[ni]
与解码器层的输出d[i]
,其中n = 4
是编码器/解码器的层数,i = 1, 2, 3
是共享信息的层号。
我们应该注意,尽管该示例使用n = 4
,但输入/输出尺寸较大的问题可能需要更深的编码器/解码器层。 通过 U-Net 结构,可以在编码器和解码器之间自由迁移特征级别的信息。
编码器层由Instance Normalization(IN)-LeakyReLU-Conv2D
组成,而解码器层由IN-ReLU-Conv2D
组成。 编码器/解码器层的实现如清单 7.1.1 所示,而生成器的实现如列表 7.1.2 所示。
实例规范化(IN)是每个数据(即 IN 是图像或每个特征的 BN)。 在样式迁移中,重要的是标准化每个样本而不是每个批量的对比度。 IN 等于,相当于对比度归一化。 同时,BN 打破了对比度标准化。
记住在使用 IN 之前先安装tensorflow-addons
:
$ pip install tensorflow-addons
“列表 7.1.1”:cyclegan-7.1.1.py
def encoder_layer(inputs, filters=16, kernel_size=3, strides=2, activation='relu', instance_norm=True): """Builds a generic encoder layer made of Conv2D-IN-LeakyReLU IN is optional, LeakyReLU may be replaced by ReLU """
conv = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')
x = inputs if instance_norm: x = InstanceNormalization(axis=3)(x) if activation == 'relu': x = Activation('relu')(x) else: x = LeakyReLU(alpha=0.2)(x) x = conv(x) return x
def decoder_layer(inputs, paired_inputs, filters=16, kernel_size=3, strides=2, activation='relu', instance_norm=True): """Builds a generic decoder layer made of Conv2D-IN-LeakyReLU IN is optional, LeakyReLU may be replaced by ReLU Arguments: (partial) inputs (tensor): the decoder layer input paired_inputs (tensor): the encoder layer output provided by U-Net skip connection & concatenated to inputs. """
conv = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')
x = inputs if instance_norm: x = InstanceNormalization(axis=3)(x) if activation == 'relu': x = Activation('relu')(x) else: x = LeakyReLU(alpha=0.2)(x) x = conv(x) x = concatenate([x, paired_inputs]) return x
将移至生成器实现中:
“列表 7.1.2”:cyclegan-7.1.1.py
Keras 中的生成器实现:
def build_generator(input_shape, output_shape=None, kernel_size=3, name=None): """The generator is a U-Network made of a 4-layer encoder and a 4-layer decoder. Layer n-i is connected to layer i.
Arguments: input_shape (tuple): input shape output_shape (tuple): output shape kernel_size (int): kernel size of encoder & decoder layers name (string): name assigned to generator model
Returns: generator (Model): """
inputs = Input(shape=input_shape) channels = int(output_shape[-1]) e1 = encoder_layer(inputs, 32, kernel_size=kernel_size, activation='leaky_relu', strides=1) e2 = encoder_layer(e1, 64, activation='leaky_relu', kernel_size=kernel_size) e3 = encoder_layer(e2, 128, activation='leaky_relu', kernel_size=kernel_size) e4 = encoder_layer(e3, 256, activation='leaky_relu', kernel_size=kernel_size)
d1 = decoder_layer(e4, e3, 128, kernel_size=kernel_size) d2 = decoder_layer(d1, e2, 64, kernel_size=kernel_size) d3 = decoder_layer(d2, e1, 32, kernel_size=kernel_size) outputs = Conv2DTranspose(channels, kernel_size=kernel_size, strides=1, activation='sigmoid', padding='same')(d3)
generator = Model(inputs, outputs, name=name)
return generator
CycleGAN 的判别器类似于原始 GAN 判别器。 输入图像被下采样数次(在此示例中为 3 次)。 最后一层是Dense
(1)层,它预测输入为实数的可能性。 除了不使用 IN 之外,每个层都类似于生成器的编码器层。 然而,在大图像中,用一个数字将图像计算为真实图像或伪图像会导致参数效率低下,并导致生成器的图像质量较差。
解决方案是使用 PatchGAN [6],该方法将图像划分为补丁网格,并使用标量值网格来预测补丁是真实概率。“图 7.1.11”显示了原始 GAN 判别器和2 x 2
PatchGAN 判别器之间的比较:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-89c04LPV-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_11.png)]
图 7.1.11:GAN 与 PatchGAN 判别器的比较
在此示例中,面片不重叠且在其边界处相遇。 但是,通常,补丁可能会重叠。
我们应该注意,PatchGAN 并没有在 CycleGAN 中引入一种新型的 GAN。 为了提高生成的图像质量,如果使用2 x 2
PatchGAN,则没有四个输出可以区分,而没有一个输出可以区分。 损失函数没有变化。 从直觉上讲,这是有道理的,因为如果图像的每个面片或部分看起来都是真实的,则整个图像看起来会更加真实。
“图 7.1.12”显示了tf.keras
中实现的判别器网络。 下图显示了判别器确定输入图像或色块为彩色 CIFAR10 图像的可能性:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aQ9IgfTU-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_12.png)]
图 7.1.12:目标标识符D[y]
在tf.keras
中的实现。 PatchGAN 判别器显示在右侧
由于输出图像只有32 x 32
RGB 时较小,因此表示该图像是真实的单个标量就足够了。 但是,当使用 PatchGAN 时,我们也会评估结果。“列表 7.1.3”显示了判别器的函数构建器:
“列表 7.1.3”:cyclegan-7.1.1.py
tf.keras
中的判别器实现:
def build_discriminator(input_shape, kernel_size=3, patchgan=True, name=None): """The discriminator is a 4-layer encoder that outputs either a 1-dim or a n x n-dim patch of probability that input is real
Arguments: input_shape (tuple): input shape kernel_size (int): kernel size of decoder layers patchgan (bool): whether the output is a patch or just a 1-dim name (string): name assigned to discriminator model
Returns: discriminator (Model): """
inputs = Input(shape=input_shape) x = encoder_layer(inputs, 32, kernel_size=kernel_size, activation='leaky_relu', instance_norm=False) x = encoder_layer(x, 64, kernel_size=kernel_size, activation='leaky_relu', instance_norm=False) x = encoder_layer(x, 128, kernel_size=kernel_size, activation='leaky_relu', instance_norm=False) x = encoder_layer(x, 256, kernel_size=kernel_size, strides=1, activation='leaky_relu', instance_norm=False)
# if patchgan=True use nxn-dim output of probability # else use 1-dim output of probability if patchgan: x = LeakyReLU(alpha=0.2)(x) outputs = Conv2D(1, kernel_size=kernel_size, strides=2, padding='same')(x) else: x = Flatten()(x) x = Dense(1)(x) outputs = Activation('linear')(x)
discriminator = Model(inputs, outputs, name=name)
return discriminator
使用生成器和判别器生成器,我们现在可以构建 CycleGAN。“列表 7.1.4”显示了构建器函数。 与上一节中的讨论一致,实例化了两个生成器g_source = F
和g_target = G
以及两个判别器d_source = D[x]
和d_target = D[y]
。 正向循环为x' = F(G(x)) = reco_source = g_source(g_target(source_input))
。反向循环为y' = G(F(y)) = reco_target = g_target(g_source (target_input))
。
对抗模型的输入是源数据和目标数据,而输出是D[x]
和D[y]
的输出以及重构的输入x'
和y'
。 在本示例中,由于由于灰度图像和彩色图像中通道数之间的差异,因此未使用身份网络。 对于 GAN 和循环一致性损失,我们分别使用建议的λ1 = 1.0
和λ2 = 10.0
损失权重。 与前几章中的 GAN 相似,我们使用 RMSprop 作为判别器的优化器,其学习率为2e-4
,衰减率为6e-8
。 对抗的学习率和衰退率是判别器的一半。
“列表 7.1.4”:cyclegan-7.1.1.py
tf.keras
中的 CycleGAN 构建器:
def build_cyclegan(shapes, source_name='source', target_name='target', kernel_size=3, patchgan=False, identity=False ): """Build the CycleGAN
1) Build target and source discriminators 2) Build target and source generators 3) Build the adversarial network
Arguments: shapes (tuple): source and target shapes source_name (string): string to be appended on dis/gen models target_name (string): string to be appended on dis/gen models kernel_size (int): kernel size for the encoder/decoder or dis/gen models patchgan (bool): whether to use patchgan on discriminator identity (bool): whether to use identity loss
Returns: (list): 2 generator, 2 discriminator, and 1 adversarial models """
source_shape, target_shape = shapes lr = 2e-4 decay = 6e-8 gt_name = "gen_" + target_name gs_name = "gen_" + source_name dt_name = "dis_" + target_name ds_name = "dis_" + source_name
# build target and source generators g_target = build_generator(source_shape, target_shape, kernel_size=kernel_size, name=gt_name) g_source = build_generator(target_shape, source_shape, kernel_size=kernel_size, name=gs_name) print('---- TARGET GENERATOR ----') g_target.summary() print('---- SOURCE GENERATOR ----') g_source.summary()
# build target and source discriminators d_target = build_discriminator(target_shape, patchgan=patchgan, kernel_size=kernel_size, name=dt_name) d_source = build_discriminator(source_shape, patchgan=patchgan, kernel_size=kernel_size, name=ds_name) print('---- TARGET DISCRIMINATOR ----') d_target.summary() print('---- SOURCE DISCRIMINATOR ----') d_source.summary()
optimizer = RMSprop(lr=lr, decay=decay) d_target.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) d_source.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
d_target.trainable = False d_source.trainable = False
# build the computational graph for the adversarial model # forward cycle network and target discriminator source_input = Input(shape=source_shape) fake_target = g_target(source_input) preal_target = d_target(fake_target) reco_source = g_source(fake_target)
# backward cycle network and source discriminator target_input = Input(shape=target_shape) fake_source = g_source(target_input) preal_source = d_source(fake_source) reco_target = g_target(fake_source)
# if we use identity loss, add 2 extra loss terms # and outputs if identity: iden_source = g_source(source_input) iden_target = g_target(target_input) loss = ['mse', 'mse', 'mae', 'mae', 'mae', 'mae'] loss_weights = [1., 1., 10., 10., 0.5, 0.5] inputs = [source_input, target_input] outputs = [preal_source, preal_target, reco_source, reco_target, iden_source, iden_target] else: loss = ['mse', 'mse', 'mae', 'mae'] loss_weights = [1., 1., 10., 10.] inputs = [source_input, target_input] outputs = [preal_source, preal_target, reco_source, reco_target]
# build adversarial model adv = Model(inputs, outputs, name='adversarial') optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5) adv.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) print('---- ADVERSARIAL NETWORK ----') adv.summary()
return g_source, g_target, d_source, d_target, adv
TensorFlow 2 和 Keras 高级深度学习:6~10(3)https://developer.aliyun.com/article/1426952