方法
Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入μp\mu_{p}μp和μq\mu_{q}μq的2个分布p和q之间的基于距离的度量:
MMD(F,p,q)=∣∣μp−μq∣∣F2MMD(F, p, q) = || \mu_{p} - \mu_{q} ||^2_{F}MMD(F,p,q)=∣∣μp−μq∣∣F2
应用核技巧后,我们可以从两个分布的样本中计算出MMD2MMD^2MMD2的无偏估计。
默认情况下,我们使用径向基函数内核,但用户可以自由地将自己的首选内核传递给检测器。我们通过对MMD2MMD^2MMD2值的置换测试获得p值。该方法也在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中描述。
后端
该方法在 PyTorch 和 TensorFlow 框架中实现,支持 CPU 和 GPU 。对于这两个框架,Alibi Detect还支持各种现成的预处理步骤,并在本文中进行了说明。然而,Alibi Detect不会为您安装PyTorch。了解如何执行此操作查看PyTorch文档。
数据集
CIFAR10 由 60,000 个 32 x 32 RGB 图像组成,平均分布在 10 个类别中。 我们在 CIFAR-10-C 数据集 (Hendrycks & Dietterich, 2019) 上评估漂移检测器。 CIFAR-10-C 中的实例受到不同严重程度的各种类型的噪声、模糊、亮度等的破坏和干扰,导致分类模型性能逐渐下降。 我们还检查具有类不平衡的原始测试集的漂移。
from functools import partial import matplotlib.pyplot as plt import numpy as np import tensorflow as tf from alibi_detect.cd import MMDDrift from alibi_detect.models.tensorflow import scale_by_instance from alibi_detect.utils.fetching import fetch_tf_model from alibi_detect.saving import save_detector, load_detector from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c 复制代码
加载数据
原始CIFAR-10数据:
# (50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1) (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data() X_train = X_train.astype('float32') / 255 X_test = X_test.astype('float32') / 255 # (50000,), (10000,) y_train = y_train.astype('int64').reshape(-1,) y_test = y_test.astype('int64').reshape(-1,) 复制代码
对于CIFAR-10-C,我们可以从以下5个严重级别的损坏类型中进行选择:
corruptions = corruption_types_cifar10c() print(corruptions) # <class 'list'> 复制代码
运行结果:
['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur'] 复制代码
让我们选择损坏级别 5 的损坏子集。每种损坏类型都包含对所有原始测试集图像的扰动。
# 高斯噪声、运动模糊、亮度、马赛克 corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate'] X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True) X_corr = X_corr.astype('float32') / 255 # (40000, 32, 32, 3) 复制代码
我们将原始测试集拆分为参考数据集和MMD检验 H0 下不应拒绝的数据集。我们还按损坏类型拆分损坏的数据:
np.random.seed(0) n_test = X_test.shape[0] # (5000,) idx = np.random.choice(n_test, size=n_test // 2, replace=False) idx_h0 = np.delete(np.arange(n_test), idx, axis=0) X_ref,y_ref = X_test[idx], y_test[idx] X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0] print(X_ref.shape, X_h0.shape) 复制代码
运行结果:
(5000, 32, 32, 3) (5000, 32, 32, 3) 复制代码
# check that the classes are more or less balanced # 返回 分类 及 分类的次数 classes, counts_ref = np.unique(y_ref, return_counts=True) # 返回 分类 及 分类的次数 counts_h0 = np.unique(y_h0, return_counts=True)[1] print('Class Ref H0') for cl, cref, ch0 in zip(classes, counts_ref, counts_h0): assert cref + ch0 == n_test // 10 print('{} {} {}'.format(cl, cref, ch0)) 复制代码
运行结果:
Class Ref H0 0 472 528 1 510 490 2 498 502 3 492 508 4 501 499 5 495 505 6 493 507 7 501 499 8 516 484 9 522 478 复制代码
# 4 n_corr = len(corruption) # 按损坏类型拆分损坏的数据 X_c = [X_corr[i * n_test:(i + 1) * n_test] for i in range(n_corr)] 复制代码
我们可以为每个损坏类型可视化相同的实例:
i = 4 # 10000 n_test = X_test.shape[0] plt.title('Original') plt.axis('off') plt.imshow(X_test[i]) plt.show() for _ in range(len(corruption)): plt.title(corruption[_]) plt.axis('off') plt.imshow(X_corr[n_test * _+ i]) plt.show() 复制代码
我们还可以验证,在这个受干扰的数据集上,CIFAR-10上的分类模型的性能显著下降:
dataset = 'cifar10' model = 'resnet32' clf = fetch_tf_model(dataset, model) acc = clf.evaluate(scale_by_instance(X_test), y_test, batch_size=128, verbose=0)[1] print('Test set accuracy:') print('Original {:.4f}'.format(acc)) clf_accuracy = {'original': acc} for _ in range(len(corruption)): acc = clf.evaluate(scale_by_instance(X_c[_]), y_test, batch_size=128, verbose=0)[1] clf_accuracy[corruption[_]] = acc print('{} {:.4f}'.format(corruption[_], acc)) 复制代码
运行结果:
Test set accuracy: Original 0.9278 gaussian_noise 0.2208 motion_blur 0.6339 brightness 0.8913 pixelate 0.3666 复制代码
鉴于性能下降,我们检测不良的数据漂移非常重要!
使用 TensorFlow 后端检测漂移
首先,我们使用 TensorFlow 框架尝试使用漂移检测器进行预处理和 MMD 计算步骤。
我们正在尝试使用多元 MMD 置换测试检测高维 (32x32x3) 数据上的数据漂移。 因此,首先应用降维是有意义的。 在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中也使用了一些降维方法:随机初始化的编码器(论文中的UAE或Untrained AutoEncoder)、BBSD(使用分类器的 softmax 输出的黑盒移位检测) 和 PCA(使用 scikit-learn)。
随机编码器
首先我们尝试随机初始化的编码器:
from functools import partial from tensorflow.keras.layers import Conv2D, Dense, Flatten, InputLayer, Reshape from alibi_detect.cd.tensorflow import preprocess_drift tf.random.set_seed(0) # define encoder encoding_dim = 32 encoder_net = tf.keras.Sequential( [ InputLayer(input_shape=(32, 32, 3)), Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu), Flatten(), Dense(encoding_dim,) ] ) # 定义预处理 # define preprocessing function preprocess_fn = partial(preprocess_drift, model=encoder_net, batch_size=512) # 初始化 # initialise drift detector cd = MMDDrift(X_ref, backend='tensorflow', p_val=.05, preprocess_fn=preprocess_fn, n_permutations=100) # 保存、加载 # we can also save/load an initialised detector filepath = 'my_path' # change to directory where detector is saved save_detector(cd, filepath) cd = load_detector(filepath) 复制代码
让我们检查检测器是否认为漂移发生在不同的测试集上,以及预测调用的时间:
from timeit import default_timer as timer labels = ['No!', 'Yes!'] def make_predictions(cd, x_h0, x_corr, corruption): t = timer() preds = cd.predict(x_h0) dt = timer() - t # 没有损坏 print('No corruption') print('Drift? {}'.format(labels[preds['data']['is_drift']])) print(f'p-value: {preds["data"]["p_val"]:.3f}') print(f'Time (s) {dt:.3f}') # 损坏列表 if isinstance(x_corr, list): for x, c in zip(x_corr, corruption): t = timer() preds = cd.predict(x) dt = timer() - t print('') print(f'Corruption type: {c}') print('Drift? {}'.format(labels[preds['data']['is_drift']])) print(f'p-value: {preds["data"]["p_val"]:.3f}') print(f'Time (s) {dt:.3f}') 复制代码
make_predictions(cd, X_h0, X_c, corruption) 复制代码
运行结果:
No corruption Drift? No! p-value: 0.680 Time (s) 2.217 Corruption type: gaussian_noise Drift? Yes! p-value: 0.000 Time (s) 6.074 Corruption type: motion_blur Drift? Yes! p-value: 0.000 Time (s) 6.031 Corruption type: brightness Drift? Yes! p-value: 0.000 Time (s) 6.019 Corruption type: pixelate Drift? Yes! p-value: 0.000 Time (s) 6.010 复制代码
正如预期的那样,仅在损坏的数据集上检测到漂移。