在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(4)

简介: 我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x),因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。


我们来看一个实例:

print_sentence(X_train[0], token2word)
复制代码


运行结果:

cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
[1415   33    6   22   12  215   28   77   52    5   14  407   16   82
    2    8    4  107  117 5952   15  256    4    2    7 3766    5  723
   36   71   43  530  476   26  400  317   46    7    4    2 1029   13
  104   88    4  381   15  297   98   32 2071   56   26  141    6  194
 7486   18    4  226   22   21  134  476   26  480    5  144   30 5535
   18   51   36   28  224   92   25  104    4  226   65   16   38 1334
   88   12   16  283    5   16 4472  113  103   32   15   16 5345   19
  178   32]
复制代码


定义和训练一个简单的模型:

model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)
model.fit(X_train, y_train, batch_size=32, epochs=2,
          shuffle=True, validation_data=(X_test, y_test))
复制代码


运行结果:

Epoch 1/2
782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474
Epoch 2/2
782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451
复制代码


从训练好的模型中提取嵌入层并结合UAE预处理步骤:

embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)
x_emb = embedding(X_train[:5])
print(x_emb.shape)
复制代码


运行结果:

(5, 100, 256)
复制代码



tf.random.set_seed(0)
shape = tuple(x_emb.shape[1:])
uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)
复制代码


同样,创建参考、H0 和扰动数据集。 还针对Reuters新闻主题分类数据集进行测试。

X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
tokens = [word2token[w] for w in words]
X_word = {}
for i, t in enumerate(tokens):
    X_word[words[i]] = {}
    for p in perc_chg:
        X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first')
复制代码



# load and tokenize Reuters dataset
(X_reut, y_reut), (w2t_reut, t2w_reut) = \
    get_dataset(dataset='reuters', max_len=max_len)[1:]
# sample random instances
idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)
X_ood = X_reut[idx]
复制代码


初始化检测器并检测漂移

from alibi_detect.cd.tensorflow import preprocess_drift
# define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model`
def convert_list(X: list):
    return np.array(X)
# define preprocessing function
preprocess_fn = partial(preprocess_drift, 
    model=uae, 
    batch_size=128, 
    preprocess_batch_fn=convert_list)
# initialize detector
cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn)
复制代码



H0数据集:

preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
复制代码


运行结果:

Drift? No!
p-value: [0.18111965 0.50035924 0.5360543  0.722555   0.2406036  0.02925058 0.43243074 0.12050407 0.722555   0.60991895 0.19951835 0.60991895 0.50035924 0.79439443 0.722555   0.64755726 0.40047103 0.34099194 0.1338343  0.10828251 0.64755726 0.9995433  0.9540582  0.9134755 0.40047103 0.1640792  0.40047103 0.64755726 0.9134755  0.7590978 0.5726548  0.722555  ]
复制代码


扰动数据集:


for w, probas in X_word.items():
    for p, v in probas.items():
        preds = cd.predict(v)
        print('Word: {} -- % perturbed: {}'.format(w, p))
        print('Drift? {}'.format(labels[preds['data']['is_drift']]))
        print('p-value: {}'.format(preds['data']['p_val']))
        print('')
复制代码

运行结果:

Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.9998709  0.7590978  0.99870795 0.9995433  0.9801618  0.9134755
 0.82795686 0.99870795 0.9882611  0.8879386  0.9801618  0.79439443
 0.85929435 0.96887016 0.9134755  0.996931   0.5726548  0.93558097
 0.9882611  0.99870795 0.93558097 0.96887016 0.85929435 0.9882611
 0.93558097 0.996931   0.996931   0.96887016 0.9882611  0.96887016
 0.8879386  0.996931  ]
Word: fantastic -- % perturbed: 5.0
Drift? No!
p-value: [0.85929435 0.06155144 0.9540582  0.79439443 0.43243074 0.6852314
 0.722555   0.9134755  0.28769323 0.996931   0.60991895 0.19951835
 0.43243074 0.64755726 0.722555   0.8879386  0.18111965 0.18111965
 0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108
 0.722555   0.46576622 0.07762147 0.8879386  0.05464633 0.10828251
 0.03327804 0.9801618 ]
Word: good -- % perturbed: 1.0
Drift? No!
p-value: [0.99365413 0.8879386  0.99870795 0.9801618  0.99870795 0.99870795
 0.9134755  0.93558097 0.8879386  0.9995433  0.93558097 0.996931
 0.99999607 0.9995433  0.99870795 0.9801618  0.99870795 0.9801618
 0.8879386  0.996931   0.9134755  0.996931   0.7590978  0.99365413
 0.9540582  0.99870795 0.99870795 0.9998709  0.9801618  0.64755726
 0.9999727  0.8879386 ]
Word: good -- % perturbed: 5.0
Drift? No!
p-value: [0.9882611  0.6852314  0.79439443 0.60991895 0.28769323 0.3699725
 0.28769323 0.6852314  0.79439443 0.31356168 0.99870795 0.85929435
 0.34099194 0.34099194 0.8879386  0.996931   0.96887016 0.96887016
 0.9540582  0.722555   0.19951835 0.9995433  0.3699725  0.722555
 0.1338343  0.9134755  0.5360543  0.26338065 0.85929435 0.2406036
 0.31356168 0.6852314 ]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.93558097 0.996931   0.85929435 0.9540582  0.50035924 0.64755726
 0.82795686 0.85929435 0.82795686 0.9882611  0.82795686 0.9540582
 0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555
 0.93558097 0.93558097 0.64755726 0.99365413 0.5726548  0.9998709
 0.93558097 0.96887016 0.9995433  0.99365413 0.7590978  0.93558097
 0.9882611  0.9134755 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01
 3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01
 2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01
 1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01
 9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02
 6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01
 8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02
 1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: [0.996931   0.9801618  0.96887016 0.79439443 0.79439443 0.5726548
 0.82795686 0.996931   0.43243074 0.93558097 0.79439443 0.82795686
 0.06919032 0.3699725  0.96887016 0.9540582  0.5360543  0.6852314
 0.60991895 0.79439443 0.9540582  0.9801618  0.40047103 0.5726548
 0.82795686 0.8879386  0.9540582  0.9134755  0.99365413 0.60991895
 0.82795686 0.79439443]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02
 1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01
 1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03
 1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01
 1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03
 7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02
 4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05
 5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01]
复制代码


该检测器不如基于 Transformer 的 K-S 漂移检测器灵敏。从头开始训练的 embeddings 只在一个小数据集和一个具有交叉熵损失函数的简单模型上训练了 2 个 epoch。 另一方面,预训练的 BERT 模型可以更好地捕捉数据的语义。

来自 Reuters 数据集的样本:

preds_ood = cd.predict(X_ood)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))
print('p-value: {}'.format(preds_ood['data']['p_val']))
复制代码


运行结果:

Drift? Yes!
p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06]
复制代码
相关文章
|
6月前
|
数据采集 数据可视化 数据格式
3D检测数据集 DAIR-V2X-V 转为Kitti格式 | 可视化
本文分享在DAIR-V2X-V数据集中,将标签转为Kitti格式,并可视化3D检测效果。
102 0
|
机器学习/深度学习 PyTorch TensorFlow
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(3)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
|
TensorFlow 算法框架/工具
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(2)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
|
自然语言处理 PyTorch TensorFlow
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(1)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样
这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。
501 0
YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样
|
数据可视化 PyTorch TensorFlow
在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(上)
方法 Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p ​ 和\mu_{q}μ q ​ 的2个分布p和q之间的基于距离的度量:
|
存储 PyTorch TensorFlow
在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(下)
方法 Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p ​ 和\mu_{q}μ q ​ 的2个分布p和q之间的基于距离的度量:
|
计算机视觉
CV之IS:利用pixellib库基于mask_rcnn_coco模型对《庆余年》片段实现实例分割简单代码全实现
CV之IS:利用pixellib库基于mask_rcnn_coco模型对《庆余年》片段实现实例分割简单代码全实现
CV之IS:利用pixellib库基于mask_rcnn_coco模型对《庆余年》片段实现实例分割简单代码全实现
|
算法 数据挖掘 Python
ML之NB:利用朴素贝叶斯NB算法(CountVectorizer+不去除停用词)对fetch_20newsgroups数据集(20类新闻文本)进行分类预测、评估
ML之NB:利用朴素贝叶斯NB算法(CountVectorizer+不去除停用词)对fetch_20newsgroups数据集(20类新闻文本)进行分类预测、评估
ML之NB:利用朴素贝叶斯NB算法(CountVectorizer+不去除停用词)对fetch_20newsgroups数据集(20类新闻文本)进行分类预测、评估