检测和正确分类未见的异常是一个具有挑战性的问题,多年来已经以许多不同的方式解决了这个问题。而今天我们要介绍一种基于GAN的异常检测方法,GAN是一种深度学习模型,可以学习生成与给定数据集相似的真实数据样本。GAN的这一特性表明它们可以成功地用于异常检测,以前的基于GAN的生成模型都是使用GAN的生成器,而异常检测则是需要使用GAN的鉴别器。
GAN简介
生成对抗网络(GANs)是一类用于无监督机器学习的人工智能算法。它们是由Ian Goodfellow和他的同事在2014年推出的。GANs由生成器和鉴别器两个神经网络组成,它们通过对抗性训练同时进行训练。
生成器:GAN的这一部分负责生成新的数据实例。它将随机噪声作为输入,并将其转换成理想情况下与真实数据无法区分的数据。
鉴别器:GAN的这一部分充当分类器。它被训练来区分真实数据和由生成器生成的合成数据。
生成器旨在生成合成数据,这些数据非常令人信服,以至于判别器无法区分真实数据和生成数据。而鉴别器同时经过训练,变得更善于区分真实数据和生成数据。
训练的目标是生成器创建的数据越来越真实,而鉴别器在区分差异方面变得更加熟练。这种对抗过程会一直持续下去,直到生成器生成的数据基本上与真实数据无法区分。
当生成器生成高度真实的数据,而鉴别器无法可靠地将其与真实数据区分开来时,平衡点代表GAN的成功训练。
将GAN用于异常检测
生成对抗网络(GANs)可以通过训练它们生成正常或典型的数据分布来用于异常检测。
对于生成模型,我们一般使用GAN的方法是,使用GAN的生成器来学习普通数据的底层模式,并通过鉴别器来对其进行强化训练,最后得到一个非常强大的生成器模型
而对于异常检测来说,我们使用GAN的生成器组件来学习普通数据的底层模式,用来生成类似于正态分布的合成数据样本,然后得到一个强大的鉴别器(分类模型),这个模型就可以作为我们异常检测的模型来进行使用。
以下是GAN用于异常检测的步骤概述:
1、正常数据训练:
使用数据的正常或典型实例(例如,正常图像,正常传感器读数等)的数据集来训练GAN。生成器学习生成模拟正常数据分布的合成样本,鉴别器被训练以区分真实数据和合成数据。
2、合成数据的生成:
使用训练好的生成器生成一组合成数据样本。这些合成样本应该与训练数据中的正常实例相似,但是我们不需要这个部分的模型。
3、异常检测:
将GAN生成的合成数据与原始正常数据相结合。使用传统的异常检测技术或简单的阈值方法来识别明显偏离预期分布的实例。与真实数据和合成数据都不相似的实例被认为是潜在的异常。(这是一种简单方法)
4、鉴别器作为异常检测器:
鉴别器重新用作异常检测器。在异常检测阶段将其应用于真实数据和合成数据。鉴别器分类为真实的实例可能被认为是正常的,而分类为合成的实例可能被标记为潜在的异常。(这是单独使用鉴别器进行异常检测的方法)
代码示例
构建一个完整的生成对抗网络(GAN)包括几个组成部分,包括定义生成器和鉴别器架构,指定损失函数和设置训练循环。下面是一个使用Pytorch进行构建的简单实例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
# Define the generator model
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, 784),
nn.Sigmoid(),
nn.Unflatten(1, (28, 28, 1))
)
def forward(self, x):
return self.model(x)
# Define the discriminator model
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(img_shape), 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Define the GAN model
class GAN(nn.Module):
def __init__(self, generator, discriminator):
super(GAN, self).__init__()
self.generator = generator
self.discriminator = discriminator
def forward(self, x):
x = self.generator(x)
x = self.discriminator(x)
return x
# Function to compile models
def compile_models(generator, discriminator, gan, latent_dim):
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
discriminator.compile(optimizer=d_optimizer, loss=criterion)
gan.compile(optimizer=g_optimizer, loss=criterion)
# Function to generate random noise for the generator
def generate_latent_points(latent_dim, batch_size):
return torch.randn(batch_size, latent_dim)
# Function to train the GAN
def train_gan(generator, discriminator, gan, dataset, latent_dim, epochs, batch_size):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
gan.to(device)
dataset = torch.tensor(dataset, dtype=torch.float32).to(device)
dataloader = DataLoader(TensorDataset(dataset), batch_size=batch_size, shuffle=True)
criterion = nn.BCELoss()
for epoch in range(epochs):
for batch_data in dataloader:
real_data = batch_data[0].to(device)
batch_size = real_data.size(0)
noise = generate_latent_points(latent_dim, batch_size).to(device)
generated_data = generator(noise)
labels_real = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
labels_fake = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)
d_loss_real = criterion(discriminator(real_data), labels_real)
d_loss_fake = criterion(discriminator(generated_data.detach()), labels_fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
discriminator.zero_grad()
d_loss.backward()
discriminator_optimizer.step()
noise = generate_latent_points(latent_dim, batch_size).to(device)
labels_gan = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
g_loss = criterion(gan(noise), labels_gan)
generator.zero_grad()
g_loss.backward()
generator_optimizer.step()
print(f"Epoch {epoch + 1}/{epochs}, Batch {batch}/{len(dataloader)}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")
# Function to generate and plot synthetic data
def generate_and_plot(generator, latent_dim, examples=10):
generator.eval()
noise = generate_latent_points(latent_dim, examples)
generated_data = generator(noise).detach().cpu().numpy()
for i in range(examples):
plt.subplot(2, 5, i + 1)
plt.imshow(generated_data[i, 0, :, :], cmap='gray_r')
plt.axis('off')
plt.show()
# Example usage
latent_dim = 100
img_shape = (28, 28, 1)
# Build and compile the models
generator = Generator(latent_dim)
discriminator = Discriminator(img_shape)
gan = GAN(generator, discriminator)
compile_models(generator, discriminator, gan, latent_dim)
# Load and preprocess your dataset (e.g., MNIST)
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images / 127.5 - 1.0 # Normalize images to the range [-1, 1]
train_images = np.expand_dims(train_images, axis=-1)
# Train the GAN
train_gan(generator, discriminator, gan, train_images, latent_dim, epochs=100, batch_size=64)
# Generate and plot synthetic data
generate_and_plot(generator, latent_dim)
以上实例基于MNIST数据集
基于GAN异常检测的研究进展
虽然GAN在生成领域没有那么大放异彩,但是在其他领域还是有很多人在研究
https://www.sciencedirect.com/science/article/abs/pii/S0925231221019482
综述了gan在异常检测中的应用。作者讨论了gan在异常检测中的优点和局限性,并对该领域的最新研究进行了概述。他们还指出了该领域研究的挑战和未来方向。
https://www.researchgate.net/publication/362320139_Anomaly_detection_methods_based_on_GAN_a_survey
这个调查关注gan在医疗时间序列异常检测中的应用,并取得了很好的结果。
https://asp-eurasipjournals.springeropen.com/articles/10.1186/s13634-022-00943-7
这篇论文提出了一种结合USAD生成对抗训练架构和卷积自编码器(CAE)的检测模型,通过生成对抗训练的正态数据分布和提高从数据中提取特征的能力来增强对抗训练过程中的稳定性。作者在几个基准数据集上证明了他们提出的模型优于现有方法。
https://arxiv.org/pdf/2310.00335.pdf
将gan应用于发电厂的异常检测。作者使用不同的增强技术,如自编码器和主成分分析来提高gan在异常检测方面的性能。
https://www.mdpi.com/1424-8220/23/1/355
提出了一种新的基于注意力特征融合的编码器-解码器GAN模块,用于GAN工业图像的异常检测。作者在几个基准数据集上证明了他们提出的方法优于现有方法。
https://ieeexplore.ieee.org/document/10043696
探讨了在生物医学成像中使用gan进行异常检测。作者介绍了使用gan进行异常检测的概述,并研究了最先进的基于gan的生物医学成像异常检测方法。他们证明了基于gan的方法在几个基准数据集上优于传统方法。
总结
可以看到GAN的研究还在继续,并且GAN的问题也还存在:
确保生成器和鉴别器之间的良好平衡是至关重要的。如果生成器太弱,它可能无法准确捕获正态数据分布。如果它太强,它可能无法产生不同的合成样本。
将GAN应用于异常检测可能是一种强大的方法,特别是在标记异常数据稀缺的情况下,因为GAN可以学习表示正态数据分布,而无需显式标记异常。但是训练阶段必须仔细进行,训练数据质量对结果至关重要,GAN已经在各种数据上表现出优于传统的异常检测的表现,我们期待他有更好的发展吧。
https://avoid.overfit.cn/post/cc6a7b7c18d04bd7ac3aa15d55520e57