在 IMDB 电影评论数据集上进行文本数据漂移检测(Seldon Alibi Detect)(4)

简介: 我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。在这个示例中,我们将专注于检测协变量漂移Δp(x)\Delta p(x)Δp(x),因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。


print_sentence(X_train[0], token2word)


cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
[1415   33    6   22   12  215   28   77   52    5   14  407   16   82
    2    8    4  107  117 5952   15  256    4    2    7 3766    5  723
   36   71   43  530  476   26  400  317   46    7    4    2 1029   13
  104   88    4  381   15  297   98   32 2071   56   26  141    6  194
 7486   18    4  226   22   21  134  476   26  480    5  144   30 5535
   18   51   36   28  224   92   25  104    4  226   65   16   38 1334
   88   12   16  283    5   16 4472  113  103   32   15   16 5345   19
  178   32]


model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)
model.fit(X_train, y_train, batch_size=32, epochs=2,
          shuffle=True, validation_data=(X_test, y_test))


Epoch 1/2
782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474
Epoch 2/2
782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451


embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)
x_emb = embedding(X_train[:5])


(5, 100, 256)

shape = tuple(x_emb.shape[1:])
uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)

同样,创建参考、H0 和扰动数据集。 还针对Reuters新闻主题分类数据集进行测试。

X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
tokens = [word2token[w] for w in words]
X_word = {}
for i, t in enumerate(tokens):
    X_word[words[i]] = {}
    for p in perc_chg:
        X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first')

# load and tokenize Reuters dataset
(X_reut, y_reut), (w2t_reut, t2w_reut) = \
    get_dataset(dataset='reuters', max_len=max_len)[1:]
# sample random instances
idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)
X_ood = X_reut[idx]


from alibi_detect.cd.tensorflow import preprocess_drift
# define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model`
def convert_list(X: list):
    return np.array(X)
# define preprocessing function
preprocess_fn = partial(preprocess_drift, 
# initialize detector
cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn)


preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))


Drift? No!
p-value: [0.18111965 0.50035924 0.5360543  0.722555   0.2406036  0.02925058 0.43243074 0.12050407 0.722555   0.60991895 0.19951835 0.60991895 0.50035924 0.79439443 0.722555   0.64755726 0.40047103 0.34099194 0.1338343  0.10828251 0.64755726 0.9995433  0.9540582  0.9134755 0.40047103 0.1640792  0.40047103 0.64755726 0.9134755  0.7590978 0.5726548  0.722555  ]


for w, probas in X_word.items():
    for p, v in probas.items():
        preds = cd.predict(v)
        print('Word: {} -- % perturbed: {}'.format(w, p))
        print('Drift? {}'.format(labels[preds['data']['is_drift']]))
        print('p-value: {}'.format(preds['data']['p_val']))


Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.9998709  0.7590978  0.99870795 0.9995433  0.9801618  0.9134755
 0.82795686 0.99870795 0.9882611  0.8879386  0.9801618  0.79439443
 0.85929435 0.96887016 0.9134755  0.996931   0.5726548  0.93558097
 0.9882611  0.99870795 0.93558097 0.96887016 0.85929435 0.9882611
 0.93558097 0.996931   0.996931   0.96887016 0.9882611  0.96887016
 0.8879386  0.996931  ]
Word: fantastic -- % perturbed: 5.0
Drift? No!
p-value: [0.85929435 0.06155144 0.9540582  0.79439443 0.43243074 0.6852314
 0.722555   0.9134755  0.28769323 0.996931   0.60991895 0.19951835
 0.43243074 0.64755726 0.722555   0.8879386  0.18111965 0.18111965
 0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108
 0.722555   0.46576622 0.07762147 0.8879386  0.05464633 0.10828251
 0.03327804 0.9801618 ]
Word: good -- % perturbed: 1.0
Drift? No!
p-value: [0.99365413 0.8879386  0.99870795 0.9801618  0.99870795 0.99870795
 0.9134755  0.93558097 0.8879386  0.9995433  0.93558097 0.996931
 0.99999607 0.9995433  0.99870795 0.9801618  0.99870795 0.9801618
 0.8879386  0.996931   0.9134755  0.996931   0.7590978  0.99365413
 0.9540582  0.99870795 0.99870795 0.9998709  0.9801618  0.64755726
 0.9999727  0.8879386 ]
Word: good -- % perturbed: 5.0
Drift? No!
p-value: [0.9882611  0.6852314  0.79439443 0.60991895 0.28769323 0.3699725
 0.28769323 0.6852314  0.79439443 0.31356168 0.99870795 0.85929435
 0.34099194 0.34099194 0.8879386  0.996931   0.96887016 0.96887016
 0.9540582  0.722555   0.19951835 0.9995433  0.3699725  0.722555
 0.1338343  0.9134755  0.5360543  0.26338065 0.85929435 0.2406036
 0.31356168 0.6852314 ]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.93558097 0.996931   0.85929435 0.9540582  0.50035924 0.64755726
 0.82795686 0.85929435 0.82795686 0.9882611  0.82795686 0.9540582
 0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555
 0.93558097 0.93558097 0.64755726 0.99365413 0.5726548  0.9998709
 0.93558097 0.96887016 0.9995433  0.99365413 0.7590978  0.93558097
 0.9882611  0.9134755 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01
 3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01
 2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01
 1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01
 9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02
 6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01
 8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02
 1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: [0.996931   0.9801618  0.96887016 0.79439443 0.79439443 0.5726548
 0.82795686 0.996931   0.43243074 0.93558097 0.79439443 0.82795686
 0.06919032 0.3699725  0.96887016 0.9540582  0.5360543  0.6852314
 0.60991895 0.79439443 0.9540582  0.9801618  0.40047103 0.5726548
 0.82795686 0.8879386  0.9540582  0.9134755  0.99365413 0.60991895
 0.82795686 0.79439443]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02
 1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01
 1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03
 1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01
 1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03
 7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02
 4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05
 5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01]

该检测器不如基于 Transformer 的 K-S 漂移检测器灵敏。从头开始训练的 embeddings 只在一个小数据集和一个具有交叉熵损失函数的简单模型上训练了 2 个 epoch。 另一方面,预训练的 BERT 模型可以更好地捕捉数据的语义。

来自 Reuters 数据集的样本:

preds_ood = cd.predict(X_ood)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))
print('p-value: {}'.format(preds_ood['data']['p_val']))


Drift? Yes!
p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06]
