我们来看一个实例:
print_sentence(X_train[0], token2word) 复制代码
运行结果:
cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all [1415 33 6 22 12 215 28 77 52 5 14 407 16 82 2 8 4 107 117 5952 15 256 4 2 7 3766 5 723 36 71 43 530 476 26 400 317 46 7 4 2 1029 13 104 88 4 381 15 297 98 32 2071 56 26 141 6 194 7486 18 4 226 22 21 134 476 26 480 5 144 30 5535 18 51 36 28 224 92 25 104 4 226 65 16 38 1334 88 12 16 283 5 16 4472 113 103 32 15 16 5345 19 178 32] 复制代码
定义和训练一个简单的模型:
model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2) model.fit(X_train, y_train, batch_size=32, epochs=2, shuffle=True, validation_data=(X_test, y_test)) 复制代码
运行结果:
Epoch 1/2 782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474 Epoch 2/2 782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451 复制代码
从训练好的模型中提取嵌入层并结合UAE预处理步骤:
embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output) x_emb = embedding(X_train[:5]) print(x_emb.shape) 复制代码
运行结果:
(5, 100, 256) 复制代码
tf.random.set_seed(0) shape = tuple(x_emb.shape[1:]) uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim) 复制代码
同样,创建参考、H0 和扰动数据集。 还针对Reuters新闻主题分类数据集进行测试。
X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample) X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample) tokens = [word2token[w] for w in words] X_word = {} for i, t in enumerate(tokens): X_word[words[i]] = {} for p in perc_chg: X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first') 复制代码
# load and tokenize Reuters dataset (X_reut, y_reut), (w2t_reut, t2w_reut) = \ get_dataset(dataset='reuters', max_len=max_len)[1:] # sample random instances idx = np.random.choice(X_reut.shape[0], n_sample, replace=False) X_ood = X_reut[idx] 复制代码
初始化检测器并检测漂移
from alibi_detect.cd.tensorflow import preprocess_drift # define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model` def convert_list(X: list): return np.array(X) # define preprocessing function preprocess_fn = partial(preprocess_drift, model=uae, batch_size=128, preprocess_batch_fn=convert_list) # initialize detector cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn) 复制代码
H0数据集:
preds_h0 = cd.predict(X_h0) labels = ['No!', 'Yes!'] print('Drift? {}'.format(labels[preds_h0['data']['is_drift']])) print('p-value: {}'.format(preds_h0['data']['p_val'])) 复制代码
运行结果:
Drift? No! p-value: [0.18111965 0.50035924 0.5360543 0.722555 0.2406036 0.02925058 0.43243074 0.12050407 0.722555 0.60991895 0.19951835 0.60991895 0.50035924 0.79439443 0.722555 0.64755726 0.40047103 0.34099194 0.1338343 0.10828251 0.64755726 0.9995433 0.9540582 0.9134755 0.40047103 0.1640792 0.40047103 0.64755726 0.9134755 0.7590978 0.5726548 0.722555 ] 复制代码
扰动数据集:
for w, probas in X_word.items(): for p, v in probas.items(): preds = cd.predict(v) print('Word: {} -- % perturbed: {}'.format(w, p)) print('Drift? {}'.format(labels[preds['data']['is_drift']])) print('p-value: {}'.format(preds['data']['p_val'])) print('') 复制代码
运行结果:
Word: fantastic -- % perturbed: 1.0 Drift? No! p-value: [0.9998709 0.7590978 0.99870795 0.9995433 0.9801618 0.9134755 0.82795686 0.99870795 0.9882611 0.8879386 0.9801618 0.79439443 0.85929435 0.96887016 0.9134755 0.996931 0.5726548 0.93558097 0.9882611 0.99870795 0.93558097 0.96887016 0.85929435 0.9882611 0.93558097 0.996931 0.996931 0.96887016 0.9882611 0.96887016 0.8879386 0.996931 ] Word: fantastic -- % perturbed: 5.0 Drift? No! p-value: [0.85929435 0.06155144 0.9540582 0.79439443 0.43243074 0.6852314 0.722555 0.9134755 0.28769323 0.996931 0.60991895 0.19951835 0.43243074 0.64755726 0.722555 0.8879386 0.18111965 0.18111965 0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108 0.722555 0.46576622 0.07762147 0.8879386 0.05464633 0.10828251 0.03327804 0.9801618 ] Word: good -- % perturbed: 1.0 Drift? No! p-value: [0.99365413 0.8879386 0.99870795 0.9801618 0.99870795 0.99870795 0.9134755 0.93558097 0.8879386 0.9995433 0.93558097 0.996931 0.99999607 0.9995433 0.99870795 0.9801618 0.99870795 0.9801618 0.8879386 0.996931 0.9134755 0.996931 0.7590978 0.99365413 0.9540582 0.99870795 0.99870795 0.9998709 0.9801618 0.64755726 0.9999727 0.8879386 ] Word: good -- % perturbed: 5.0 Drift? No! p-value: [0.9882611 0.6852314 0.79439443 0.60991895 0.28769323 0.3699725 0.28769323 0.6852314 0.79439443 0.31356168 0.99870795 0.85929435 0.34099194 0.34099194 0.8879386 0.996931 0.96887016 0.96887016 0.9540582 0.722555 0.19951835 0.9995433 0.3699725 0.722555 0.1338343 0.9134755 0.5360543 0.26338065 0.85929435 0.2406036 0.31356168 0.6852314 ] Word: bad -- % perturbed: 1.0 Drift? No! p-value: [0.93558097 0.996931 0.85929435 0.9540582 0.50035924 0.64755726 0.82795686 0.85929435 0.82795686 0.9882611 0.82795686 0.9540582 0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555 0.93558097 0.93558097 0.64755726 0.99365413 0.5726548 0.9998709 0.93558097 0.96887016 0.9995433 0.99365413 0.7590978 0.93558097 0.9882611 0.9134755 ] Word: bad -- % perturbed: 5.0 Drift? Yes! p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01 3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01 2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01 1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01 9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02 6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01 8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02 1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01] Word: horrible -- % perturbed: 1.0 Drift? No! p-value: [0.996931 0.9801618 0.96887016 0.79439443 0.79439443 0.5726548 0.82795686 0.996931 0.43243074 0.93558097 0.79439443 0.82795686 0.06919032 0.3699725 0.96887016 0.9540582 0.5360543 0.6852314 0.60991895 0.79439443 0.9540582 0.9801618 0.40047103 0.5726548 0.82795686 0.8879386 0.9540582 0.9134755 0.99365413 0.60991895 0.82795686 0.79439443] Word: horrible -- % perturbed: 5.0 Drift? Yes! p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02 1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01 1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03 1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01 1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03 7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02 4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05 5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01] 复制代码
该检测器不如基于 Transformer 的 K-S 漂移检测器灵敏。从头开始训练的 embeddings 只在一个小数据集和一个具有交叉熵损失函数的简单模型上训练了 2 个 epoch。 另一方面,预训练的 BERT 模型可以更好地捕捉数据的语义。
来自 Reuters 数据集的样本:
preds_ood = cd.predict(X_ood) labels = ['No!', 'Yes!'] print('Drift? {}'.format(labels[preds_ood['data']['is_drift']])) print('p-value: {}'.format(preds_ood['data']['p_val'])) 复制代码
运行结果:
Drift? Yes! p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06] 复制代码