变分贝叶斯深度学习综述

news2025/7/14 9:12:19

**©PaperWeekly 原创 · 作者 |**薛博阳

**单位 |**香港中文大学

**研究方向 |**语言模型

引言

近年来,贝叶斯深度学习(Bayesian Deep Learn-ing)在诸多领域得到广泛关注应用,效果显著。本文将针对贝叶斯深度学习框架进行系统性的概述,包括模型不确定性的引入;贝叶斯神经网络(Bayesian Neural Network)、高斯过程(Gaussian Process)、变分自编码器(Variational Auto-Encoder)三个主流模型的介绍,以及如何使用变分推断(Variational In-ference)求解上述模型的潜在变量分布;最后对相关参考文献进行总结。

深度学习的不确定性

首先谈下为什么要引入贝叶斯深度学习方法。贝叶斯深度学习能够对传统深度学习模型的不确定性(Model Uncertainty)建模,随着近年来卷积神经网络,Transformer 等发展,主流深度学习框架变得越来越复杂,网络深度可达成百甚至上千层,参数量也超过数千亿。这些大规模的神经网络虽然对信息感知和特征提取能力越来越强,但也存在在有限数据集上容易过拟合及模型泛化能力弱的隐患。

针 对 这 个 问 题, 一 种 常 用 的 方 法 是 引入 Dropout,在模型训练时使用由一个超参数控制的伯努利分布对所有网络节点随机选择丢弃,使每次训练迭代的网络都不完全相同,引入了模型结构的不确定性;而在预测时则考虑所有节点,可看作所 有训练中不同网络的集成(Ensemble)或平均,从而有效减小了过拟合,Dropout 中的超参数通常需要手动调节。

另一种方法是在网络参数上加入随机噪声,这相当于引入参数的不确定性。然而这些不确定 性建模方法都只是简单的正则化,缺乏严谨的数学表述推理论证。深度学习任务通常有两种不确定性,一种是来源于数据噪声的随机不确定性(Aleatoric Uncertainty),另一种就是重点关注的模型不确定性,在不同模型框架上又可分为参数不确定性,结构不确定性以及隐变量的不确定性。下面使用一个更直观的例子来说明引入不确定性如何有效提升模型泛化能力。

传统深度学习模型都是确定函数的点估计(Point Estimation),对于一个鉴别狗品种的图像分类模型,如果输入一张训练集分布外的图像,如一只猫的照片,那样识别结果将很离谱。于是我们希望能在模型返回预测结果时附带不确定性信息,也可以看做对结果的置信度。这需要引入能对不确定性建模的概率模型(Probabilistic Model),如下图所示的二氧化碳浓度预测模型,蓝色虚线左边是观测数据,右边是测试数据;对于测试部分的数据点 ,传统的深度学习模型会给出确定的预测,如左图红色虚线所示;而概率模型则会对未知的测试部分均采用概率分布来表示,如右图蓝色阴影部分所示,在数据集不能学习到准确的数据分布的情况下(当然实际所用的所有数据集都是有限的),很显然引入不确定性更合理,这也是能有效提升模型泛化能力的原因。

▲ 图1. 传统深度学习模型(左)与概率模型(右)对二氧化碳浓度的预测

这里的概率模型就是用的贝叶斯方法建模,假设存在数据集 和标签集 ,在预测测试数据对 的概率分布时,根据边缘概率计算,我们有

其中 为模型参数,问题就转换为求参数 在训练集 上的最大后验分布的问题。根据贝叶斯公式,有

传统深度学习通常是对参数 进行定参估计,而贝叶斯模型把参数看做概率分布,需要对所有 值进行积分,按照 Bishop 的《Pattern Recognition and Machine Learning》的定义,这种积分是贝叶斯方法的核心,在深度学习模型上应用贝叶斯方法就统称为贝叶斯深度学习。我们对公式(2)分母部分进行归一化积分,有

这部分也称作模型证据(Model Evidence)或边缘似然(Marginal Likelihood) 分布。由于积分的存在,通常很难求得解析解,这就需要用到一些近似推断方 法。至此,我们从不确定性,概率模型,贝叶斯方法的基本思路着手,明确了贝叶斯深度学习和不确定性的关系,下面就来讲贝叶斯深度学习里最经典的模型——贝叶斯神经网络。

贝叶斯神经网络

广义的贝叶斯深度学习在不同文章课题中定义略有不同,但狭义的贝叶斯深度学习公认是指贝叶斯神经网络。结合上文不确定性,再来详细讨论贝叶斯神经网络和传统神经网络的区别:传统神经网络中,我们认为模型参数 是定值,如图二左所示,并且在一个任务上存在最优参数 ;训练时,给模型参数赋一组初值 ,基于观测数据集 训练模型不断更新 ,训练时可以使用最大似然估计

或者加入正则项将最大似然变为最大后验估计

其中 正则项是将 假设为拉普拉斯先验, 正则项是高斯先验,不论 MLE 还是 MAP,最终学习目标都是让参数无限逼近 。

▲ 图2. 参数固定的传统神经网络(左)以及参数服从概率分布的贝叶斯神经网络(右)

需要指出虽然最大后验估计也引入先验,但仍属于定参估计,没有引入概率模型,不需要对参数积分,因此不属于贝叶斯方法。传统神经网络无法对不确定性建模,在监督学习中往往对预测结果过于自信,很容易发生过拟合。

顺着前文在参数上引入不确定性的思路,我们 认为 服从某种概率分布而非固定参数,如图二右所示,为了最大化不确定性,我们先假设 服从高 斯分布,这样训练的也不再是单一网络,而是无数个相同位置节点参数服从同一概率分布的集成网络。此时参数 的先验分布不再是简单的正则项,而是对应的共轭分布,高斯分布的共轭先验也应该是一 个高斯分布,对应的后验分布也是一个高斯分布。模型的优化目标就是最大化后验高斯分布 ,也就是公式(1)中的后验。

现在问题的关键就是计算公式(3)的边缘算子了。如果模型是线性回归之类的简单模型,其实也不难求出解析解,但换成神经网络后就会变得异常复杂,下面简单推导一下。假设有模型 ,输入向量 ,标签 ,模型参数为 ,假设模型输出服从均值为 ,方差为 的高斯分布,有

权重 的共轭先验也是高斯分布,假设其均值为 ,方差为 ,可得

后验分布由贝叶斯公式

计算,代入多元高斯概率密度函数,对后验分布取

其中 为常数项。对于一般的线性回归模型, 是关于 的线性函数,后验概率仍是是关于 的高斯分布,可以直接计算出解析解,但是在神经网络中,由于大量非线性单元,模型输出 与 不再是线性关系,网络模型对于参数值的高度非线性意味着精确的贝叶斯方法(即数值求解)不可行,因此我们不得不借助一些近似方法,如拉普拉斯近似(Laplace Approximation),马尔科夫链蒙特卡罗采样(Markov-Chain Monte-Carlo Sampling),以及近几年使用最多的变分推断法。

变分推断

本节主要讲变分推断求解贝叶斯神经网络的过程,类似的也可以用于其他贝叶斯深度学习模型上。

变分法最早起源于 18 世纪欧拉、拉格朗日等关于泛函优化的研究,泛函数 (Functional)是以函数作为输入,返回泛函值作为输出的一种映射,它以一个函数作为输入,返回泛函的值作为输出。研究所有可能的输入函数,找到最大化或者最小化泛函的函数就是问题的解。相比其他近似推断方法,变分法具有更好的收敛性和可扩展性,适合大规模问题的求解。贝叶斯深度学习将参数视作概率分布后,误差函数的输入也就从定值变为函数,从而转变为泛函优化,这就是用变分法来求解贝叶斯深度学习模型的原因。

第三节已经证明了贝叶斯神经网络中的 无法直接计算解析解,甚至很难采样。变分法的核心就是用一个可解的近似分布 逼近真实分布。第二节分析表明估计后验分布 需要最大化公式(3)边缘分布的积分,假设公式(3),根据 Jensen 不等式,有

这也被称为变分下限(Variational Lower Bound), 是对后验概率 的变分近似, 是参数的先验分布,KL 散度用来度量两个概率分布的距离,如下图所示

▲ 图3. 变分下限

一种更直观的理解是,已知后验分布 是 一个未知分布,我们引入已知参数分布的 去逼近 ,所以只需最小化 KL ,可以作如下推导

最终结果第一项 与 无关可以忽略,第二项 和第三项分别求 与先验 的距离,以及 时似然函数 的期望值。这与公式(11)的结果一致,也就是目标函数或误差函数,即

与正则化的传统神经网络对比,贝叶斯神经网络误差函数也分为两部分,一是训练数据相关的似然代 价(Likelihood Cost),其中 服从 ;二是先验相关的复杂性代价(Complexity Cost),也就是把正则项变成 KL 散度,传统方法中引入正则项就有让模型参数变得稀疏的作用,控制了模型的复杂度。误差函数的优化就是在两项函数之间取平衡。

下面说说误差函数两项的求解方法,为了最大化不确定性,假设近似和后验均服从高斯分布,即

似然代价因为积分存在无法直接求解,在此借助蒙特卡罗采样(Monte Carlo Sampling),即

其中 是每次训练中对 的采样次数,如果直接对均值 和方差 采样代入高斯分布因指数运算在反向传播时会造成训练过程不稳定,在此使用一种重参数化(Reparameterize)方法,即

这意味着在前向传播计算似然代价时,参数 需要从公式(17)随机采样获得,对应贝叶斯神经网络的参数 不再是一个定值。

对于 与 的 KL 散度项,需要对积分离散化,然后代入高斯分布的概率密度函数,有

其中 表示 中的第 项参数, 并且相互独立,根据高斯分布均值和二阶矩的性质 ,即可完成最后一步推导。至此,我们已推导出误差函数项的形式,利用 梯度下降和反向传播算法,就可以完成大规模参数贝叶斯神经网络的训练了,为了稀疏模型我们可以 的标准高斯分布,反向传播算法如下

由于我们假设参数服从高斯分布,因此使用了均值和方差两个参数,参数量为同等规模的传统神经网络的 2 倍。当然实际上只需要对部分参数做贝叶斯推断,就可以取得较好的效果了。

简化后的代价函数也可以进行小批量梯度下降,训练时将 随机分成 个相等的子集。每次梯度更新是小批量上的平均。如果想要衡量复杂性成本与小批量之间的关系,我们可以将小批量均匀随机划分,那么 KL 代价可以在每个训练周期非均匀地分布在小批量 之间:令 ,并且 。

研究发现 时效果最好,这也意味着在前几个小批量更新时在模型比较依赖先验也就是复杂性成本的影响,而后面训练时很大程度上受数据的影响。也就是说,当数据集趋于无穷时,贝叶斯神经网络和传统神经网络相差并不大,但是在有限数据集上,贝叶斯神经网络明显性能更优。

高斯过程

高斯过程是结合连续函数和概率模型的一种非参数化方法,函数 的高斯过程可表示为

其中 是训练集中的随机数据对, 是核函数,上述公式是高斯过程的核空间表述(kernel space view),然而,由于数据集中所有数据点都需要相互运算,在大规模数据集上计算复杂度会特别高,另一种权重空间表述(weight space view)的高斯过程形式为模型中一系列基函数的插值

核函数与基函数的关系为 , 是第 个基函数的系数。

此前已有研究证明,对于单隐层的神经网络,当隐层结点数不断增加并趋于无穷时,输出服从高斯分布。由于输出被描述为基函数的无穷和,因此可以将输出看作高斯过程,如下图所示可以看出高斯

▲ 图4. 单隐层神经网络结点数不断增加 (a),(b),©,(d) 时输出分布

过程是和模型结构相关的,这也启发了一系列将高斯过程与贝叶斯深度学习相结合的研究,一种思路就是利用上述权重空间的高斯过程,对基函数插值系数使用贝叶斯估计从而为模型结构的不确定性建模,公式(11)可表示如下

相较之下多了一个变量的积分,其余推导过程就按照变分下限,蒙特卡罗采样,重参数化,反向传播等一系列步骤进行,推导过程与第四节类似,在此不作赘述。

变分自编码器

变分自编码器本质上也是贝叶斯深度学习,只不过这次是对隐变量(Latent Variables)进行不确定性建模,也就是将神经网络中的隐藏层输出视作随机变量。传统自编码器是一种由编码器和解码器组成的用于特征提取或数据降维的模型。如下图所示左边是编码器,右边是解码器,通过自编码器将输

▲ 图5. 自编码器

入 映射到低维空间 再通过解码器还原回真实数据。

在数据处理时会遇到数据量不足的情况,这时就会考虑使用生成模型生成数据,变分自编码器就是在自编码器基础上对 引入变分贝叶斯估计,使其能够生成数据。这涉及到一类利用变分贝叶斯求解图模型变量的方法,也是变分贝叶斯在深度学习隐藏变量上的应用,变分自编码器是其中的典型代表。

对于如下具有连续隐变量的概率图模型

▲ 图6. 连续隐变量的概率图模型

我们试图推断和学习有向概率图模型的隐分布, 并通过对 的采样来实现数据 的生成。由于连续 随机隐变量 不可见,我们无法根据条件概率分布 生成 ,也就无法得到生成模型 。而数据的先验分布

因存在积分也无法求解,这时候就可以构建模型 来近似

这个过程可视作编码器,即由样本数据 学出一个对应的隐层分布 ,并使用 作为解码器,实现模型生成。数据集的先验可以写作

其中 就是变分下限,也可以写作

还可以进一步写作

接下来就是蒙特卡罗采样,重参数化,反向传播等一系列算法的运用,推导过程与第四章基本类似,在此不作赘述。

总结

本文从深度学习不确定性的角度切入,总结了贝叶斯深度学习模型提升模型的泛化能力的作用,并讲了三个主流的框架:贝叶斯神经网络,高斯过程,变分自编码器,分别在模型参数,模型结构和隐藏变量进行不确定性建模,并且给出变分法求解上述模型的过程。参考资料详见下文。

参考文献

模型不确定性:

[1] Gal, Y. “Uncertainty in Deep Learning.”PhD Thesis, 2016.

贝叶斯神经网络,变分推断:

[2] D. Barber and C. M. Bishop, “Ensemble Learning in Bayesian Neural Networks,”Nato ASI Series F Computer and Systems Sciences, 1998

[3] R. M. Neal, “Bayesian Learning for Neu- ral Networks,”Springer Science & Business Media, 2012.

[4] C. M. Bishop, “Pattern Recognition and Machine Learning,”Machine Learning, 2006.

[5] A. Graves, “Practical Variational Inference for Neural Networks,”NIPS 2012.

[6] C. Blundell et al., “Weight Uncertainty in Neural Network,”ICML 2014.

[7] Goan, E. Bayesian. “Neural Networks: An Introduction and Survey,”In Case Studies in Applied Bayesian Data Science 2020.

高斯过程:

[8] C. E. Rasmussen, “Gaussian Processes for Machine Learning,”Machine Learning, 2006.

变分自编码器:

[9] D. P. Kingma et al., “Auto-Encoding Vari- ational Bayes,”stat, 2014.

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

更多阅读




#投 稿 通 道#

让你的文字被更多人看到

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

△长按添加PaperWeekly小编

🔍

现在,在**「知乎」**也能找到我们了

进入知乎首页搜索**「PaperWeekly」**

点击**「关注」**订阅我们的专栏吧

·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1123931.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

​如何使用ArcGIS Pro制作一张地形图

01数据来源 本教程所使用的数据是从水经微图中下载的DEM数据,除了DEM数据,常见的GIS数据都可以从水经微图中下载,你可以通过关注“水经注GIS”,然后在后台回复“微图”即可获取软件下载地址,当然也可以直接在水经注…

TensorFlow2从磁盘读取图片数据集的示例(tf.data.Dataset.list_files)

import os import warnings warnings.filterwarnings("ignore") import tensorflow as tf from tensorflow.keras.optimizers import Adam from tensorflow.keras.applications.resnet import ResNet50 from pathlib import Path import numpy as np#数据所在文件夹 …

AI爆文变现脚本:0基础小白的保姆级操作教程-更新迭代

脚本作用:这个脚本主要是辅助训练营的同学使用的,脚本可以增加发文的效率。 脚本现在已经更新了9个版本了。目的是为了更方便大家操作使用。 AI爆文流量主(广告)变现项目的实际操作教程,我之前分享过了,大家感兴趣的可以再去看看…

灰色和测试环境打包串台

事情是这样的: 最近开发总说jenkins灰色环境打包总是到成测试环境的,测试环境总是走到了线上了。我们排查了也很久最终发现原来是这个问题导致的。如下: 修改如下: 问题解决

Tomcat+nginx负载均衡和动静分离

Nginx实现负载均衡和动静分离的原理 Nginx实现负载均衡是通过反向代理实现Nginx服务器作为前端,Tomcat服务器作为后端,web页面请求由Nginx服务来进行转发。 但是不是把所有的web请求转发,而是将静态页面请求Ncinx服务器自己来处理&#xff0c…

当年很流行,现在已经淘汰的前端技术有哪些?

近几年,前端技术真可谓是飞速发展,不断有新的技术涌现,爆火的前端框架 Astro,前端运行时 Bun,构建工具 Vite 等都给前端提供了强大动力。当然,也有很多前端技术随着技术的发展不再需要使用,有了…

【数据结构】线性表(十一)队列:双端队列及其基本操作(初始化、判空、判满、头部入队、尾部入队、头部出队、尾部出队、存取队首队尾元素)

文章目录 一、队列1. 定义2. 基本操作 二、顺序队列三、链式队列双端队列0. 头文件1. 队列结构体2. 初始化3. 判断队列是否为空4. 判断队列是否已满5. 头部入队6. 尾部入队7. 头部出队8. 尾部出队9. 存取队列头部的元素10. 存取队列尾部的元素11. 释放队列内存12. 主函数13. 代…

每日一题 2678. 老人的数目(简单)

简单题,不多说 class Solution:def countSeniors(self, details: List[str]) -> int:ans 0for l in details:if int(l[11:13]) > 60:ans 1return ans

CSS设置超出范围滚动条和滚动条样式

CSS设置超出范围滚动条和滚动条样式 效果展示 当块级内容区域超出块级元素范围的时候,就会以滚动条的形式展示,你可以滚动里面的内容,里面的内容不会超出块级区域范围。 未设置超出隐藏,显示滚动条 超出隐藏,显示滚动…

APP软件外包开发设计原则

设计一个成功的APP需要遵循一些关键的设计原则,以确保用户体验良好、功能明晰、吸引力和易用性。以下是一些重要的APP设计原则,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作。 用户为中心&…

Ubuntu系统中安装libcurl库用来做爬虫

在Ubuntu系统上运行爬虫,可以使用libcurl的方式简单部署libcurl爬虫管理平台。在libcurl库中,可以使用普通任务和定时任务来运行爬虫。同时,还可以添加依赖包和配置消息通知钉钉机器人等功能。如果需要使用Python-bs4库,可以通过系…

从GitHub火到头条!这份万众期待的阿里内部JAVA面试手册,开源了

前言: 现在的互联网开发岗招聘,程序员面试背八股文已经成为了不可逆转的形式,其中一个Java岗几百人在投简历也已经成为了常态!更何况一份面试题动辄七八百道,你吃透了,技术只要不是很差,面试怎…

开发中常用的版本管理工具有哪些?

一、是什么 版本控制(Version control),是维护工程蓝图的标准作法,能追踪工程蓝图从诞生一直到定案的过程。此外,版本控制也是一种软件工程技巧,借此能在软件开发的过程中,确保由不同人所编辑的…

网站页脚展示备案号并在新标签页中打开超链接

备案时,我们就注意到,备案成功后需要在网站首页底部展示“备案号”,并将备案号链接至https://beian.miit.gov.cn。 这里我使用了WrodPress中的主题,主题自定义中有提供对页脚文本的编辑,支持用css标签定义样式。若是自…

创意无限,图文生成如虎添翼:星火大模型的威力

在数字化的时代,讯飞(iFlyTek)的星火大模型已经走在了创新的前沿。它以卓越的技术和无与伦比的免费政策,为创作者、开发者和企业家们提供了创新无限的可能性。 星火大模型最新亮点 多样性无限,星火助手数量达4000 星火…

Post-Process1-水下

一、新建第三人称游戏项目,我这里选择C,你也可以选择Blueprint。 新建一个Level,命名为DemoUnderWater 保存一下,命名为DownUnderWater 添加水插件 选择Yes 勾选Show Engine Content和Show Plugin Content,在左侧可以看…

Mysql如何确定执行计划是最优开销?Mysql优化器!

1. 什么是 MySQL 优化器? MySQL 优化器是 MySQL 中的一个核心组件。MySQL 优化器的主要职责在于确定查询的执行计划。在数据库中,同样的查询可以有多种不同的执行方式,如使用不同的索引,使用不同的连接顺序等。每种执行方式都有其…

C++之std::string

string类与头文件包含&#xff1a;#include <string> string构造方法&#xff1a; // string constructor #include <iostream> #include <string>int main () {std::string s0 ("Initial string"); //根据已有字符串构造新的string实例// cons…

紫光展锐发布全新6G白皮书,展望泛在融合发展蓝图

自2019年5G蜂窝技术正式商用以来&#xff0c;5G网络建设如火如荼&#xff0c;各类形态的5G终端层出不穷。5G商用推进的同时&#xff0c;6G研究也在全球范围内拉开帷幕。2023年6月ITU发布了《IMT面向2030及未来发展的框架和总体目标建议书》&#xff08;下文简称“建议书”&…

Java基于SSM开发的企业员工管理系统源码

主要功能 包括部门、岗位、工资、员工、请假、审批管理。普通员工可请假查看工资等&#xff0c;管理员可审批、管理员工工资等。 演示视频&#xff1a; https://www.bilibili.com/video/BV1c94y1j7QM/?share_sourcecopy_web&vd_source11344bb73ef9b33550b8202d07ae139b …