在 PyTorch 中实现可解释的神经网络模型

简介: 在 PyTorch 中实现可解释的神经网络模型

动动发财的小手,点个赞吧!

目的

深度学习系统缺乏可解释性对建立人类信任构成了重大挑战。这些模型的复杂性使人类几乎不可能理解其决策背后的根本原因。

深度学习系统缺乏可解释性阻碍了人类的信任。

为了解决这个问题,研究人员一直在积极研究新的解决方案,从而产生了重大创新,例如基于概念的模型。这些模型不仅提高了模型的透明度,而且通过在训练过程中结合高级人类可解释的概念(如“颜色”或“形状”),培养了对系统决策的新信任感。因此,这些模型可以根据学习到的概念为其预测提供简单直观的解释,从而使人们能够检查其决策背后的原因。这还不是全部!它们甚至允许人类与学习到的概念进行交互,让我们能够控制最终的决定。

基于概念的模型允许人类检查深度学习预测背后的推理,并让我们重新控制最终决策。

这篇博文中,我们将深入研究这些技术,并为您提供使用简单的 PyTorch 接口实现最先进的基于概念的模型的工具。通过实践经验,您将学习如何利用这些强大的模型来增强可解释性并最终校准人类对您的深度学习系统的信任。

概念瓶颈模型

在这个介绍中,我们将深入探讨概念瓶颈模型。这模型在 2020 年国际机器学习会议上发表的一篇论文中介绍,旨在首先学习和预测一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务:

通过遵循这种方法,我们可以将预测追溯到提供解释的概念,例如“输入对象是一个{apple},因为它是{spherical}和{red}。”

概念瓶颈模型首先学习一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务。

实现

为了说明概念瓶颈模型,我们将重新审视著名的 XOR 问题,但有所不同。我们的输入将包含两个连续的特征。为了捕捉这些特征的本质,我们将使用概念编码器将它们映射为两个有意义的概念,表示为“A”和“B”。我们任务的目标是预测“A”和“B”的异或 (XOR)。通过这个例子,您将更好地理解概念瓶颈如何在实践中应用,并见证它们在解决具体问题方面的有效性。

我们可以从导入必要的库并加载这个简单的数据集开始:

import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

接下来,我们实例化一个概念编码器以将输入特征映射到概念空间,并实例化一个任务预测器以将概念映射到任务预测:

concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(10, 8),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(8, c.shape[1]),
    torch.nn.Sigmoid(),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1], 8),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(8, 1),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

然后我们通过优化概念和任务的交叉熵损失来训练网络:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(2001):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_pred)

    # update loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.2*task_loss

    loss.backward()
    optimizer.step()

训练模型后,我们评估其在测试集上的性能:

c_pred = concept_encoder(x_test)
y_pred = task_predictor(c_pred)

concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
task_accuracy = accuracy_score(y_test, y_pred > 0)

现在,在几个 epoch 之后,我们可以观察到概念和任务在测试集上的准确性都非常好(~98% 的准确性)!

由于这种架构,我们可以通过根据输入概念查看任务预测器的响应来为模型预测提供解释,如下所示:

c_different = torch.FloatTensor([0, 1])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

c_equal = torch.FloatTensor([1, 1])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

这会产生例如 f([0,1])=1 和 f([1,1])=0 ,如预期的那样。这使我们能够更多地了解模型的行为,并检查它对于任何相关概念集的行为是否符合预期,例如,对于互斥的输入概念 [0,1] 或 [1,0],它返回的预测y=1。

概念瓶颈模型通过将预测追溯到概念来提供直观的解释。

淹没在准确性与可解释性的权衡中

概念瓶颈模型的主要优势之一是它们能够通过揭示概念预测模式来为预测提供解释,从而使人们能够评估模型的推理是否符合他们的期望。

然而,标准概念瓶颈模型的主要问题是它们难以解决复杂问题!更一般地说,他们遇到了可解释人工智能中众所周知的一个众所周知的问题,称为准确性-可解释性权衡。实际上,我们希望模型不仅能实现高任务性能,还能提供高质量的解释。不幸的是,在许多情况下,当我们追求更高的准确性时,模型提供的解释往往会在质量和忠实度上下降,反之亦然。

在视觉上,这种权衡可以表示如下:

可解释模型擅长提供高质量的解释,但难以解决具有挑战性的任务,而黑盒模型以提供脆弱和糟糕的解释为代价来实现高任务准确性。

为了在具体设置中说明这种权衡,让我们考虑一个概念瓶颈模型,该模型应用于要求稍高的基准,即“三角学”数据集:

x, c, y = datasets.trigonometry(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

在该数据集上训练相同的网络架构后,我们观察到任务准确性显着降低,仅达到 80% 左右。

概念瓶颈模型未能在任务准确性和解释质量之间取得平衡。

这就引出了一个问题:我们是永远被迫在准确性和解释质量之间做出选择,还是有办法取得更好的平衡?

相关文章
|
3天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Keras构建一个简单的神经网络模型
使用Keras构建一个简单的神经网络模型
|
2天前
|
网络协议 前端开发 网络安全
网络通信基础(网络通信基本概念+TCP/IP 模型)
网络通信基础(网络通信基本概念+TCP/IP 模型)
|
2天前
|
机器学习/深度学习 自然语言处理 数据可视化
揭秘深度学习模型中的“黑箱”:理解与优化网络决策过程
【5月更文挑战第28天】 在深度学习领域,神经网络因其卓越的性能被广泛应用于图像识别、自然语言处理等任务。然而,这些复杂的模型往往被视作“黑箱”,其内部决策过程难以解释。本文将深入探讨深度学习模型的可解释性问题,并提出几种方法来揭示和优化网络的决策机制。我们将从模型可视化、敏感性分析到高级解释框架,一步步剖析模型行为,旨在为研究者提供更透明、可靠的深度学习解决方案。
|
2天前
|
存储 缓存 网络协议
网络 (基础概念, OSI 七层模型, TCP/IP 五层模型)
网络 (基础概念, OSI 七层模型, TCP/IP 五层模型)
10 1
|
8天前
|
存储 网络协议 Linux
【Linux 网络】网络基础(一)(局域网、广域网、网络协议、TCP/IP结构模型、网络传输、封装和分用)-- 详解(下)
【Linux 网络】网络基础(一)(局域网、广域网、网络协议、TCP/IP结构模型、网络传输、封装和分用)-- 详解(下)
|
8天前
|
存储 网络协议 安全
【Linux 网络】网络基础(一)(局域网、广域网、网络协议、TCP/IP结构模型、网络传输、封装和分用)-- 详解(上)
【Linux 网络】网络基础(一)(局域网、广域网、网络协议、TCP/IP结构模型、网络传输、封装和分用)-- 详解(上)
|
8天前
|
机器学习/深度学习 算法 计算机视觉
基于yolov2深度学习网络模型的鱼眼镜头中人员检测算法matlab仿真
该内容是一个关于基于YOLOv2的鱼眼镜头人员检测算法的介绍。展示了算法运行的三张效果图,使用的是matlab2022a软件。YOLOv2模型结合鱼眼镜头畸变校正技术,对鱼眼图像中的人员进行准确检测。算法流程包括图像预处理、网络前向传播、边界框预测与分类及后处理。核心程序段加载预训练的YOLOv2检测器,遍历并处理图像,检测到的目标用矩形标注显示。
|
10天前
|
机器学习/深度学习 人工智能 算法
食物识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
食物识别系统采用TensorFlow的ResNet50模型,训练了包含11类食物的数据集,生成高精度H5模型。系统整合Django框架,提供网页平台,用户可上传图片进行食物识别。效果图片展示成功识别各类食物。[查看演示视频、代码及安装指南](https://www.yuque.com/ziwu/yygu3z/yhd6a7vai4o9iuys?singleDoc#)。项目利用深度学习的卷积神经网络(CNN),其局部感受野和权重共享机制适于图像识别,广泛应用于医疗图像分析等领域。示例代码展示了一个使用TensorFlow训练的简单CNN模型,用于MNIST手写数字识别。
32 3
|
13天前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
25 1
|
13天前
|
机器学习/深度学习 PyTorch 算法框架/工具
神经网络基本概念以及Pytorch实现,多线程编程面试题
神经网络基本概念以及Pytorch实现,多线程编程面试题