在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(上)

简介: 方法Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p​ 和\mu_{q}μ q​ 的2个分布p和q之间的基于距离的度量:

方法

Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入μp\mu_{p}μpμq\mu_{q}μq的2个分布p和q之间的基于距离的度量:

MMD(F,p,q)=∣∣μp−μq∣∣F2MMD(F, p, q) = || \mu_{p} - \mu_{q} ||^2_{F}MMD(F,p,q)=μpμqF2

应用核技巧后,我们可以从两个分布的样本中计算出MMD2MMD^2MMD2的无偏估计。

默认情况下,我们使用径向基函数内核,但用户可以自由地将自己的首选内核传递给检测器。我们通过对MMD2MMD^2MMD2值的置换测试获得p值。该方法也在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中描述。


后端

该方法在 PyTorch 和 TensorFlow 框架中实现,支持 CPU 和 GPU 。对于这两个框架,Alibi Detect还支持各种现成的预处理步骤,并在本文中进行了说明。然而,Alibi Detect不会为您安装PyTorch。了解如何执行此操作查看PyTorch文档


数据集

CIFAR10 由 60,000 个 32 x 32 RGB 图像组成,平均分布在 10 个类别中。 我们在 CIFAR-10-C 数据集 (Hendrycks & Dietterich, 2019) 上评估漂移检测器。 CIFAR-10-C 中的实例受到不同严重程度的各种类型的噪声、模糊、亮度等的破坏和干扰,导致分类模型性能逐渐下降。 我们还检查具有类不平衡的原始测试集的漂移。

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from alibi_detect.cd import MMDDrift
from alibi_detect.models.tensorflow import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model
from alibi_detect.saving import save_detector, load_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
复制代码


加载数据

原始CIFAR-10数据:

# (50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# (50000,), (10000,)
y_train = y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)
复制代码


对于CIFAR-10-C,我们可以从以下5个严重级别的损坏类型中进行选择:

corruptions = corruption_types_cifar10c()
print(corruptions)
# <class 'list'>
复制代码


运行结果:

['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']
复制代码

让我们选择损坏级别 5 的损坏子集。每种损坏类型都包含对所有原始测试集图像的扰动。

# 高斯噪声、运动模糊、亮度、马赛克
corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate']
X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True)
X_corr = X_corr.astype('float32') / 255
# (40000, 32, 32, 3)
复制代码


我们将原始测试集拆分为参考数据集和MMD检验 H0 下不应拒绝的数据集。我们还按损坏类型拆分损坏的数据:

np.random.seed(0)
n_test = X_test.shape[0]
# (5000,)
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
idx_h0 = np.delete(np.arange(n_test), idx, axis=0)
X_ref,y_ref = X_test[idx], y_test[idx]
X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0]
print(X_ref.shape, X_h0.shape)
复制代码


运行结果:

(5000, 32, 32, 3) (5000, 32, 32, 3)
复制代码



# check that the classes are more or less balanced
# 返回 分类 及 分类的次数
classes, counts_ref = np.unique(y_ref, return_counts=True)
# 返回 分类 及 分类的次数
counts_h0 = np.unique(y_h0, return_counts=True)[1]
print('Class Ref H0')
for cl, cref, ch0 in zip(classes, counts_ref, counts_h0):
    assert cref + ch0 == n_test // 10
    print('{}     {} {}'.format(cl, cref, ch0))
复制代码


运行结果:

Class Ref H0
0     472 528
1     510 490
2     498 502
3     492 508
4     501 499
5     495 505
6     493 507
7     501 499
8     516 484
9     522 478
复制代码



# 4
n_corr = len(corruption)
# 按损坏类型拆分损坏的数据
X_c = [X_corr[i * n_test:(i + 1) * n_test] for i in range(n_corr)]
复制代码


我们可以为每个损坏类型可视化相同的实例:

i = 4
# 10000
n_test = X_test.shape[0]
plt.title('Original')
plt.axis('off')
plt.imshow(X_test[i])
plt.show()
for _ in range(len(corruption)):
    plt.title(corruption[_])
    plt.axis('off')
    plt.imshow(X_corr[n_test * _+ i])
    plt.show()
复制代码


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


我们还可以验证,在这个受干扰的数据集上,CIFAR-10上的分类模型的性能显著下降:

dataset = 'cifar10'
model = 'resnet32'
clf = fetch_tf_model(dataset, model)
acc = clf.evaluate(scale_by_instance(X_test), y_test, batch_size=128, verbose=0)[1]
print('Test set accuracy:')
print('Original {:.4f}'.format(acc))
clf_accuracy = {'original': acc}
for _ in range(len(corruption)):
    acc = clf.evaluate(scale_by_instance(X_c[_]), y_test, batch_size=128, verbose=0)[1]
    clf_accuracy[corruption[_]] = acc
    print('{} {:.4f}'.format(corruption[_], acc))
复制代码


运行结果:

Test set accuracy:
Original 0.9278
gaussian_noise 0.2208
motion_blur 0.6339
brightness 0.8913
pixelate 0.3666
复制代码


鉴于性能下降,我们检测不良的数据漂移非常重要!

使用 TensorFlow 后端检测漂移

首先,我们使用 TensorFlow 框架尝试使用漂移检测器进行预处理和 MMD 计算步骤。

我们正在尝试使用多元 MMD 置换测试检测高维 (32x32x3) 数据上的数据漂移。 因此,首先应用降维是有意义的。 在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中也使用了一些降维方法:随机初始化的编码器(论文中的UAEUntrained AutoEncoder)、BBSD(使用分类器的 softmax 输出的黑盒移位检测) 和 PCA(使用 scikit-learn)。


随机编码器

首先我们尝试随机初始化的编码器:

from functools import partial
from tensorflow.keras.layers import Conv2D, Dense, Flatten, InputLayer, Reshape
from alibi_detect.cd.tensorflow import preprocess_drift
tf.random.set_seed(0)
# define encoder
encoding_dim = 32
encoder_net = tf.keras.Sequential(
  [
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
      Flatten(),
      Dense(encoding_dim,)
  ]
)
# 定义预处理
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=encoder_net, batch_size=512)
# 初始化
# initialise drift detector
cd = MMDDrift(X_ref, backend='tensorflow', p_val=.05,
              preprocess_fn=preprocess_fn, n_permutations=100)
# 保存、加载
# we can also save/load an initialised detector
filepath = 'my_path'  # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)
复制代码


让我们检查检测器是否认为漂移发生在不同的测试集上,以及预测调用的时间:

from timeit import default_timer as timer
labels = ['No!', 'Yes!']
def make_predictions(cd, x_h0, x_corr, corruption):
    t = timer()
    preds = cd.predict(x_h0)
    dt = timer() - t
    # 没有损坏
    print('No corruption')
    print('Drift? {}'.format(labels[preds['data']['is_drift']]))
    print(f'p-value: {preds["data"]["p_val"]:.3f}')
    print(f'Time (s) {dt:.3f}')
    # 损坏列表
    if isinstance(x_corr, list):
        for x, c in zip(x_corr, corruption):
            t = timer()
            preds = cd.predict(x)
            dt = timer() - t
            print('')
            print(f'Corruption type: {c}')
            print('Drift? {}'.format(labels[preds['data']['is_drift']]))
            print(f'p-value: {preds["data"]["p_val"]:.3f}')
            print(f'Time (s) {dt:.3f}')
复制代码



make_predictions(cd, X_h0, X_c, corruption)
复制代码


运行结果:

No corruption
Drift? No!
p-value: 0.680
Time (s) 2.217
Corruption type: gaussian_noise
Drift? Yes!
p-value: 0.000
Time (s) 6.074
Corruption type: motion_blur
Drift? Yes!
p-value: 0.000
Time (s) 6.031
Corruption type: brightness
Drift? Yes!
p-value: 0.000
Time (s) 6.019
Corruption type: pixelate
Drift? Yes!
p-value: 0.000
Time (s) 6.010
复制代码


正如预期的那样,仅在损坏的数据集上检测到漂移。

相关文章
|
机器学习/深度学习 监控
数据漂移、概念漂移以及如何监控它们(mona)
在机器学习模型监控的上下文中经常提到数据和概念漂移,但它们到底是什么以及如何检测到它们?此外,考虑到围绕它们的常见误解,是不惜一切代价避免数据和概念漂移的事情,还是在生产中训练模型的自然和可接受的后果?请仔细阅读,找出答案。在本文中,我们将提供模型漂移的细粒度细分,以及检测它们的方法以及处理它们时的最佳实践。
|
2月前
|
固态存储 IDE 开发工具
电脑无法识别固态硬盘怎么办?
本文详解固态硬盘(SSD)无法被电脑识别的常见问题及解决方法。涵盖硬件连接、BIOS设置、系统识别、驱动安装等方面,适用于新手与老用户。分析四种常见识别失败情况,并提供排查步骤与解决方案,助你快速定位问题并修复。
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
人工智能与情感计算:AI如何理解人类情感
人工智能与情感计算:AI如何理解人类情感
1446 20
|
机器学习/深度学习 人工智能 自然语言处理
一文速通半监督学习(Semi-supervised Learning):桥接有标签与无标签数据
一文速通半监督学习(Semi-supervised Learning):桥接有标签与无标签数据
828 0
|
存储 机器学习/深度学习 人工智能
深入浅出 AI 智能体(AI Agent)|技术干货
随着人工智能技术的发展,智能体(AI Agents)逐渐成为人与大模型交互的主要方式。智能体能执行任务、解决问题,并提供个性化服务。其关键组成部分包括规划、记忆和工具使用,使交互更加高效、自然。智能体的应用涵盖专业领域问答、资讯整理、角色扮演等场景,极大地提升了用户体验与工作效率。借助智能体开发平台,用户可以轻松打造定制化AI应用,推动AI技术在各领域的广泛应用与深度融合。
24618 1
|
机器学习/深度学习 开发框架 自然语言处理
深度学习中的自动学习率调整方法探索与应用
传统深度学习模型中,学习率的选择对训练效果至关重要,然而其调整通常依赖于经验或静态策略。本文探讨了现代深度学习中的自动学习率调整方法,通过分析不同算法的原理与应用实例,展示了这些方法在提高模型收敛速度和精度方面的潜力。 【7月更文挑战第14天】
217 3
|
机器学习/深度学习 数据采集 监控
深度学习之在线学习与适应
基于深度学习的在线学习与适应,旨在开发能够在不断变化的环境中实时学习和调整的模型,使其在面对新数据或新任务时能够迅速适应并维持高性能。
269 0
|
算法 数据可视化 数据挖掘
Barnes-Hut t-SNE:大规模数据的高效降维算法
Barnes-Hut t-SNE是一种针对大规模数据集的高效降维算法,它是t-SNE的变体,用于高维数据可视化。t-SNE通过保持概率分布相似性将数据从高维降至2D或3D。Barnes-Hut算法采用天体物理中的方法,将时间复杂度从O(N²)降低到O(NlogN),通过构建空间索引树和近似远距离交互来加速计算。在scikit-learn中可用,代码示例展示了如何使用该算法进行聚类可视化,成功分离出不同簇并获得高轮廓分数,证明其在大數據集上的有效性。
332 1
|
机器学习/深度学习 人工智能 自然语言处理
多任务学习的优势
【5月更文挑战第25天】多任务学习的优势
339 6
|
机器学习/深度学习 数据采集 Prometheus
机器学习模型监控工具:Evidently 与 Seldon Alibi 对比
每当我们训练和部署机器学习模型时,我们都希望确保该模型在生产中表现良好。 模型需要监控,因为现实世界中发生了我们在训练期间无法解释的事情。最明显的例子是当现实世界的数据偏离训练数据时,或者当我们遇到异常值时。我们使用监控来做出决策,例如:何时重新训练或何时获取新数据。