9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,5000CU*H 3个月
简介: 生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。

生成对抗网络(GANs)的训练效果很大程度上取决于其损失函数的选择。本研究首先介绍经典GAN损失函数的理论基础,随后使用PyTorch实现包括原始GAN、最小二乘GAN(LS-GAN)、Wasserstein GAN(WGAN)及带梯度惩罚的WGAN(WGAN-GP)在内的多种损失函数。
生成对抗网络(GANs)的工作原理堪比一场精妙的艺术创作过程——生成器(Generator)扮演创作者角色,不断生成作品;判别器(Discriminator)则如同严苛的评论家,持续提供改进建议。这种对抗学习机制促使两个网络在竞争中共同进步。判别器向生成器提供反馈的方式——即损失函数的设计——对整个网络的学习表现有着决定性影响。

GAN的基本原理与经典损失函数

1、原始GAN

Goodfellow等人于2014年提出的原始GAN采用极小极大博弈(Minimax Game)框架,其损失函数可表述为:

其中:

  • D(x)表示判别器对输入x判定为真实样本的概率
  • G(z)表示生成器将随机噪声z转换为合成图像的函数
  • p_{data}(x)表示真实数据分布
  • p_z(z)表示噪声先验分布,通常为标准正态分布

原始GAN在理论上试图最小化生成分布与真实分布之间的Jensen-Shannon散度(JS散度),但在实际训练中存在梯度消失、模式崩溃和训练不稳定等问题。这些局限性促使研究者开发了多种改进的损失函数。

PyTorch实现

 import torch
import torch.nn as nn

# 原始GAN损失函数实现
class OriginalGANLoss:
    def __init__(self, device):
        self.device = device
        self.criterion = nn.BCELoss()

    def discriminator_loss(self, real_output, fake_output):
        # 真实样本的目标标签为1.0
        real_labels = torch.ones_like(real_output, device=self.device)
        # 生成样本的目标标签为0.0
        fake_labels = torch.zeros_like(fake_output, device=self.device)

        # 计算判别器对真实样本的损失
        real_loss = self.criterion(real_output, real_labels)
        # 计算判别器对生成样本的损失
        fake_loss = self.criterion(fake_output, fake_labels)

        # 总损失为两部分之和
        d_loss = real_loss + fake_loss
        return d_loss

    def generator_loss(self, fake_output):
        # 生成器希望判别器将生成样本判断为真实样本
        target_labels = torch.ones_like(fake_output, device=self.device)
        g_loss = self.criterion(fake_output, target_labels)
         return g_loss

2、非饱和损失函数(Non-Saturating Loss)

为解决原始GAN中生成器梯度消失问题,Goodfellow提出了非饱和损失,将生成器的目标函数修改为:

这种修改保持了相同的最优解,但提供了更强的梯度信号,特别是在训练初期生成样本质量较差时,有效改善了学习效率。非饱和损失通过直接最大化判别器对生成样本的预测概率,而不是最小化判别器正确分类的概率,从而避免了在生成器表现不佳时梯度趋近于零的问题。

PyTorch实现

 class NonSaturatingGANLoss:
    def __init__(self, device):
        self.device = device
        self.criterion = nn.BCELoss()

    def discriminator_loss(self, real_output, fake_output):
        # 与原始GAN相同
        real_labels = torch.ones_like(real_output, device=self.device)
        fake_labels = torch.zeros_like(fake_output, device=self.device)

        real_loss = self.criterion(real_output, real_labels)
        fake_loss = self.criterion(fake_output, fake_labels)

        d_loss = real_loss + fake_loss
        return d_loss

    def generator_loss(self, fake_output):
        # 非饱和损失:直接最大化log(D(G(z)))
        target_labels = torch.ones_like(fake_output, device=self.device)
        # 注意这里使用的是相同的BCE损失,但目标是让D将G(z)判断为真
        g_loss = self.criterion(fake_output, target_labels)
         return g_loss

GAN变体实现与原理分析

3、最小二乘GAN(LS-GAN)

LS-GAN通过用最小二乘损失替代标准GAN中的二元交叉熵损失,有效改善了训练过程:

这一修改使得模型在训练过程中梯度变化更为平滑,显著降低了训练不稳定性。LS-GAN的主要优势在于能够有效减轻模式崩溃问题(即生成器仅产生有限类型样本的现象),同时促进学习过程的连续性与稳定性,使模型能够更加渐进地学习数据分布特征。理论上,LS-GAN试图最小化Pearson \chi^2 散度,这对于分布重叠较少的情况提供了更好的训练信号。

PyTorch实现

 class LSGANLoss:
    def __init__(self, device):
        self.device = device
        # LS-GAN使用MSE损失而非BCE损失
        self.criterion = nn.MSELoss()

    def discriminator_loss(self, real_output, fake_output):
        # 真实样本的目标值为1.0
        real_labels = torch.ones_like(real_output, device=self.device)
        # 生成样本的目标值为0.0
        fake_labels = torch.zeros_like(fake_output, device=self.device)

        # 计算真实样本的MSE损失
        real_loss = self.criterion(real_output, real_labels)
        # 计算生成样本的MSE损失
        fake_loss = self.criterion(fake_output, fake_labels)

        d_loss = real_loss + fake_loss
        return d_loss

    def generator_loss(self, fake_output):
        # 生成器希望生成的样本被判别为真实样本
        target_labels = torch.ones_like(fake_output, device=self.device)
        g_loss = self.criterion(fake_output, target_labels)
         return g_loss

4、Wasserstein GAN(WGAN)

WGAN通过引入Wasserstein距离(也称为地球移动者距离)作为分布差异度量,从根本上改变了GAN的训练机制:

其中\mathcal{D}是所有满足1-Lipschitz约束的函数集合。与传统GAN关注样本真假二分类不同,WGAN评估的是生成分布与真实分布之间的距离,这一方法提供了更为连续且有意义的梯度信息。WGAN能够显著改善梯度传播问题,有效防止判别器过度主导训练过程,同时大幅减轻模式崩溃现象,提高生成样本的多样性。

原始WGAN通过权重裁剪(weight clipping)实现Lipschitz约束,具体做法是将判别器参数限制在某个固定范围内,如[-c, c],但这种方法可能会限制网络容量并导致病态行为。

PyTorch实现

 class WGANLoss:
    def __init__(self, device, clip_value=0.01):
        self.device = device
        self.clip_value = clip_value

    def discriminator_loss(self, real_output, fake_output):
        # WGAN的判别器(称为critic)直接最大化真实样本和生成样本输出的差值
        # 注意这里没有使用sigmoid激活
        d_loss = -torch.mean(real_output) + torch.mean(fake_output)
        return d_loss

    def generator_loss(self, fake_output):
        # 生成器希望最大化critic对生成样本的评分
        g_loss = -torch.mean(fake_output)
        return g_loss

    def weight_clipping(self, critic):
        # 权重裁剪,限制critic参数范围
        for p in critic.parameters():
             p.data.clamp_(-self.clip_value, self.clip_value)

5、带梯度惩罚的WGAN(WGAN-GP)

WGAN-GP是对WGAN的进一步优化,通过引入梯度惩罚项来满足Lipschitz连续性约束:

其中\hat{x}是真实样本和生成样本之间的随机插值点。这一改进避免了原始WGAN中权重裁剪可能带来的容量限制和训练不稳定问题。梯度惩罚使模型训练过程更加稳定,同时减少了生成图像中的伪影,提高了最终生成结果的质量与真实度。WGAN-GP已成为许多高质量图像生成任务的首选损失函数。

PyTorch实现

 class WGANGP:
    def __init__(self, device, lambda_gp=10):
        self.device = device
        self.lambda_gp = lambda_gp

    def discriminator_loss(self, real_output, fake_output, real_samples, fake_samples, discriminator):
        # 基本的Wasserstein距离
        d_loss = -torch.mean(real_output) + torch.mean(fake_output)

        # 计算梯度惩罚
        # 在真实和生成样本之间随机插值
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
        interpolates = alpha * real_samples + (1 - alpha) * fake_samples
        interpolates.requires_grad_(True)

        # 计算判别器对插值样本的输出
        d_interpolates = discriminator(interpolates)

        # 计算梯度
        fake_outputs = torch.ones_like(d_interpolates, device=self.device, requires_grad=False)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        # 计算梯度L2范数
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        # 添加梯度惩罚项
        d_loss = d_loss + self.lambda_gp * gradient_penalty

        return d_loss

    def generator_loss(self, fake_output):
        # 与WGAN相同
        g_loss = -torch.mean(fake_output)
         return g_loss

6、条件生成对抗网络(CGAN)

CGAN通过在生成器和判别器中引入条件信息y(如类别标签),实现对生成过程的控制:

CGAN能够生成特定类别的样本,大大增强了模型的实用性,特别是在医学影像等需要精确控制生成内容的应用场景中。通过条件控制,CGAN可以引导生成过程,使得生成结果满足特定的语义或结构要求,为个性化内容生成提供了可靠技术支持。

PyTorch实现

 class CGANLoss:
    def __init__(self, device):
        self.device = device
        self.criterion = nn.BCELoss()

    def discriminator_loss(self, real_output, fake_output):
        # 条件GAN的判别器损失与原始GAN相似,只是输入增加了条件信息
        real_labels = torch.ones_like(real_output, device=self.device)
        fake_labels = torch.zeros_like(fake_output, device=self.device)

        real_loss = self.criterion(real_output, real_labels)
        fake_loss = self.criterion(fake_output, fake_labels)

        d_loss = real_loss + fake_loss
        return d_loss

    def generator_loss(self, fake_output):
        # 与原始GAN相似
        target_labels = torch.ones_like(fake_output, device=self.device)
        g_loss = self.criterion(fake_output, target_labels)
        return g_loss

# CGAN的网络结构示例
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, n_classes, img_shape):
        super(ConditionalGenerator, self).__init__()
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            # 输入是噪声向量与条件拼接后的向量
            nn.Linear(latent_dim + n_classes, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # 条件嵌入
        c = self.label_emb(labels)
        # 拼接噪声和条件
        x = torch.cat([z, c], 1)
        # 生成图像
        img = self.model(x)
        img = img.view(img.size(0), *self.img_shape)
         return img

7、信息最大化GAN(InfoGAN)

InfoGAN在无监督学习框架下实现了对生成样本特定属性的控制,其核心思想是最大化潜在编码c与生成样本G(z,c)之间的互信息:

其中Q是一个辅助网络,用于近似后验分布P(c|x),而I(c;G(z,c))表示互信息。InfoGAN能够在无监督的情况下学习数据的解耦表示,对于医学图像分析中的特征提取和异常检测具有潜在价值。

PyTorch实现

 class InfoGANLoss:
    def __init__(self, device, lambda_info=1.0):
        self.device = device
        self.criterion = nn.BCELoss()
        self.lambda_info = lambda_info
        # 对于离散潜变量使用交叉熵损失
        self.discrete_criterion = nn.CrossEntropyLoss()
        # 对于连续潜变量使用高斯分布负对数似然
        self.continuous_criterion = nn.MSELoss()

    def discriminator_loss(self, real_output, fake_output):
        # 判别器损失与原始GAN相同
        real_labels = torch.ones_like(real_output, device=self.device)
        fake_labels = torch.zeros_like(fake_output, device=self.device)

        real_loss = self.criterion(real_output, real_labels)
        fake_loss = self.criterion(fake_output, fake_labels)

        d_loss = real_loss + fake_loss
        return d_loss

    def generator_info_loss(self, fake_output, q_discrete, q_continuous, c_discrete, c_continuous):
        # 生成器损失部分(欺骗判别器)
        target_labels = torch.ones_like(fake_output, device=self.device)
        g_loss = self.criterion(fake_output, target_labels)

        # 互信息损失部分
        # 离散潜变量的互信息损失
        info_disc_loss = self.discrete_criterion(q_discrete, c_discrete)
        # 连续潜变量的互信息损失
        info_cont_loss = self.continuous_criterion(q_continuous, c_continuous)

        # 总损失
        total_loss = g_loss + self.lambda_info * (info_disc_loss + info_cont_loss)

         return total_loss, info_disc_loss, info_cont_loss

8、能量基础GAN(EBGAN)

EBGAN将判别器视为能量函数,而非传统的概率函数,其损失函数为:

其中[·]^+表示\max(0,·),m是边界参数。EBGAN通过能量视角重新诠释GAN训练过程,为模型设计提供了新的思路,尤其适合处理具有复杂分布的医学数据。EBGAN的判别器不再输出概率值,而是输出能量分数,真实样本的能量应当低于生成样本。

PyTorch实现

 class EBGANLoss:
    def __init__(self, device, margin=10.0):
        self.device = device
        self.margin = margin

    def discriminator_loss(self, real_energy, fake_energy):
        # 判别器的目标是降低真实样本的能量,提高生成样本的能量(直到边界值)
        # 对生成样本的损失使用hinge loss
        hinge_loss = torch.mean(torch.clamp(self.margin - fake_energy, min=0))
        # 总损失
        d_loss = torch.mean(real_energy) + hinge_loss
        return d_loss

    def generator_loss(self, fake_energy):
        # 生成器的目标是降低生成样本的能量
        g_loss = torch.mean(fake_energy)
         return g_loss

9、f-GAN

f-GAN是一种基于f-散度的GAN框架,可以统一多种GAN变体:

其中f^*是凸函数f的Fenchel共轭。通过选择不同的f函数,f-GAN可以实现对不同散度的优化,如KL散度、JS散度、Hellinger距离等,为特定应用场景提供了更灵活的选择。f-GAN为GAN提供了一个统一的理论框架,使研究者能够根据具体任务需求设计最适合的散度度量。

PyTorch实现

 class FGANLoss:
    def __init__(self, device, divergence_type='kl'):
        self.device = device
        self.divergence_type = divergence_type

    def activation_function(self, x):
        # 不同散度对应的激活函数
        if self.divergence_type == 'kl':  # KL散度
            return x
        elif self.divergence_type == 'js':  # JS散度
            return torch.log(1 + torch.exp(x))
        elif self.divergence_type == 'hellinger':  # Hellinger距离
            return 1 - torch.exp(-x)
        elif self.divergence_type == 'total_variation':  # 总变差距离
            return 0.5 * torch.tanh(x)
        else:
            return x  # 默认为KL散度

    def conjugate_function(self, x):
        # 不同散度的Fenchel共轭
        if self.divergence_type == 'kl':
            return torch.exp(x - 1)
        elif self.divergence_type == 'js':
            return -torch.log(2 - torch.exp(x))
        elif self.divergence_type == 'hellinger':
            return x / (1 - x)
        elif self.divergence_type == 'total_variation':
            return x
        else:
            return torch.exp(x - 1)  # 默认为KL散度

    def discriminator_loss(self, real_output, fake_output):
        # 判别器损失
        # 注意:在f-GAN中,通常D的输出需要经过激活函数处理
        activated_real = self.activation_function(real_output)
        d_loss = -torch.mean(activated_real) + torch.mean(self.conjugate_function(fake_output))
        return d_loss

    def generator_loss(self, fake_output):
        # 生成器损失
        activated_fake = self.activation_function(fake_output)
        g_loss = -torch.mean(activated_fake)
         return g_loss

总结

本文通过详细分析GAN的经典损失函数及其多种变体,揭示了不同类型损失函数各自的优势:LS-GAN训练稳定性好,WGAN-GP生成图像清晰度高,而条件类GAN如CGAN则在可控性方面表现突出。

这介绍代码对于相关领域的GAN应用具有重要参考价值。未来研究可进一步探索损失函数组合优化策略,以及针对特定图像模态的自适应损失函数设计。

目录
相关文章
|
11天前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
97 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
4天前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
55 17
|
1月前
|
存储 自然语言处理 PyTorch
从零开始用Pytorch实现LLaMA 4的混合专家(MoE)模型
近期发布的LLaMA 4模型引入混合专家(MoE)架构,以提升效率与性能。尽管社区对其实际表现存在讨论,但MoE作为重要设计范式再次受到关注。本文通过Pytorch从零实现简化版LLaMA 4 MoE模型,涵盖数据准备、分词、模型构建(含词元嵌入、RoPE、RMSNorm、多头注意力及MoE层)到训练与文本生成全流程。关键点包括MoE层实现(路由器、专家与共享专家)、RoPE处理位置信息及RMSNorm归一化。虽规模小于实际LLaMA 4,但清晰展示MoE核心机制:动态路由与稀疏激活专家,在控制计算成本的同时提升性能。完整代码见链接,基于FareedKhan-dev的Github代码修改而成。
63 9
从零开始用Pytorch实现LLaMA 4的混合专家(MoE)模型
|
1月前
|
机器学习/深度学习 数据可视化 机器人
比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
扩散模型和流匹配是生成高分辨率数据(如图像和机器人轨迹)的先进技术。扩散模型通过逐步去噪生成数据,其代表应用Stable Diffusion已扩展至机器人学领域形成“扩散策略”。流匹配作为更通用的方法,通过学习时间依赖的速度场将噪声转化为目标分布,适用于图像生成和机器人轨迹生成,且通常以较少资源实现更快生成。 本文深入解析流匹配在图像生成中的应用,核心思想是将图像视为随机变量的实现,并通过速度场将源分布转换为目标分布。文中提供了一维模型训练实例,展示了如何用神经网络学习速度场,以及使用最大均值差异(MMD)改进训练效果。与扩散模型相比,流匹配结构简单,资源需求低,适合多模态分布生成。
88 13
比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
|
1月前
|
机器学习/深度学习 编解码 PyTorch
从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
本文介绍了一种基于扩散模型的文本到视频生成系统,详细展示了模型架构、训练流程及生成效果。通过3D U-Net结构和多头注意力机制,模型能够根据文本提示生成高质量视频。
75 1
从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
|
3月前
|
机器学习/深度学习 算法 安全
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
423 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
|
4月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
311 66
|
7月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
1002 2
|
9月前
|
机器学习/深度学习 并行计算 PyTorch
优化技巧与策略:提高 PyTorch 模型训练效率
【8月更文第29天】在深度学习领域中,PyTorch 是一个非常流行的框架,被广泛应用于各种机器学习任务中。然而,随着模型复杂度的增加以及数据集规模的增长,如何有效地训练这些模型成为了一个重要的问题。本文将介绍一系列优化技巧和策略,帮助提高 PyTorch 模型训练的效率。
794 0
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
181 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers