使用GAN进行异常检测

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,5000CU*H 3个月
简介: 自从基于Stable Diffusion的生成模型大火以后,基于GAN的研究越来越少了,但是这并不能说明他就没有用了。异常检测是多个研究领域面临的重要问题,包括金融、医疗保健和网络安全。

检测和正确分类未见的异常是一个具有挑战性的问题,多年来已经以许多不同的方式解决了这个问题。而今天我们要介绍一种基于GAN的异常检测方法,GAN是一种深度学习模型,可以学习生成与给定数据集相似的真实数据样本。GAN的这一特性表明它们可以成功地用于异常检测,以前的基于GAN的生成模型都是使用GAN的生成器,而异常检测则是需要使用GAN的鉴别器。

GAN简介

生成对抗网络(GANs)是一类用于无监督机器学习的人工智能算法。它们是由Ian Goodfellow和他的同事在2014年推出的。GANs由生成器和鉴别器两个神经网络组成,它们通过对抗性训练同时进行训练。

生成器:GAN的这一部分负责生成新的数据实例。它将随机噪声作为输入,并将其转换成理想情况下与真实数据无法区分的数据。

鉴别器:GAN的这一部分充当分类器。它被训练来区分真实数据和由生成器生成的合成数据。

生成器旨在生成合成数据,这些数据非常令人信服,以至于判别器无法区分真实数据和生成数据。而鉴别器同时经过训练,变得更善于区分真实数据和生成数据。

训练的目标是生成器创建的数据越来越真实,而鉴别器在区分差异方面变得更加熟练。这种对抗过程会一直持续下去,直到生成器生成的数据基本上与真实数据无法区分。

当生成器生成高度真实的数据,而鉴别器无法可靠地将其与真实数据区分开来时,平衡点代表GAN的成功训练。

将GAN用于异常检测

生成对抗网络(GANs)可以通过训练它们生成正常或典型的数据分布来用于异常检测。

对于生成模型,我们一般使用GAN的方法是,使用GAN的生成器来学习普通数据的底层模式,并通过鉴别器来对其进行强化训练,最后得到一个非常强大的生成器模型

而对于异常检测来说,我们使用GAN的生成器组件来学习普通数据的底层模式,用来生成类似于正态分布的合成数据样本,然后得到一个强大的鉴别器(分类模型),这个模型就可以作为我们异常检测的模型来进行使用。

以下是GAN用于异常检测的步骤概述:

1、正常数据训练:

使用数据的正常或典型实例(例如,正常图像,正常传感器读数等)的数据集来训练GAN。生成器学习生成模拟正常数据分布的合成样本,鉴别器被训练以区分真实数据和合成数据。

2、合成数据的生成:

使用训练好的生成器生成一组合成数据样本。这些合成样本应该与训练数据中的正常实例相似,但是我们不需要这个部分的模型。

3、异常检测:

将GAN生成的合成数据与原始正常数据相结合。使用传统的异常检测技术或简单的阈值方法来识别明显偏离预期分布的实例。与真实数据和合成数据都不相似的实例被认为是潜在的异常。(这是一种简单方法)

4、鉴别器作为异常检测器:

鉴别器重新用作异常检测器。在异常检测阶段将其应用于真实数据和合成数据。鉴别器分类为真实的实例可能被认为是正常的,而分类为合成的实例可能被标记为潜在的异常。(这是单独使用鉴别器进行异常检测的方法)

代码示例

构建一个完整的生成对抗网络(GAN)包括几个组成部分,包括定义生成器和鉴别器架构,指定损失函数和设置训练循环。下面是一个使用Pytorch进行构建的简单实例

 import torch
 import torch.nn as nn
 import torch.optim as optim
 from torch.utils.data import DataLoader, TensorDataset
 import numpy as np
 import matplotlib.pyplot as plt

 # Define the generator model
 class Generator(nn.Module):
     def __init__(self, latent_dim):
         super(Generator, self).__init__()
         self.model = nn.Sequential(
             nn.Linear(latent_dim, 256),
             nn.ReLU(),
             nn.BatchNorm1d(256),
             nn.Linear(256, 512),
             nn.ReLU(),
             nn.BatchNorm1d(512),
             nn.Linear(512, 784),
             nn.Sigmoid(),
             nn.Unflatten(1, (28, 28, 1))
         )

     def forward(self, x):
         return self.model(x)

 # Define the discriminator model
 class Discriminator(nn.Module):
     def __init__(self, img_shape):
         super(Discriminator, self).__init__()
         self.model = nn.Sequential(
             nn.Flatten(),
             nn.Linear(np.prod(img_shape), 512),
             nn.ReLU(),
             nn.Linear(512, 256),
             nn.ReLU(),
             nn.Linear(256, 1),
             nn.Sigmoid()
         )

     def forward(self, x):
         return self.model(x)

 # Define the GAN model
 class GAN(nn.Module):
     def __init__(self, generator, discriminator):
         super(GAN, self).__init__()
         self.generator = generator
         self.discriminator = discriminator

     def forward(self, x):
         x = self.generator(x)
         x = self.discriminator(x)
         return x

 # Function to compile models
 def compile_models(generator, discriminator, gan, latent_dim):
     d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
     g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

     criterion = nn.BCELoss()

     discriminator.compile(optimizer=d_optimizer, loss=criterion)
     gan.compile(optimizer=g_optimizer, loss=criterion)

 # Function to generate random noise for the generator
 def generate_latent_points(latent_dim, batch_size):
     return torch.randn(batch_size, latent_dim)

 # Function to train the GAN
 def train_gan(generator, discriminator, gan, dataset, latent_dim, epochs, batch_size):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     generator.to(device)
     discriminator.to(device)
     gan.to(device)

     dataset = torch.tensor(dataset, dtype=torch.float32).to(device)
     dataloader = DataLoader(TensorDataset(dataset), batch_size=batch_size, shuffle=True)

     criterion = nn.BCELoss()

     for epoch in range(epochs):
         for batch_data in dataloader:
             real_data = batch_data[0].to(device)
             batch_size = real_data.size(0)

             noise = generate_latent_points(latent_dim, batch_size).to(device)
             generated_data = generator(noise)

             labels_real = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
             labels_fake = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)

             d_loss_real = criterion(discriminator(real_data), labels_real)
             d_loss_fake = criterion(discriminator(generated_data.detach()), labels_fake)

             d_loss = 0.5 * (d_loss_real + d_loss_fake)

             discriminator.zero_grad()
             d_loss.backward()
             discriminator_optimizer.step()

             noise = generate_latent_points(latent_dim, batch_size).to(device)
             labels_gan = torch.ones((batch_size, 1), dtype=torch.float32).to(device)

             g_loss = criterion(gan(noise), labels_gan)

             generator.zero_grad()
             g_loss.backward()
             generator_optimizer.step()

             print(f"Epoch {epoch + 1}/{epochs}, Batch {batch}/{len(dataloader)}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

 # Function to generate and plot synthetic data
 def generate_and_plot(generator, latent_dim, examples=10):
     generator.eval()
     noise = generate_latent_points(latent_dim, examples)
     generated_data = generator(noise).detach().cpu().numpy()

     for i in range(examples):
         plt.subplot(2, 5, i + 1)
         plt.imshow(generated_data[i, 0, :, :], cmap='gray_r')
         plt.axis('off')

     plt.show()

 # Example usage
 latent_dim = 100
 img_shape = (28, 28, 1)

 # Build and compile the models
 generator = Generator(latent_dim)
 discriminator = Discriminator(img_shape)
 gan = GAN(generator, discriminator)
 compile_models(generator, discriminator, gan, latent_dim)

 # Load and preprocess your dataset (e.g., MNIST)
 (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
 train_images = train_images / 127.5 - 1.0  # Normalize images to the range [-1, 1]
 train_images = np.expand_dims(train_images, axis=-1)

 # Train the GAN
 train_gan(generator, discriminator, gan, train_images, latent_dim, epochs=100, batch_size=64)

 # Generate and plot synthetic data
 generate_and_plot(generator, latent_dim)

以上实例基于MNIST数据集

基于GAN异常检测的研究进展

虽然GAN在生成领域没有那么大放异彩,但是在其他领域还是有很多人在研究

https://www.sciencedirect.com/science/article/abs/pii/S0925231221019482

综述了gan在异常检测中的应用。作者讨论了gan在异常检测中的优点和局限性,并对该领域的最新研究进行了概述。他们还指出了该领域研究的挑战和未来方向。

https://www.researchgate.net/publication/362320139_Anomaly_detection_methods_based_on_GAN_a_survey

这个调查关注gan在医疗时间序列异常检测中的应用,并取得了很好的结果。

https://asp-eurasipjournals.springeropen.com/articles/10.1186/s13634-022-00943-7

这篇论文提出了一种结合USAD生成对抗训练架构和卷积自编码器(CAE)的检测模型,通过生成对抗训练的正态数据分布和提高从数据中提取特征的能力来增强对抗训练过程中的稳定性。作者在几个基准数据集上证明了他们提出的模型优于现有方法。

https://arxiv.org/pdf/2310.00335.pdf

将gan应用于发电厂的异常检测。作者使用不同的增强技术,如自编码器和主成分分析来提高gan在异常检测方面的性能。

https://www.mdpi.com/1424-8220/23/1/355

提出了一种新的基于注意力特征融合的编码器-解码器GAN模块,用于GAN工业图像的异常检测。作者在几个基准数据集上证明了他们提出的方法优于现有方法。

https://ieeexplore.ieee.org/document/10043696

探讨了在生物医学成像中使用gan进行异常检测。作者介绍了使用gan进行异常检测的概述,并研究了最先进的基于gan的生物医学成像异常检测方法。他们证明了基于gan的方法在几个基准数据集上优于传统方法。

总结

可以看到GAN的研究还在继续,并且GAN的问题也还存在:

确保生成器和鉴别器之间的良好平衡是至关重要的。如果生成器太弱,它可能无法准确捕获正态数据分布。如果它太强,它可能无法产生不同的合成样本。

将GAN应用于异常检测可能是一种强大的方法,特别是在标记异常数据稀缺的情况下,因为GAN可以学习表示正态数据分布,而无需显式标记异常。但是训练阶段必须仔细进行,训练数据质量对结果至关重要,GAN已经在各种数据上表现出优于传统的异常检测的表现,我们期待他有更好的发展吧。

https://avoid.overfit.cn/post/cc6a7b7c18d04bd7ac3aa15d55520e57

目录
相关文章
|
6月前
|
机器学习/深度学习 数据采集 算法
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
解码癌症预测的密码:可解释性机器学习算法SHAP揭示XGBoost模型的预测机制
307 0
|
1月前
|
机器学习/深度学习 调度 知识图谱
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
近年来,深度神经网络成为时间序列预测的主流方法。自监督学习通过从未标记数据中学习,能够捕获时间序列的长期依赖和局部特征。TimeDART结合扩散模型和自回归建模,创新性地解决了时间序列预测中的关键挑战,在多个数据集上取得了最优性能,展示了强大的泛化能力。
75 0
TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
|
1月前
|
机器学习/深度学习 运维 监控
深度学习之异常检测
基于深度学习的异常检测是一项重要的研究领域,主要用于识别数据中的异常样本或行为。异常检测广泛应用于多个领域,如网络安全、金融欺诈检测、工业设备预测性维护、医疗诊断等。
117 2
|
4月前
|
机器学习/深度学习 运维 数据挖掘
无监督学习在异常检测中的应用
【7月更文挑战第14天】无监督学习在异常检测中的应用具有重要意义,其可以帮助我们发现数据中的潜在异常模式,提高异常检测的效率和准确性。通过不断的研究和探索,我们可以进一步完善无监督学习方法在异常检测中的应用,为实际应用提供更加可靠和有效的解决方案。
|
6月前
|
机器学习/深度学习 数据采集 人工智能
【机器学习】怎样检测到线性回归模型中的过拟合?
【5月更文挑战第17天】【机器学习】怎样检测到线性回归模型中的过拟合?
|
6月前
|
机器学习/深度学习 运维
深度学习实现自编码器Autoencoder神经网络异常检测心电图ECG时间序列
深度学习实现自编码器Autoencoder神经网络异常检测心电图ECG时间序列
|
6月前
|
运维 数据挖掘 Python
探索LightGBM:监督式聚类与异常检测
探索LightGBM:监督式聚类与异常检测【2月更文挑战第3天】
118 1
|
机器学习/深度学习 存储 算法
图像识别1:基于相似性度量的二分类实验
图像识别1:基于相似性度量的二分类实验
97 0
|
机器学习/深度学习 算法 数据库
图像识别2:图像多分类实验
图像识别2:图像多分类实验
72 0
|
机器学习/深度学习 自然语言处理
【深度学习】实验17 使用GAN生成手写数字样本
【深度学习】实验17 使用GAN生成手写数字样本
130 0