深度学习:Xavier初始化理论+代码实现

简介: 深度学习:Xavier初始化理论+代码实现

深度学习:Xavier初始化理论+代码实现

Xavier初始化理论

权值初始化对网络优化至关重要。早年深度神经网络无法有效训练的一个重要原因就是早期人们对初始化不太重视。我们早期用的方法大部分都是随机初始化,而随着网络深度的加深,随机初始化在控制数值稳定性上也可能失效。Xavier这个方法可以考虑输入层与输出层的维度,使在forward 和backward阶段保持每层之间均值与方差接近。
请添加图片描述
我们拿mlp举例,为了方便运算,忽略激活函数,上图是神经网络的一部分,我们假设$h_1^t$为输入层$h_t^{t+1}$为输出层,我们假设权重系数W~iid(independent identically distribution),均值为0,方差为a,其中$h^t、h^{t+1}$独立于w。
前向计算公式为:
$$h_{j}^{t+1}=\sum _{i}w_{ij}\cdot h_{i}^{t}$$
Xavier的核心思想是让输入层与输出层方差接近,我们首先考虑$h^t、h^{t+1}$的均值,因为t层最初可以追溯到数据输入层,可以通过归一化的手段控制,所以我们直接考虑t+1层。
$$\begin{aligned}E\left[ h_{j}^{t+1}\right] =E\left[ \sum _{i}w_{ij}\cdot h_{i}^{t}\right] \\ =\sum _{1}E\left[ w_{ij}\right] \cdot E\left[ hi^{t}\right] \\ =\sum _{i}0\cdot E\left[ hi^{t}\right] \\ =0\end{aligned}$$
我们发现t+1层均值为0,之后计算方差:
$$\begin
{aligned}Var\left[ h_{j}^{t+1}\right] \
=E\left[ \left( h_j^{t+1}\right) ^{2}\right] -E\left[ h_{j}^{t+1}\right] ^{2}\
=E\left[ h_j^{t+1}) ^{2}\right] -0\
=E\left[ \left( \sum _{i}w_{ij}\cdot h_{i}^{t}\right) ^{2}\right] \
=E\left[ \sum _{i}\left( w_{ij}\right) ^{2}\left( h_{i}^{t}\right) ^{2}\right] \
=\sum _{i}E[(w_{ij})^2]E[(h_{i}^{t})^2]\
=\sum _{i}Var\left[ w_{ij}\right] \cdot Var\left[ h_{i}^t\right] \
\end{aligned}$$
我们的目标是让前后层方差相等,所以并且w的方差在上面我们假设为a,所以我们要满足:
$$n^t*a = 1$$
到目前为止,我们的前向计算的满足条件就计算完成了,我们接下来计算反向传播:
$$\dfrac{\partial Loss}{\partial h^{t}}=\dfrac{\partial loss}{\partial h^{t+1}}\cdot W_{ij}$$
计算步骤可前面一样,最终我们可以得出:
$$n^{t+1}*a = 1$$
我们到了一个进退两难的地步,因为无法同时满足:$n^t*a = 1$、$n^{t+1}*a = 1$,所以Xavier采取了一个折中的方案:
$$\begin{aligned}\left( n_{t}+n_{t+1}\right) \cdot a=2\\ a=\dfrac{n_{t}+n_{t+1}}{2}\end{aligned}$$
我们有了权重的均值和方差,我们就可以初始化了。
在这里插入图片描述
当加入激活函数,是否他们会改变呢?
我们加入激活函数:为了方便运算,假设线性激活函数为:
$$\begin{aligned}\sigma \left( x\right) =\alpha x+\beta \\ E\left[ \sigma \left( hj^{t+1}\right) \right] \\ =\alpha E\left[ hj^{t+1}\right] +E\left[ \beta \right] =0\end{aligned}$$
为了保证均值为0,其中E$[h_j^{t+1}]$均值为0,$\beta$也要为0.
$$\begin{aligned}Var\left[ \sigma (h_{j}^{t+1}) \right] \\ =E\left[ (h_{j}^{t+1}) ^{2}\right] -E\left[ h_j^{t+1}) \right] ^{2}\\ = E\left[ \left( \alpha h_{j}^{t+1}+\beta \right) ^{2}\right] \\ =E\left[ \left( \left( \alpha h_j^{t+1}\right) ^{2}+2\alpha h_j^{t+1}\beta +\beta ^{2}\right) \right] \\ =\alpha ^{2}E\left[ \left( h_{j}^{t+1}\right) ^{2}\right] \\ =\alpha ^{2}Var\left[ hj^{t+1}\right] \end{aligned}$$
我们发现,经过激活函数,变成了之前的alpha 方倍,为了保持方差不变,让 alpha =1。也就是说,我们的激活函数尽量选择与y =x 接近的函数,才可以在Xavier上表现较好。
在这里插入图片描述

代码实现

import torch
from torch import nn
model = nn.Linear(20, 30)
input = torch.randn(128, 20)
model.weight=torch.nn.Parameter(nn.init.uniform_(torch.Tensor(30,20)))##均匀分布
model.weight=torch.nn.Parameter(nn.init.normal_(torch.Tensor(30,20)))##正态分布
output = m(input)
目录
相关文章
|
2月前
|
机器学习/深度学习 算法 计算机视觉
基于深度学习的停车位关键点检测系统(代码+原理)
基于深度学习的停车位关键点检测系统(代码+原理)
125 0
|
2月前
|
机器学习/深度学习 数据采集 PyTorch
图像分类保姆级教程-深度学习入门教程(附全部代码)
图像分类保姆级教程-深度学习入门教程(附全部代码)
45 1
|
2月前
|
机器学习/深度学习 编解码 API
深度学习+不良身体姿势检测+警报系统+代码+部署(姿态识别矫正系统)
深度学习+不良身体姿势检测+警报系统+代码+部署(姿态识别矫正系统)
57 0
|
5月前
|
机器学习/深度学习
深度学习模型调参技巧分享 视频讲解代码实战
深度学习模型调参技巧分享 视频讲解代码实战
41 0
|
2月前
|
机器学习/深度学习 算法 数据可视化
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)-2
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)
99 0
|
2月前
|
机器学习/深度学习 Ubuntu Linux
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)-1
计算机视觉+深度学习+机器学习+opencv+目标检测跟踪+一站式学习(代码+视频+PPT)
55 1
|
2月前
|
机器学习/深度学习 JSON 自然语言处理
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
46 0
|
2月前
|
机器学习/深度学习 算法 算法框架/工具
基于深度学习的交通标志检测和识别(从原理到环境配置/代码运行)
基于深度学习的交通标志检测和识别(从原理到环境配置/代码运行)
194 0
|
4月前
|
机器学习/深度学习 Linux TensorFlow
基于Python TensorFlow Keras的深度学习回归代码——keras.Sequential深度神经网络
基于Python TensorFlow Keras的深度学习回归代码——keras.Sequential深度神经网络
|
4月前
|
机器学习/深度学习 编译器 TensorFlow
基于Python TensorFlow Estimator的深度学习回归与分类代码——DNNRegressor
基于Python TensorFlow Estimator的深度学习回归与分类代码——DNNRegressor