在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(上)

简介: 方法Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p​ 和\mu_{q}μ q​ 的2个分布p和q之间的基于距离的度量:

方法

Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入μp\mu_{p}μpμq\mu_{q}μq的2个分布p和q之间的基于距离的度量:

MMD(F,p,q)=∣∣μp−μq∣∣F2MMD(F, p, q) = || \mu_{p} - \mu_{q} ||^2_{F}MMD(F,p,q)=μpμqF2

应用核技巧后,我们可以从两个分布的样本中计算出MMD2MMD^2MMD2的无偏估计。

默认情况下,我们使用径向基函数内核,但用户可以自由地将自己的首选内核传递给检测器。我们通过对MMD2MMD^2MMD2值的置换测试获得p值。该方法也在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中描述。


后端

该方法在 PyTorch 和 TensorFlow 框架中实现,支持 CPU 和 GPU 。对于这两个框架,Alibi Detect还支持各种现成的预处理步骤,并在本文中进行了说明。然而,Alibi Detect不会为您安装PyTorch。了解如何执行此操作查看PyTorch文档


数据集

CIFAR10 由 60,000 个 32 x 32 RGB 图像组成,平均分布在 10 个类别中。 我们在 CIFAR-10-C 数据集 (Hendrycks & Dietterich, 2019) 上评估漂移检测器。 CIFAR-10-C 中的实例受到不同严重程度的各种类型的噪声、模糊、亮度等的破坏和干扰,导致分类模型性能逐渐下降。 我们还检查具有类不平衡的原始测试集的漂移。

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from alibi_detect.cd import MMDDrift
from alibi_detect.models.tensorflow import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model
from alibi_detect.saving import save_detector, load_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
复制代码


加载数据

原始CIFAR-10数据:

# (50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# (50000,), (10000,)
y_train = y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)
复制代码


对于CIFAR-10-C,我们可以从以下5个严重级别的损坏类型中进行选择:

corruptions = corruption_types_cifar10c()
print(corruptions)
# <class 'list'>
复制代码


运行结果:

['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']
复制代码

让我们选择损坏级别 5 的损坏子集。每种损坏类型都包含对所有原始测试集图像的扰动。

# 高斯噪声、运动模糊、亮度、马赛克
corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate']
X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True)
X_corr = X_corr.astype('float32') / 255
# (40000, 32, 32, 3)
复制代码


我们将原始测试集拆分为参考数据集和MMD检验 H0 下不应拒绝的数据集。我们还按损坏类型拆分损坏的数据:

np.random.seed(0)
n_test = X_test.shape[0]
# (5000,)
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
idx_h0 = np.delete(np.arange(n_test), idx, axis=0)
X_ref,y_ref = X_test[idx], y_test[idx]
X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0]
print(X_ref.shape, X_h0.shape)
复制代码


运行结果:

(5000, 32, 32, 3) (5000, 32, 32, 3)
复制代码



# check that the classes are more or less balanced
# 返回 分类 及 分类的次数
classes, counts_ref = np.unique(y_ref, return_counts=True)
# 返回 分类 及 分类的次数
counts_h0 = np.unique(y_h0, return_counts=True)[1]
print('Class Ref H0')
for cl, cref, ch0 in zip(classes, counts_ref, counts_h0):
    assert cref + ch0 == n_test // 10
    print('{}     {} {}'.format(cl, cref, ch0))
复制代码


运行结果:

Class Ref H0
0     472 528
1     510 490
2     498 502
3     492 508
4     501 499
5     495 505
6     493 507
7     501 499
8     516 484
9     522 478
复制代码



# 4
n_corr = len(corruption)
# 按损坏类型拆分损坏的数据
X_c = [X_corr[i * n_test:(i + 1) * n_test] for i in range(n_corr)]
复制代码


我们可以为每个损坏类型可视化相同的实例:

i = 4
# 10000
n_test = X_test.shape[0]
plt.title('Original')
plt.axis('off')
plt.imshow(X_test[i])
plt.show()
for _ in range(len(corruption)):
    plt.title(corruption[_])
    plt.axis('off')
    plt.imshow(X_corr[n_test * _+ i])
    plt.show()
复制代码


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


网络异常,图片无法展示
|


我们还可以验证,在这个受干扰的数据集上,CIFAR-10上的分类模型的性能显著下降:

dataset = 'cifar10'
model = 'resnet32'
clf = fetch_tf_model(dataset, model)
acc = clf.evaluate(scale_by_instance(X_test), y_test, batch_size=128, verbose=0)[1]
print('Test set accuracy:')
print('Original {:.4f}'.format(acc))
clf_accuracy = {'original': acc}
for _ in range(len(corruption)):
    acc = clf.evaluate(scale_by_instance(X_c[_]), y_test, batch_size=128, verbose=0)[1]
    clf_accuracy[corruption[_]] = acc
    print('{} {:.4f}'.format(corruption[_], acc))
复制代码


运行结果:

Test set accuracy:
Original 0.9278
gaussian_noise 0.2208
motion_blur 0.6339
brightness 0.8913
pixelate 0.3666
复制代码


鉴于性能下降,我们检测不良的数据漂移非常重要!

使用 TensorFlow 后端检测漂移

首先,我们使用 TensorFlow 框架尝试使用漂移检测器进行预处理和 MMD 计算步骤。

我们正在尝试使用多元 MMD 置换测试检测高维 (32x32x3) 数据上的数据漂移。 因此,首先应用降维是有意义的。 在 Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift 中也使用了一些降维方法:随机初始化的编码器(论文中的UAEUntrained AutoEncoder)、BBSD(使用分类器的 softmax 输出的黑盒移位检测) 和 PCA(使用 scikit-learn)。


随机编码器

首先我们尝试随机初始化的编码器:

from functools import partial
from tensorflow.keras.layers import Conv2D, Dense, Flatten, InputLayer, Reshape
from alibi_detect.cd.tensorflow import preprocess_drift
tf.random.set_seed(0)
# define encoder
encoding_dim = 32
encoder_net = tf.keras.Sequential(
  [
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
      Flatten(),
      Dense(encoding_dim,)
  ]
)
# 定义预处理
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=encoder_net, batch_size=512)
# 初始化
# initialise drift detector
cd = MMDDrift(X_ref, backend='tensorflow', p_val=.05,
              preprocess_fn=preprocess_fn, n_permutations=100)
# 保存、加载
# we can also save/load an initialised detector
filepath = 'my_path'  # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)
复制代码


让我们检查检测器是否认为漂移发生在不同的测试集上,以及预测调用的时间:

from timeit import default_timer as timer
labels = ['No!', 'Yes!']
def make_predictions(cd, x_h0, x_corr, corruption):
    t = timer()
    preds = cd.predict(x_h0)
    dt = timer() - t
    # 没有损坏
    print('No corruption')
    print('Drift? {}'.format(labels[preds['data']['is_drift']]))
    print(f'p-value: {preds["data"]["p_val"]:.3f}')
    print(f'Time (s) {dt:.3f}')
    # 损坏列表
    if isinstance(x_corr, list):
        for x, c in zip(x_corr, corruption):
            t = timer()
            preds = cd.predict(x)
            dt = timer() - t
            print('')
            print(f'Corruption type: {c}')
            print('Drift? {}'.format(labels[preds['data']['is_drift']]))
            print(f'p-value: {preds["data"]["p_val"]:.3f}')
            print(f'Time (s) {dt:.3f}')
复制代码



make_predictions(cd, X_h0, X_c, corruption)
复制代码


运行结果:

No corruption
Drift? No!
p-value: 0.680
Time (s) 2.217
Corruption type: gaussian_noise
Drift? Yes!
p-value: 0.000
Time (s) 6.074
Corruption type: motion_blur
Drift? Yes!
p-value: 0.000
Time (s) 6.031
Corruption type: brightness
Drift? Yes!
p-value: 0.000
Time (s) 6.019
Corruption type: pixelate
Drift? Yes!
p-value: 0.000
Time (s) 6.010
复制代码


正如预期的那样,仅在损坏的数据集上检测到漂移。

相关文章
|
6月前
|
PyTorch 算法框架/工具
【IOU实验】即插即用!对bubbliiiing的yolo系列代码替换iou计算函数做比对实验(G_C_D_S-IOU)
【IOU实验】即插即用!对bubbliiiing的yolo系列代码替换iou计算函数做比对实验(G_C_D_S-IOU)
103 0
【IOU实验】即插即用!对bubbliiiing的yolo系列代码替换iou计算函数做比对实验(G_C_D_S-IOU)
YOLOv3的NMS参数调整对模型的准确率和召回率分别有什么影响?
YOLOv3的NMS参数调整对模型的准确率和召回率分别有什么影响?
|
6月前
|
机器学习/深度学习 数据可视化 算法
支持向量回归SVR拟合、预测回归数据和可视化准确性检查实例
支持向量回归SVR拟合、预测回归数据和可视化准确性检查实例
|
6月前
|
机器学习/深度学习 算法
R语言非参数方法:使用核回归平滑估计和K-NN(K近邻算法)分类预测心脏病数据
R语言非参数方法:使用核回归平滑估计和K-NN(K近邻算法)分类预测心脏病数据
|
6月前
|
存储 数据可视化 计算机视觉
基于YOLOv8的自定义数据姿势估计
基于YOLOv8的自定义数据姿势估计
|
机器学习/深度学习 Serverless 计算机视觉
NeRF 模型评价指标PSNR,MS-SSIM, LPIPS 详解和python实现
NeRF 模型评价指标PSNR,MS-SSIM, LPIPS 详解和python实现
2469 0
|
算法 数据挖掘
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(二)
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(二)
166 1
|
机器学习/深度学习 算法 前端开发
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(一)
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类(优于Mixup和Maniflod Mixup)(一)
236 1
|
机器学习/深度学习 文字识别 算法
深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测
深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测
YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样
这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。
615 0
YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样