清华大学朱军详解珠算:贝叶斯深度学习的GPU库(附视频)

简介: 5 月 27-28 日,机器之心在北京 898 创新空间顺利主办了第一届全球机器智能峰会(GMIS 2017)。中国科学院自动化研究所复杂系统管理与控制国家重点实验室主任王飞跃为大会做了开幕式致辞。大会第一天,「LSTM 之父」Jürgen Schmidhuber、Citadel 首席人工智能官邓力、腾讯 AI Lab 副主任俞栋、英特尔 AIPG 数据科学部主任 Yinyin Liu、GE Transportation Digital Solutions CTO Wesly Mukai 等知名人工智能专家参与峰会,并通过主题演讲、圆桌论坛等形式从科学家、企业家、技术专家的视角对人工智能技术前

大会第一天下午,清华大学智能技术与系统国家重点实验室朱军发表了主题为《珠算:贝叶斯深度学习的 GPU 库》的演讲,他探讨分享了贝叶斯深度学习模型的计算平台:珠算。该平台由清华大学机器学习组开发,目前已经在 GitHub 上开源,参阅机器之心之前的报道《清华大学发布珠算:一个用于生成模型的 Python 库》。


珠算项目地址:https://github.com/thu-ml/zhusuan


在 GMIS 2017 大会上,朱军从深度学习谈起,对该项目进行了更加深入的介绍,同时还在深度生成模型、贝叶斯推理等更广泛方面分享了自己的思考。在这篇文章中,机器之心对朱军的演讲内容进行了整理,同时为了更便于阅读,也进行了适当的编辑。


1638101277(1).png点击查看原视频

以下是该演讲视频的主要内容:


谢谢机器之心的邀请,很高兴有这个机会和大家分享一下我们实验室做的计算平台,因为我们是实验室,不像公司里有那么多的人,但我们做的东西是属于比较前沿的。


微信图片_20211128200923.jpg我们研究的是贝叶斯深度学习,首先我跟大家分享一下为什么要关心贝叶斯深度学习。


贝叶斯深度学习


现在深度学习在各个领域里有很多用处。虽然 Deep Learning 非常好,但还不足够好。我们看一下大家都很熟知的 Deep Learning 还存在的两个问题:


微信图片_20211128200943.jpg


一个问题是(深度学习)可能不是很鲁棒。可能会存在这种所谓的对抗样本,这有一个简单的例子,比如你有一个建筑物的图片,你可以用一个训练很好的神经网络分类得很准确。但是,我们可以加一些噪声,这些噪声可能是人检测不到的,合成一个图片之后却可以完全误导这个网络,甚至能够按照你的意愿误导分到某一个类。这是非常不好的性质,尤其我们在关键领域用深度学习的时候——一旦遇到这种情况发生,可能就会有一些比较致命性的错误发生。所以我们就想提出一个问题:机器学习或者深度学习本身能不能像人一样犯错误?人可能更多的时候是更鲁棒的,人可能会犯错误,但是人犯的错误相对都是比较直观、比较合理一点的——可能有某种道理在里面。


另外一个问题是深度学习大部分情况下都被我们当成一个黑箱。所以现在有很多的工作,包括我们自己的工作,都是试图去解释深度学习学到了什么。这里我们列了一个去年做的 CNNVis 的工作,能展示卷积网络每一层是什么、层和层之间是怎么关联的。这个方法非常受欢迎,也从一个侧面说明了大家对这个问题关心的程度。


在我看来,Deep Learning 本身属于机器学习的一个极端,它用了大量的训练样本,用了大量的计算资源。结果是我们在很多任务下,在特定环境、特定数据集上可以得到非常高的准确度,当然背后也有我们对网络结构的人为调整。


另外一端是贝叶斯的学习方法,大家可能知道,2015 年的时候,在 AlphaGo 火之前,Science 有一篇文章就说怎么设计贝叶斯程序,在这种情况下可以用少量的训练样本帮助我们学非常精确的模型,当时展示的成果是这个贝叶斯程度可以(在手写体数字生成和识别任务上)通过视觉图灵测试。这从一个方面告诉我们:我们做学习的时候可以有不同的思路。


微信图片_20211128201000.jpg


这是学习范式的两个极端,两者之间就有很多的事情可以做。我们把中间称之为「贝叶斯深度学习(Bayesian Deep Learning)」。它既有贝叶斯本身的可解释性,可以从少量的数据里边来学习;另外又有 Deep Learning 非常强大的拟合能力。


微信图片_20211128201017.jpg


给大家看一个最近非常火的例子,叫深度生成模型(Deep Generative Models),这是典型的融合了深度学习和贝叶斯方法的模型。这里做了一个抽象:上面有一个隐含的变量,用 Z 表示;中间会经过一个深度神经网络,你可以根据你的任务选择不同的神经网络、不同的深度、不同的结构;下面是我们观察到的数据 X。这个场景有很多,比如对抗生成网络,可以生成高维的自然图片。实际上,Z 可以是非常随机的噪声,通过神经网络可以生成非常高质量的图片。


在这种框架下,我们可以做很多。比如可以给隐含变量设定某些结构信息,比如生成人脸时,有一些变量指代人的姿态,另外一些变量可能描述其他的特征,这两个放在一起我们就可以构建这样一个深度生成模型。


它同一列有同一姿态,可以变化其它变量来生成不同的图片。现在是非常受欢迎、非常强大的一种模型了。


下面用更形式化的方式进行描述。我们用概率模型来描述,比如对 Z 变量(隐含变量),我们会用 P(Z) 来描述它的先验分布;中间有一个参数化的神经网络做变换;最后生成我们想要的数据 X。在不同场景下,这个 Z 的含义可能不一样。比如:如果要生成医学图片,我们通常希望 Z 能够表达造成疾病的原因;而对于文本图片,我们可能希望理解背后的主题等等。


微信图片_20211128201042.jpg


这个模型其实非常直观,但是它的难点在于我们所谓的 Inference(推断),这个过程是反向过来的——在 Inference 过程中,观察一些 X,然后我们用一些推导工具推导出我们观察到的 Z 到底是什么。在这个过程中,我们要用到一个主要的公式——贝叶斯公式。


珠算


微信图片_20211128201100.jpg


那么珠算平台到底是起到什么作用呢?


我们都知道有很多公开的框架可以支持深度学习进行非常迅速的开发和原型设计,但目前还并没有很好的平台能支持贝叶斯深度学习。所以,我们构建了称之为珠算的平台。珠算平台可以支持我们进行深度学习,也可以支持贝叶斯推断,当然还可以是两者之间有机的融合。


大家知道,珠算或算盘是最古老的计算机器(calculating machine),被认为是中国的历史第五大发明。我们之所以取名为「珠算」,就是希望这个平台能够从某种意义上给传统算盘一种新的解释,同时还希望这个平台能够进行高效的计算。


微信图片_20211128201120.jpg


珠算是一个生成模型的 Python 库,构建于 TensorFlow 之上。珠算不像现有的主要是为监督学习而设计的深度学习库,它是一种扎根于贝叶斯推断并支持多种生成模型的软件库。珠算区别于其他平台的一个很大的特点,即可以深度地做贝叶斯推断,因此,也就可以很有效地支持深度生成模型。珠算平台可以在 GPU 上训练神经网络,同时我们可以在上面做概率建模和概率推断,带来好处有:可以利用无监督数据、可以做小样本学习、可以做不确定性的推理和决策、可以生成新的样本等等。


微信图片_20211128201140.jpg


为了做珠算平台,第一步是个抽象过程,需要把一类的模型能够抽象表达出来,在这里我们用贝叶斯网络。贝叶斯网络是在深度学习流行之前非常主流的方法,它是一种非常好的形式化方式,能非常直观地刻画模型。但是,与传统的贝叶斯网络不通,我们是深度融合了贝叶斯方法和深度神经网络的优点,因此,我们的贝叶斯网络有两类节点:随机的节点和确定性的节点。确定性的节点基本上对应了深度神经网络的非线性变换,而随机节点可以描述不确定性。珠算是完全支持这两种节点的。在确定性的节点上我们把 TensorFlow 的所有操作都继承了下来。我们可以像在 TensorFlow 上构建神经网络一样构建中间的一些模块。如上图所示,构建一个模型很直观。我们首先只需要初始化 BayesianNet 环境,然后按照直观写模型。


微信图片_20211128201202.jpg


这是一个具体的例子,如上图所示,我们需要生成手写体字符,这种情况下因为数据不是很高维,用简单的生成模型就够了,比如有一个 Z 变量,Z 是随机的,经过两层的全连接的神经网络,最后生成我们的 X,这种模型在珠算里面非常容易写。可以在起始化 BayesianNet 环境之后,就沿着箭头的方向来写。比如:我们说 Z 变量服从一个高斯分布(z = zs.Normal()),珠算平台中有正态分布函数可以刻画该分布。接下来是两层的全连接层(layers.fully_connected()),最下面是数据的生成,比如我们数据是二值的,那么可以用伯努力随机分布来刻画它,这是非常直观地写模型的框架。你可以根据自己的需要书写其他的生成模型。


微信图片_20211128201241.jpg


对于这种模型最难的实际上是推断部分,在机器学习里有两类的推断方法,一种是变分(Variational)方法,一种是蒙特卡罗模拟方法。对于变分方法来说,红色的点是我们的目标,在某个概率分布空间里面,但我们并不能直接计算。所以,变分方法主要是希望在某个简化的子集里找一个蓝色的点去逼近它,我们希望这个逼近是最优的,所以通常情况下要解决最优化问题。这里边有很多推导公并没有提到。对于 MCMC 方法来说,现在主流的解决方法是构造一些动力学方程,以达到模拟的效果,这里也隐含了很多技术细节。


因此,即使是非常简单的模型,如果要做推断都可能需要很多的数学推导,我们需要算梯度、调步长参数等等。而且很多步骤可能都会使我们犯错误,所以这是一个复杂的过程。而珠算要做的就是简化推导实现的过程,并用一个非常简洁的(概率)编程方式写出来,编程对计算机来说是最容易理解的。


微信图片_20211128201303.jpg


给大家两个例子看我们怎么通过珠算实现推断的。首先,比如我们要做一个变分推断,在珠算上变分推断只需要三步:第一步,我们要构造一个变分分布,这个变分分布就像我前面讲的生成模型一样,可以通过初始化一个 BayesianNet,然后非常直观地写每部分是确定性的还是随机的等等。第二步,可以调用一下变分目标(variational objective),比如 z.sgvb,珠算上实现了不同的变分目标。剩下的事情,就是使用梯度下降进行迭代,就像我们实现深度神经网络一样,不断地使用随机梯度下降进行迭代而达到优化,这是典型变分推断的实现。


微信图片_20211128201333.jpg


如果我们要做的是 HMC,HMC 是一个混合的蒙特卡罗方法或者哈密尔顿蒙特卡罗方法,这属于机器学习里面的一种十分优秀的算法,它可以处理高维空间里面的采样,该算法在珠算上也非常容易来实现。我们首先需要构建变量以储存样本,然后就可以初始化 HMC 采样器。接下来调用 sample() 函数就可以得到一个采样算子,随后的在不断运行样本迭代时,就像求解一个最优化算法一样。如果大家熟悉深度神经网络过程的话,基本上我们对这种贝叶斯神经网络可以完全对等地去实现。


贝叶斯深度学习怎么用?


微信图片_20211128201354.jpg


贝叶斯深度学习在什么地方可以用?我给大家看一些例子。在我们课题组里主要强调如何用非常少的标注数据进行有效的学习。在机器学习里边有一个大家研究很多的叫半监督学习(Semi—supervised Learning)的场景,它可以利用大量的未标注数据帮助从少量标注数据中学习分类器。技术细节我就不说了,来看看结果。这个红色框里面是我们做出的结果,比如说在 SVHN 的数据集上,我们大概用 1% 的训练数据就可以达到 5% 的错误率,这个是目前最佳的结果。


因为我们是一个生成模型,所以我们还可以去生成新的样本,比如说我们可以生成二维的手写体字符。在一维上固定一个变量,调另外一个变量,生成你想要的某个类别或者某种风格的字符。

微信图片_20211128201424.jpg


这是更新的工作,我们是在生成对抗网络(GAN)上做的。大家知道 GAN,它的生成效果很好了。我们在小样本的学习下面也可以做非常好的效果,我们提出了一个 Triple GAN 的工作。在这个自然图片的数据集上,比之前大家做的各种 GAN 变种的结果显著要好(错误率更低)。大家同时可以看出来,这个生成结果和自然图片也非常接近了。


微信图片_20211128201453.jpg


下面一个例子是我前面提到过的——用贝叶斯方法做小样本学习。这是一个极端的例子,就是在训练的时候给它看一些基本的数据,将来在测试的时候会遇到新的类别(或概念),我们只给它看一个训练样例,然后希望它能够从中学出来一个贝叶斯程序,可以生成同一类的数据或者做识别。我们现在有一些在汉字上做的初步结果。给大家看一些例子,比如最上面给出了某一种字的一个样例,下面是生成出来的;基本上,大家能看出来和原始给的那个字的风格还是非常一致的,所以这个效果还是非常好的。一些技术细节我在这里就不详细说了。


微信图片_20211128201504.jpg


最后一个例子也是我前面讲的鲁棒的 Deep Learning。Deep Learning 有很多潜在攻击样本,我怎么让它变得更鲁棒?实际上,最近有一些工作显示使用贝叶斯推理可以让深度神经网络变得更鲁棒,比如:剑桥做的一个工作,这是我们复现出来的在一个数据集上的比较。这个测试数据集有一半是攻击样本、一半是正常样本。这个黑色的线是一个标准的神经网络,不用贝叶斯推理,它的正确率从 0.9 几(可能 0.97、0.98)一下子降到 0.6 几,降得非常严重。蓝色的线是贝叶斯神经网络,它可以做到更好,可以达到 75%、80% 左右的正确率,已经是非常不错的。右边的图是说你可以过滤掉多少对抗样本。大家可以看出来,这个蓝色的线,用贝叶斯网络可以帮助我们更好地识别对抗样本,提升鲁棒性。我们最近做了一个工作,结果是红色的线,能够更显著地识别 adversarial sample 和 normal sample,两个混在一起的时候,测试准确度能够显著地提升,实际上我们可以在一定条件下 达到图中的 Normal Accuracy。


微信图片_20211128201521.jpg


我们已经开源了珠算平台,现在我们把它当作是一个研究平台,也欢迎大家去尝试。我们在上面也开发了很多当前最佳的模型,包括经典的贝叶斯 logistic 回归、最新的贝叶斯神经网络、变分自编码器、GAN、主题模型 等等,我们自己也在不断做一些新模型。下面是开源的页面,大家可以在 GitHub 上找到。我们也写了一些 Online Documents,解释 API 怎么定义的,另外还有教程可以指导大家很快来实现比如我前面举例的网络模型。


微信图片_20211128201544.jpg


特别感谢我们组的学生,这个项目主要是我的两个博士生 Jiaxin Shi(石佳欣)和 Jianfei Chen(陈键飞)主导的,贡献者还包括一些博士后和博士生以及本科生。这个项目也受到一些国家经费的支持,我们的合作者还有天工研究院、英伟达等等。


微信图片_20211128201603.jpg

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
2月前
|
机器学习/深度学习 自然语言处理 监控
深度学习之视频摘要生成
基于深度学习的视频摘要生成是一种通过自动化方式从长视频中提取关键片段,生成简洁且有代表性的视频摘要的技术。其目的是在保留视频主要内容的基础上,大幅缩短视频的播放时长,方便用户快速理解视频的核心信息。
147 7
|
1月前
|
机器学习/深度学习 测试技术 PyTorch
深度学习之测量GPU性能的方式
在深度学习中,测量GPU性能是一个多方面的任务,涉及运行时间、吞吐量、GPU利用率、内存使用情况、计算能力、端到端性能测试、显存带宽、框架自带性能工具和基准测试工具等多种方法。通过综合使用这些方法,可以全面评估和优化GPU的性能,提升深度学习任务的效率和效果。
52 5
|
2月前
|
机器学习/深度学习 数据处理 数据库
基于Django的深度学习视频分类Web系统
基于Django的深度学习视频分类Web系统
66 4
基于Django的深度学习视频分类Web系统
|
2月前
|
机器学习/深度学习 运维 监控
深度学习之视频内容理解
基于深度学习的视频内容理解(Video Content Understanding, VCU)是一项关键技术,旨在通过神经网络模型自动分析、解读和提取视频中的语义信息。
150 10
|
2月前
|
机器学习/深度学习 监控 人机交互
深度学习之视频中的姿态跟踪
基于深度学习的视频姿态跟踪是一项用于从视频序列中持续检测和跟踪人体姿态的技术。它能够识别人体的2D或3D关键点,并在时间维度上进行跟踪,主要应用于人机交互、体育分析、动作识别和虚拟现实等领域。
66 3
|
3月前
|
机器学习/深度学习 测试技术 PyTorch
深度学习之测量GPU性能的方式
在深度学习中,测量GPU性能是一个多方面的任务,涉及运行时间、吞吐量、GPU利用率、内存使用情况、计算能力、端到端性能测试、显存带宽、框架自带性能工具和基准测试工具等多种方法。通过综合使用这些方法,可以全面评估和优化GPU的性能,提升深度学习任务的效率和效果。
346 2
|
4月前
|
机器学习/深度学习 并行计算 PyTorch
如何搭建深度学习的多 GPU 服务器
如何搭建深度学习的多 GPU 服务器
158 5
如何搭建深度学习的多 GPU 服务器
|
4月前
|
机器学习/深度学习 人工智能 调度
显著提升深度学习 GPU 利用率,阿里云拿下国际网络顶会优胜奖!
显著提升深度学习 GPU 利用率,阿里云拿下国际网络顶会优胜奖!
349 7
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
【深度学习】python之人工智能应用篇——视频生成技术
视频生成技术是一种基于深度学习和机器学习的先进技术,它使得计算机能够根据给定的文本、图像、视频等单模态或多模态数据,自动生成符合描述的、高保真的视频内容。这种技术主要依赖于深度学习模型,如生成对抗网络(GAN)、自回归模型(Auto-regressive Model)、扩散模型(Diffusion Model)等。其中,GAN由两个神经网络组成:一个生成器用于生成逼真的图像或视频,另一个判别器用于判断生成的图像或视频是否真实。通过不断的对抗学习,生成器和判别器共同优化,以产生更高质量的视频。
115 2
|
4月前
|
机器学习/深度学习 监控 算法
基于深度学习网络的人员行为视频检测系统matlab仿真,带GUI界面
本仿真展示了基于GoogLeNet的人员行为检测系统在Matlab 2022a上的实现效果,无水印。GoogLeNet采用创新的Inception模块,高效地提取视频中人员行为特征并进行分类。核心程序循环读取视频帧,每十帧执行一次分类,最终输出最频繁的行为类别如“乐队”、“乒乓球”等。此技术适用于智能监控等多个领域。
75 4