【博士每天一篇文献-算法】Seeing is believing_ Brain-inspired modular training for mechanistic interpretability

简介: 这篇文章提出了一种模仿大脑结构和功能的训练正则化方法,称为大脑启发的模块化训练(BIMT),通过在几何空间中嵌入神经元并增加与连接长度成比例的正则化项来促进神经网络的模块化和稀疏化,增强了网络的可解释性,并在多种任务和数据集上验证了其有效性。

阅读时间:2023-12-10

1 介绍

年份:2023
作者:刘子明,MIT
期刊: Entropy
引用量:22
提出一种模仿大脑结构和功能的训练正则化方法。该方法通过在将神经元映射到几何空间中(2D或3D),并在损失函数中增加一个考虑距离、连接成本、偏置连接成本的正则化项,实现神经网络的模块化和稀疏化,在全连接网络、CNN、Transformer终验证了有效性。并在多种数据集上进行了可视化分析,包括符号回归数据集、Two Moon数据集、加法数据集、置换群S4数据集、MNIST数据集。

2 创新点

(1)大脑启发的模块化训练(BIMT):提出了一种新的神经网络训练方法,该方法模仿大脑的模块化结构,通过在几何空间中嵌入神经元并增加与神经元连接长度成比例的损失函数项,来鼓励网络形成模块化结构。
(2) 损失函数的创新:在标准的损失函数中加入了一个新的正则化项,该项与神经元之间连接的长度和权重的绝对值成正比。这种设计受到进化生物学中最小连接成本理论的启发,旨在促进网络的局部性,即相关联的神经元在几何空间中彼此靠近。
(3) 模块性和可解释性的结合:BIMT不仅提高了网络的模块化,而且增强了网络的可解释性。通过模块化,网络能够以更直观的方式展示其内部结构和决策过程,使得网络的行为更容易被人类理解和解释。
(4) 几何空间中的神经元嵌入:BIMT通过在几何空间中定义神经元的位置,引入了一种新的网络结构表示方式。这种方法允许网络在保持连接的局部性的同时,自然地形成模块化结构。
(5) 实验验证:作者在多种任务上测试了BIMT方法,包括符号公式、分类任务和算法任务。
(6) 网络结构的可视化:通过可视化技术,如连接图,展示了BIMT训练的网络结构,使得模块化结构和决策边界等特征可以直观地被观察到。
(7) 扩展到不同数据类型和网络架构:论文还探讨了BIMT方法在不同类型的数据(如图像)和网络架构(如变换器)上的应用,显示了BIMT的通用性和扩展性。
(8) 性能与可解释性的平衡:尽管BIMT可能会导致性能的小幅下降,但它提供了一种在保持网络性能的同时提高可解释性的途径,这对于构建更可靠和安全的AI系统具有重要意义。

3 相关研究

3.1 模块化

  1. Pfeiffer J, Ruder S, Vulić I, et al. Modular deep learning[J]. arXiv preprint arXiv:2302.11529, 2023.
  2. Hod S, Casper S, Filan D, et al. Detecting modularity in deep neural networks[J]. 2021. 【代码
  3. Csordás R, van Steenkiste S, Schmidhuber J. Are neural nets modular? inspecting functional modularity through differentiable weight masks[J]. arXiv preprint arXiv:2010.02066, 2020. 【代码
  4. Kirsch L, Kunze J, Barber D. Modular networks: Learning to decompose neural computation[J]. Advances in neural information processing systems, 2018, 31.
  5. Azam F. Biologically inspired modular neural networks[D]. Virginia Polytechnic Institute and State University, 2000.

3.2 剪枝

  1. Han S, Pool J, Tran J, et al. Learning both weights and connections for efficient neural network[J]. Advances in neural information processing systems, 2015, 28.
  2. Anwar S, Hwang K, Sung W. Structured pruning of deep convolutional neural networks[J]. ACM Journal on Emerging Technologies in Computing Systems (JETC), 2017, 13(3): 1-18. 【代码
  3. Blalock D, Gonzalez Ortiz J J, Frankle J, et al. What is the state of neural network pruning?[J]. Proceedings of machine learning and systems, 2020, 2: 129-146. 【代码
  4. Frankle J, Carbin M. The lottery ticket hypothesis: Finding sparse, trainable neural networks[J]. arXiv preprint arXiv:1803.03635, 2018. 【代码

4 算法

(1)第一步: 将网络嵌入到几何空间中
将整个网络嵌入到一个空间中,其中第i层的第j个神经元位于
$ r_{ij} $​。如果这是二维欧几里得空间,同一神经元层中的神经元共享相同的y坐标,并在x ∈ [0, A](A > 0)中均匀间隔。不同的神经元层通过距离y* > 0垂直分隔,因此
$r_{ij} \equiv (x_{ij}, y_{ij}) = (A_j/n_i, iy*) $
使用L1范数,得到
$ d_{ijk} = |x_{i−1,j} − x_{ik}| + y* $
但也可以使用其他向量范数。例如,L2范数给出
$ d_{ijk} = \sqrt{|xi−1,j − xik|^2 + y*^2} $
(2)第二步:添加正则化项

  • $ \ell^w$是权重连接成本(connection cost for weight parameters),它是与网络中每个神经元连接的长度(即距离)与其权重的绝对值的乘积的总和。这个正则化项鼓励神经网络中的神经元之间的连接尽可能地局部化,即相关联的神经元在几何空间中彼此靠近。其定义如下:
    $\ell_w = \sum_{i=1}^{L} \sum_{j=1}^{n_i} \sum_{k=1}^{n_{i+1}} d_{ijk} |w_{ijk}| $
    其中 $ d_{ijk} $​是第 i-1层的第 j个神经元与第i层的第k个神经元之间的距离,$ w{ijk} $是这两个神经元之间的权重。
  • $ \ell^b $是偏置连接成本(connection cost for bias parameters),它是与每个偏置参数的绝对值与一个常数 $ y^* $的乘积的总和。这个正则化项鼓励偏置参数的稀疏性。其定义如下:
    $ \ell_b = \sum_{i=1}^{L} \sum_{j=1}^{n_i} y* |bij| $
    其中 $ b_{ij} $是第i层的第j个神经元的偏置项。

最终损失函数如下,
$ \ell = \ell_{pred} + \lambda (\ell_w + \ell_b) $
其中 ℓ p r e d \ell_{pred} ℓpred​是预测损失,λ是正则化项的强度。论文中设y* = 1,只留下两个超参数λ和A。将A设为0相当于标准L1正则化,只鼓励稀疏性。A > 0除了稀疏性外,还鼓励局部性。
(3)第三步:交换神经元以获得更好的局部性
交换同一神经元层中的两个神经元(即交换相应的输入/输出权重),避免梯度下降可能会陷入不良的局部最小值。然而,尝试所有可能的排列是极其昂贵的。为每个神经元(i, j)分配一个分数sij,以表示其重要性:
$ ∣ s_{ij} = \sum_{p=1}^{n_{i-1}} |w_{ipj}| + \sum_{q=1}^{n_{i+1}} |w_{i+1, jq}| $
这是进入和离开第i层第j个神经元的权重(绝对值)的总和,在同一层中根据它们的分数对神经元进行排序,并定义分数最高的k个神经元为“重要”神经元。对于每个重要神经元,将其与同一层中导致 ℓw​减少最大的神经元交换。由于交换在计算上相当昂贵,需要O(nkL)次计算,仅在每S≫1个训练步骤中实现交换。我们还允许输入和输出神经元的交换。

5 实验分析

(1)不同的正则化方法对比
image.png

image.png
使用不同技术对回归问题进行训练时神经网络的连接图,通过模块化和稀疏性提高了网络的可解释性。但是精度下降。

(2)符号回归数据集
image.png

  1. 独立性 (Independence): 展示了一个网络结构,其中有两个独立的模块处理互不相关的输入变量。例如,一个模块处理(x1, x3),另一个处理(x2, x4),它们分别计算不同的输出(y1, y2)。
  2. 特征共享 (Feature Sharing): 展示了一个网络结构,其中某些输入特征(x1, x2, x3)被多个输出(y1, y2, y3)共享。网络通过共享特征来提高效率,减少重复计算。
  3. 组合性 (Compositionality): 展示了一个网络结构,其中计算输出y需要先计算一个中间量I,然后使用这个中间量来计算最终结果。网络中有一个特定的神经元专门负责计算这个中间量I,显示了模块化处理的层次结构。

BIMT能够地发现并利用符号公式中的模块化结构,从而提高网络的可解释性,并在一定程度上保持或提高性能。
(3)Two Moon数据集
image.png
可视化网络是如何依赖权重来进行分类的。在第一和第二阶段,需要两个神经元作为输出来判断两个类别,在第三阶段中,虽然只有一个神经元输出值,在输出层,激活函数(如softmax函数)可以将神经元的激活值转换为概率形式,从而允许模型为每个输入样本输出两个类别的概率。即使网络结构不对称,只要权重和激活函数适当,模型仍然可以区分两个类别。
(4)加法数据集
image.png

模59的加法预测任务,即从两个数a和b预测它们的和c。网络通过嵌入层接收a和b的表示,并使用两层隐藏层进行预测。左一是三个并行的模块,这表明网络通过模块化的方式进行学习和预测。
中间图可视化了网络中的嵌入表示,发现嵌入向量在二维和三维空间中形成了特定的结构,如圆形和蝴蝶结形状,这有助于直观理解网络如何处理不同的输入。
右一说明敲除任何一个模块都会显著降低性能,这表明模块之间通过类似多数投票的机制协同工作。其中“All but A,B,C”是只保留模块化后是神经元,性能能达到100%并且网络具有可解释性。“Knockout A,B,C”是移除A,B,C后,利用冗余的神经元进行预测,准确率只有1.69%。
(5)置换群S4数据集

image.png
其中神经元22作为一个符号神经元,它根据排列是偶数或奇数来激活,输出1或-1。这有助于区分不同的排列类型。BIMT在帮助神经网络学习模块化、稀疏化结构以及捕捉和利用数学上的群论结构方面的有效性。这些特性不仅提高了网络的可解释性,还可能增强其在复杂任务上的性能。
image.png
凯莱图(Cayley graphs)是一种用于表示群的代数结构的图形工具,能够将群的元素和它们之间的关系直观地展现出来。通过凯莱图来可视化在图6中活跃的神经元对应的群元素的活动,不同的颜色用于表示神经元激活的不同值,绿色表示+1(激活或正权重),橙色表示-1(抑制或负权重),没有圆圈或特定颜色可能表示0(无连接或权重为零)。
活跃的神经元可能与特定的群元素或群的子集相关联。揭示了神经元如何响应群的不同元素,以及它们如何组合来执行特定的计算任务。
(6)transformer的用于线性回归
image.png
BIMT 训练后的网络可以直观地识别出哪些神经元是活跃的,并且它们是如何编码权重信息的,提供了更好的解释性。右侧的散点图展示了Res2层中的活跃神经元共同编码了一些隐含信息,散点图是规律的,权重标量(颜色)的分布也是有规律的,这表明多个神经元以非线性的方式共同编码了权重信息。
(7)CNN网络图片分类

image.png
将神经网络嵌入到3D欧几里得空间中,以保持输入图像的局部结构。BIMT学习到了去除始终为零的外围像素,导致输入层的感知野缩小。观察到中间层的大多数权重为负值,而输出层的大多数权重为正值,这可能表明中间层采取了与模式匹配不同的策略。

6 思考

(1)这篇论文的创新点和《Spatially embedded recurrent neural networks reveal widespread links between structural and functional neuroscience findings》在正则化项上有一些区别。
在seRNN中正则化项只是考虑了距离和连接连接成本的乘积,BIMT在seRNN正则项的基础上是加上了偏置连接成本的正则项,并且在算法中考了交换神经元位置。此外,BIMIT中计算神经元的距离是以L1范数(即曼哈顿距离),而seRNN是以L2范数(欧式距离)。


seRNN的正则项

$ \ell = \ell_{pred} + \lambda (\ell_w + \ell_b) $
其中 $ \ell_w = \sum_{i=1}^{L} \sum_{j=1}^{n_i} \sum_{k=1}^{n_{i+1}} d_{ijk} |w_{ijk}| $, $ \ell_b = \sum_{i=1}^{L} \sum_{j=1}^{n_i} y* |bij|$
BIMT的正则项
$ d_{ijk} = \sqrt{\sum_{m=1}^{M} (r_{i-1,j}^m - r_{i,k}^m)^2} $​
欧式距离(L2 范数)
$ d_{ijk} = \sum_{m=1}^{M} |r_{i-1,j}^m - r_{i,k}^m| $
曼哈顿距离(L1 范数)

(2)模块化正则项和剪枝有什么异同?
模块化的正则项促进网络的模块化,使得网络能够自然形成处理不同任务或特征的独立模块。增强网络的可解释性,使得网络结构更易于理解和分析。正则项是一种软性约束,它通过损失函数来间接影响网络结构。它并不直接减少网络的参数数量,而是通过优化过程鼓励形成模块化的网络结构。
剪枝(Pruning)是减少网络的复杂度,去除不重要的连接或神经元,以提高效率。通过减少参数数量来防止过拟合,加速网络的推理过程。剪枝是通过设定阈值,将小于该阈值的权重置为零,实质上移除了这些连接。可以是后训练过程,即在网络训练完成后进行剪枝。剪枝是一种硬性约束,直接减少网络的参数数量。它通常关注于减少网络的稀疏性,并不特别强调网络的模块化结构。
相同点是两者都可以提高网络的效率和可解释性,都可以在训练过程中或之后应用。
不同点是模块化正则项更侧重于形成模块化的网络结构,而剪枝侧重于减少网络大小和参数数量。模块化正则项通过损失函数中的正则项来间接影响网络结构,剪枝则是通过直接移除权重来减少网络复杂度。模块化正则项可能不会显著减少参数数量,但会促进模块化;剪枝会直接减少网络的参数数量,但可能不会形成明显的模块化结构。

目录
相关文章
|
4月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
87 12
|
4月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
292 9
|
4月前
|
机器学习/深度学习 算法 机器人
【博士每天一篇文献-算法】改进的PNN架构Lifelong learning with dynamically expandable networks
本文介绍了一种名为Dynamically Expandable Network(DEN)的深度神经网络架构,它能够在学习新任务的同时保持对旧任务的记忆,并通过动态扩展网络容量和选择性重训练机制,有效防止语义漂移,实现终身学习。
65 9
|
4月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks
本文提出了一种基于任务条件超网络(Hypernetworks)的持续学习模型,通过超网络生成目标网络权重并结合正则化技术减少灾难性遗忘,实现有效的任务顺序学习与长期记忆保持。
55 4
|
4月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
103 3
|
1天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
102 80
|
20天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
26天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
6天前
|
供应链 算法 调度
排队算法的matlab仿真,带GUI界面
该程序使用MATLAB 2022A版本实现排队算法的仿真,并带有GUI界面。程序支持单队列单服务台、单队列多服务台和多队列多服务台三种排队方式。核心函数`func_mms2`通过模拟到达时间和服务时间,计算阻塞率和利用率。排队论研究系统中顾客和服务台的交互行为,广泛应用于通信网络、生产调度和服务行业等领域,旨在优化系统性能,减少等待时间,提高资源利用率。
|
14天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。