使用GAN进行异常检测

简介: 自从基于Stable Diffusion的生成模型大火以后,基于GAN的研究越来越少了,但是这并不能说明他就没有用了。异常检测是多个研究领域面临的重要问题,包括金融、医疗保健和网络安全。

检测和正确分类未见的异常是一个具有挑战性的问题,多年来已经以许多不同的方式解决了这个问题。而今天我们要介绍一种基于GAN的异常检测方法,GAN是一种深度学习模型,可以学习生成与给定数据集相似的真实数据样本。GAN的这一特性表明它们可以成功地用于异常检测,以前的基于GAN的生成模型都是使用GAN的生成器,而异常检测则是需要使用GAN的鉴别器。

GAN简介

生成对抗网络(GANs)是一类用于无监督机器学习的人工智能算法。它们是由Ian Goodfellow和他的同事在2014年推出的。GANs由生成器和鉴别器两个神经网络组成,它们通过对抗性训练同时进行训练。

生成器:GAN的这一部分负责生成新的数据实例。它将随机噪声作为输入,并将其转换成理想情况下与真实数据无法区分的数据。

鉴别器:GAN的这一部分充当分类器。它被训练来区分真实数据和由生成器生成的合成数据。

生成器旨在生成合成数据,这些数据非常令人信服,以至于判别器无法区分真实数据和生成数据。而鉴别器同时经过训练,变得更善于区分真实数据和生成数据。

训练的目标是生成器创建的数据越来越真实,而鉴别器在区分差异方面变得更加熟练。这种对抗过程会一直持续下去,直到生成器生成的数据基本上与真实数据无法区分。

当生成器生成高度真实的数据,而鉴别器无法可靠地将其与真实数据区分开来时,平衡点代表GAN的成功训练。

将GAN用于异常检测

生成对抗网络(GANs)可以通过训练它们生成正常或典型的数据分布来用于异常检测。

对于生成模型,我们一般使用GAN的方法是,使用GAN的生成器来学习普通数据的底层模式,并通过鉴别器来对其进行强化训练,最后得到一个非常强大的生成器模型

而对于异常检测来说,我们使用GAN的生成器组件来学习普通数据的底层模式,用来生成类似于正态分布的合成数据样本,然后得到一个强大的鉴别器(分类模型),这个模型就可以作为我们异常检测的模型来进行使用。

以下是GAN用于异常检测的步骤概述:

1、正常数据训练:

使用数据的正常或典型实例(例如,正常图像,正常传感器读数等)的数据集来训练GAN。生成器学习生成模拟正常数据分布的合成样本,鉴别器被训练以区分真实数据和合成数据。

2、合成数据的生成:

使用训练好的生成器生成一组合成数据样本。这些合成样本应该与训练数据中的正常实例相似,但是我们不需要这个部分的模型。

3、异常检测:

将GAN生成的合成数据与原始正常数据相结合。使用传统的异常检测技术或简单的阈值方法来识别明显偏离预期分布的实例。与真实数据和合成数据都不相似的实例被认为是潜在的异常。(这是一种简单方法)

4、鉴别器作为异常检测器:

鉴别器重新用作异常检测器。在异常检测阶段将其应用于真实数据和合成数据。鉴别器分类为真实的实例可能被认为是正常的,而分类为合成的实例可能被标记为潜在的异常。(这是单独使用鉴别器进行异常检测的方法)

代码示例

构建一个完整的生成对抗网络(GAN)包括几个组成部分,包括定义生成器和鉴别器架构,指定损失函数和设置训练循环。下面是一个使用Pytorch进行构建的简单实例

 import torch
 import torch.nn as nn
 import torch.optim as optim
 from torch.utils.data import DataLoader, TensorDataset
 import numpy as np
 import matplotlib.pyplot as plt

 # Define the generator model
 class Generator(nn.Module):
     def __init__(self, latent_dim):
         super(Generator, self).__init__()
         self.model = nn.Sequential(
             nn.Linear(latent_dim, 256),
             nn.ReLU(),
             nn.BatchNorm1d(256),
             nn.Linear(256, 512),
             nn.ReLU(),
             nn.BatchNorm1d(512),
             nn.Linear(512, 784),
             nn.Sigmoid(),
             nn.Unflatten(1, (28, 28, 1))
         )

     def forward(self, x):
         return self.model(x)

 # Define the discriminator model
 class Discriminator(nn.Module):
     def __init__(self, img_shape):
         super(Discriminator, self).__init__()
         self.model = nn.Sequential(
             nn.Flatten(),
             nn.Linear(np.prod(img_shape), 512),
             nn.ReLU(),
             nn.Linear(512, 256),
             nn.ReLU(),
             nn.Linear(256, 1),
             nn.Sigmoid()
         )

     def forward(self, x):
         return self.model(x)

 # Define the GAN model
 class GAN(nn.Module):
     def __init__(self, generator, discriminator):
         super(GAN, self).__init__()
         self.generator = generator
         self.discriminator = discriminator

     def forward(self, x):
         x = self.generator(x)
         x = self.discriminator(x)
         return x

 # Function to compile models
 def compile_models(generator, discriminator, gan, latent_dim):
     d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
     g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

     criterion = nn.BCELoss()

     discriminator.compile(optimizer=d_optimizer, loss=criterion)
     gan.compile(optimizer=g_optimizer, loss=criterion)

 # Function to generate random noise for the generator
 def generate_latent_points(latent_dim, batch_size):
     return torch.randn(batch_size, latent_dim)

 # Function to train the GAN
 def train_gan(generator, discriminator, gan, dataset, latent_dim, epochs, batch_size):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     generator.to(device)
     discriminator.to(device)
     gan.to(device)

     dataset = torch.tensor(dataset, dtype=torch.float32).to(device)
     dataloader = DataLoader(TensorDataset(dataset), batch_size=batch_size, shuffle=True)

     criterion = nn.BCELoss()

     for epoch in range(epochs):
         for batch_data in dataloader:
             real_data = batch_data[0].to(device)
             batch_size = real_data.size(0)

             noise = generate_latent_points(latent_dim, batch_size).to(device)
             generated_data = generator(noise)

             labels_real = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
             labels_fake = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)

             d_loss_real = criterion(discriminator(real_data), labels_real)
             d_loss_fake = criterion(discriminator(generated_data.detach()), labels_fake)

             d_loss = 0.5 * (d_loss_real + d_loss_fake)

             discriminator.zero_grad()
             d_loss.backward()
             discriminator_optimizer.step()

             noise = generate_latent_points(latent_dim, batch_size).to(device)
             labels_gan = torch.ones((batch_size, 1), dtype=torch.float32).to(device)

             g_loss = criterion(gan(noise), labels_gan)

             generator.zero_grad()
             g_loss.backward()
             generator_optimizer.step()

             print(f"Epoch {epoch + 1}/{epochs}, Batch {batch}/{len(dataloader)}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

 # Function to generate and plot synthetic data
 def generate_and_plot(generator, latent_dim, examples=10):
     generator.eval()
     noise = generate_latent_points(latent_dim, examples)
     generated_data = generator(noise).detach().cpu().numpy()

     for i in range(examples):
         plt.subplot(2, 5, i + 1)
         plt.imshow(generated_data[i, 0, :, :], cmap='gray_r')
         plt.axis('off')

     plt.show()

 # Example usage
 latent_dim = 100
 img_shape = (28, 28, 1)

 # Build and compile the models
 generator = Generator(latent_dim)
 discriminator = Discriminator(img_shape)
 gan = GAN(generator, discriminator)
 compile_models(generator, discriminator, gan, latent_dim)

 # Load and preprocess your dataset (e.g., MNIST)
 (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
 train_images = train_images / 127.5 - 1.0  # Normalize images to the range [-1, 1]
 train_images = np.expand_dims(train_images, axis=-1)

 # Train the GAN
 train_gan(generator, discriminator, gan, train_images, latent_dim, epochs=100, batch_size=64)

 # Generate and plot synthetic data
 generate_and_plot(generator, latent_dim)

以上实例基于MNIST数据集

基于GAN异常检测的研究进展

虽然GAN在生成领域没有那么大放异彩,但是在其他领域还是有很多人在研究

https://www.sciencedirect.com/science/article/abs/pii/S0925231221019482

综述了gan在异常检测中的应用。作者讨论了gan在异常检测中的优点和局限性,并对该领域的最新研究进行了概述。他们还指出了该领域研究的挑战和未来方向。

https://www.researchgate.net/publication/362320139_Anomaly_detection_methods_based_on_GAN_a_survey

这个调查关注gan在医疗时间序列异常检测中的应用,并取得了很好的结果。

https://asp-eurasipjournals.springeropen.com/articles/10.1186/s13634-022-00943-7

这篇论文提出了一种结合USAD生成对抗训练架构和卷积自编码器(CAE)的检测模型,通过生成对抗训练的正态数据分布和提高从数据中提取特征的能力来增强对抗训练过程中的稳定性。作者在几个基准数据集上证明了他们提出的模型优于现有方法。

https://arxiv.org/pdf/2310.00335.pdf

将gan应用于发电厂的异常检测。作者使用不同的增强技术,如自编码器和主成分分析来提高gan在异常检测方面的性能。

https://www.mdpi.com/1424-8220/23/1/355

提出了一种新的基于注意力特征融合的编码器-解码器GAN模块,用于GAN工业图像的异常检测。作者在几个基准数据集上证明了他们提出的方法优于现有方法。

https://ieeexplore.ieee.org/document/10043696

探讨了在生物医学成像中使用gan进行异常检测。作者介绍了使用gan进行异常检测的概述,并研究了最先进的基于gan的生物医学成像异常检测方法。他们证明了基于gan的方法在几个基准数据集上优于传统方法。

总结

可以看到GAN的研究还在继续,并且GAN的问题也还存在:

确保生成器和鉴别器之间的良好平衡是至关重要的。如果生成器太弱,它可能无法准确捕获正态数据分布。如果它太强,它可能无法产生不同的合成样本。

将GAN应用于异常检测可能是一种强大的方法,特别是在标记异常数据稀缺的情况下,因为GAN可以学习表示正态数据分布,而无需显式标记异常。但是训练阶段必须仔细进行,训练数据质量对结果至关重要,GAN已经在各种数据上表现出优于传统的异常检测的表现,我们期待他有更好的发展吧。

https://avoid.overfit.cn/post/cc6a7b7c18d04bd7ac3aa15d55520e57

目录
相关文章
|
算法 搜索推荐 图计算
图计算中的社区发现算法是什么?请解释其作用和常用算法。
图计算中的社区发现算法是什么?请解释其作用和常用算法。
530 0
|
机器学习/深度学习 运维 监控
深度学习之异常检测
基于深度学习的异常检测是一项重要的研究领域,主要用于识别数据中的异常样本或行为。异常检测广泛应用于多个领域,如网络安全、金融欺诈检测、工业设备预测性维护、医疗诊断等。
1169 2
|
3月前
|
数据采集 机器学习/深度学习 自然语言处理
98_数据增强:提升LLM微调效果的关键技术
在大语言模型(LLM)的微调过程中,数据质量与数量往往是决定最终性能的关键因素。然而,获取高质量、多样化且标注准确的训练数据却常常面临诸多挑战:数据标注成本高昂、领域特定数据稀缺、数据分布不均等问题都会直接影响微调效果。在这种背景下,数据增强技术作为一种能够有效扩充训练数据并提升其多样性的方法,正发挥着越来越重要的作用。
|
定位技术
ENVI: 如何创建GLT文件并基于GLT对图像进行几何校正?
ENVI: 如何创建GLT文件并基于GLT对图像进行几何校正?
1680 0
|
SQL 存储 分布式计算
流批一体技术简介
本文由阿里云 Flink 团队苏轩楠老师撰写,旨在向 Flink 用户整体介绍 Flink 流批一体的技术和挑战。
51439 3
流批一体技术简介
|
机器学习/深度学习 运维 数据挖掘
无监督学习在异常检测中的应用
【7月更文挑战第14天】无监督学习在异常检测中的应用具有重要意义,其可以帮助我们发现数据中的潜在异常模式,提高异常检测的效率和准确性。通过不断的研究和探索,我们可以进一步完善无监督学习方法在异常检测中的应用,为实际应用提供更加可靠和有效的解决方案。
|
机器学习/深度学习 人工智能 搜索推荐
语音识别技术的现状与未来展望
【6月更文挑战第15天】**语音识别技术现状与未来:** 随AI发展,语音识别精度与速度大幅提升,应用广泛,从手机助手到智能家居。深度学习驱动技术进步,跨语言及多模态交互成为新趋势。未来,精度、鲁棒性将增强,深度学习将进一步融合,个性化和情感化交互将提升用户体验。跨领域融合与生态共建将推动技术普及,为各行业带来更多智能解决方案。但同时也需关注技术伦理和社会影响。
1190 2
|
监控 安全 网络安全
防止 DDOS 攻击的7个技巧
防止 DDOS 攻击的7个技巧
5451 0
|
安全 网络协议 算法
电脑病毒木马的清除和防范方法
电脑病毒木马的清除和防范方法
3247 0
电脑病毒木马的清除和防范方法