在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(4)

简介: 我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x),因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。


我们来看一个实例:

print_sentence(X_train[0], token2word)
复制代码


运行结果:

cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
[1415   33    6   22   12  215   28   77   52    5   14  407   16   82
    2    8    4  107  117 5952   15  256    4    2    7 3766    5  723
   36   71   43  530  476   26  400  317   46    7    4    2 1029   13
  104   88    4  381   15  297   98   32 2071   56   26  141    6  194
 7486   18    4  226   22   21  134  476   26  480    5  144   30 5535
   18   51   36   28  224   92   25  104    4  226   65   16   38 1334
   88   12   16  283    5   16 4472  113  103   32   15   16 5345   19
  178   32]
复制代码


定义和训练一个简单的模型:

model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)
model.fit(X_train, y_train, batch_size=32, epochs=2,
          shuffle=True, validation_data=(X_test, y_test))
复制代码


运行结果:

Epoch 1/2
782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474
Epoch 2/2
782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451
复制代码


从训练好的模型中提取嵌入层并结合UAE预处理步骤:

embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)
x_emb = embedding(X_train[:5])
print(x_emb.shape)
复制代码


运行结果:

(5, 100, 256)
复制代码



tf.random.set_seed(0)
shape = tuple(x_emb.shape[1:])
uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)
复制代码


同样,创建参考、H0 和扰动数据集。 还针对Reuters新闻主题分类数据集进行测试。

X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
tokens = [word2token[w] for w in words]
X_word = {}
for i, t in enumerate(tokens):
    X_word[words[i]] = {}
    for p in perc_chg:
        X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first')
复制代码



# load and tokenize Reuters dataset
(X_reut, y_reut), (w2t_reut, t2w_reut) = \
    get_dataset(dataset='reuters', max_len=max_len)[1:]
# sample random instances
idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)
X_ood = X_reut[idx]
复制代码


初始化检测器并检测漂移

from alibi_detect.cd.tensorflow import preprocess_drift
# define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model`
def convert_list(X: list):
    return np.array(X)
# define preprocessing function
preprocess_fn = partial(preprocess_drift, 
    model=uae, 
    batch_size=128, 
    preprocess_batch_fn=convert_list)
# initialize detector
cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn)
复制代码



H0数据集:

preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
复制代码


运行结果:

Drift? No!
p-value: [0.18111965 0.50035924 0.5360543  0.722555   0.2406036  0.02925058 0.43243074 0.12050407 0.722555   0.60991895 0.19951835 0.60991895 0.50035924 0.79439443 0.722555   0.64755726 0.40047103 0.34099194 0.1338343  0.10828251 0.64755726 0.9995433  0.9540582  0.9134755 0.40047103 0.1640792  0.40047103 0.64755726 0.9134755  0.7590978 0.5726548  0.722555  ]
复制代码


扰动数据集:


for w, probas in X_word.items():
    for p, v in probas.items():
        preds = cd.predict(v)
        print('Word: {} -- % perturbed: {}'.format(w, p))
        print('Drift? {}'.format(labels[preds['data']['is_drift']]))
        print('p-value: {}'.format(preds['data']['p_val']))
        print('')
复制代码

运行结果:

Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.9998709  0.7590978  0.99870795 0.9995433  0.9801618  0.9134755
 0.82795686 0.99870795 0.9882611  0.8879386  0.9801618  0.79439443
 0.85929435 0.96887016 0.9134755  0.996931   0.5726548  0.93558097
 0.9882611  0.99870795 0.93558097 0.96887016 0.85929435 0.9882611
 0.93558097 0.996931   0.996931   0.96887016 0.9882611  0.96887016
 0.8879386  0.996931  ]
Word: fantastic -- % perturbed: 5.0
Drift? No!
p-value: [0.85929435 0.06155144 0.9540582  0.79439443 0.43243074 0.6852314
 0.722555   0.9134755  0.28769323 0.996931   0.60991895 0.19951835
 0.43243074 0.64755726 0.722555   0.8879386  0.18111965 0.18111965
 0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108
 0.722555   0.46576622 0.07762147 0.8879386  0.05464633 0.10828251
 0.03327804 0.9801618 ]
Word: good -- % perturbed: 1.0
Drift? No!
p-value: [0.99365413 0.8879386  0.99870795 0.9801618  0.99870795 0.99870795
 0.9134755  0.93558097 0.8879386  0.9995433  0.93558097 0.996931
 0.99999607 0.9995433  0.99870795 0.9801618  0.99870795 0.9801618
 0.8879386  0.996931   0.9134755  0.996931   0.7590978  0.99365413
 0.9540582  0.99870795 0.99870795 0.9998709  0.9801618  0.64755726
 0.9999727  0.8879386 ]
Word: good -- % perturbed: 5.0
Drift? No!
p-value: [0.9882611  0.6852314  0.79439443 0.60991895 0.28769323 0.3699725
 0.28769323 0.6852314  0.79439443 0.31356168 0.99870795 0.85929435
 0.34099194 0.34099194 0.8879386  0.996931   0.96887016 0.96887016
 0.9540582  0.722555   0.19951835 0.9995433  0.3699725  0.722555
 0.1338343  0.9134755  0.5360543  0.26338065 0.85929435 0.2406036
 0.31356168 0.6852314 ]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.93558097 0.996931   0.85929435 0.9540582  0.50035924 0.64755726
 0.82795686 0.85929435 0.82795686 0.9882611  0.82795686 0.9540582
 0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555
 0.93558097 0.93558097 0.64755726 0.99365413 0.5726548  0.9998709
 0.93558097 0.96887016 0.9995433  0.99365413 0.7590978  0.93558097
 0.9882611  0.9134755 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01
 3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01
 2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01
 1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01
 9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02
 6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01
 8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02
 1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: [0.996931   0.9801618  0.96887016 0.79439443 0.79439443 0.5726548
 0.82795686 0.996931   0.43243074 0.93558097 0.79439443 0.82795686
 0.06919032 0.3699725  0.96887016 0.9540582  0.5360543  0.6852314
 0.60991895 0.79439443 0.9540582  0.9801618  0.40047103 0.5726548
 0.82795686 0.8879386  0.9540582  0.9134755  0.99365413 0.60991895
 0.82795686 0.79439443]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02
 1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01
 1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03
 1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01
 1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03
 7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02
 4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05
 5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01]
复制代码


该检测器不如基于 Transformer 的 K-S 漂移检测器灵敏。从头开始训练的 embeddings 只在一个小数据集和一个具有交叉熵损失函数的简单模型上训练了 2 个 epoch。 另一方面,预训练的 BERT 模型可以更好地捕捉数据的语义。

来自 Reuters 数据集的样本:

preds_ood = cd.predict(X_ood)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))
print('p-value: {}'.format(preds_ood['data']['p_val']))
复制代码


运行结果:

Drift? Yes!
p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06]
复制代码
相关文章
|
3月前
|
机器学习/深度学习 JSON 数据可视化
YOLO11-pose关键点检测:训练实战篇 | 自己数据集从labelme标注到生成yolo格式的关键点数据以及训练教程
本文介绍了如何将个人数据集转换为YOLO11-pose所需的数据格式,并详细讲解了手部关键点检测的训练过程。内容涵盖数据集标注、格式转换、配置文件修改及训练参数设置,最终展示了训练结果和预测效果。适用于需要进行关键点检测的研究人员和开发者。
643 0
|
数据采集 数据可视化 数据格式
3D检测数据集 DAIR-V2X-V 转为Kitti格式 | 可视化
本文分享在DAIR-V2X-V数据集中,将标签转为Kitti格式,并可视化3D检测效果。
272 0
|
数据可视化 PyTorch TensorFlow
在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(上)
方法 Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p ​ 和\mu_{q}μ q ​ 的2个分布p和q之间的基于距离的度量:
|
机器学习/深度学习 自然语言处理 算法
Text to Image 文本生成图像定量评价指标分析笔记 Metric Value总结 IS、FID、R-prec等
Text to Image 文本生成图像定量评价指标分析笔记 Metric Value总结 IS、FID、R-prec等
Text to Image 文本生成图像定量评价指标分析笔记 Metric Value总结 IS、FID、R-prec等
|
TensorFlow 算法框架/工具
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(2)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
|
自然语言处理 PyTorch TensorFlow
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(1)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
|
机器学习/深度学习 PyTorch TensorFlow
在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(3)
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。 在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x), 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
|
存储 PyTorch TensorFlow
在 CIFAR-10 数据集上使用最大均值差异(MMD)漂移检测器(Seldon Alibi Detect)(下)
方法 Maximum Mean Discrepancy (MMD)检测器是一种基于核的多元2样本测试方法。MMD是基于再生核希尔伯特空间 F 中的平均嵌入\mu_{p}μ p ​ 和\mu_{q}μ q ​ 的2个分布p和q之间的基于距离的度量:
逆向将物体检测数据集生成labelme标注的数据
逆向将物体检测数据集生成labelme标注的数据
293 0
逆向将物体检测数据集生成labelme标注的数据