适合离散值分类的多分类模型——softmax回归

简介: 适合离散值分类的多分类模型——softmax回归

适合离散值分类的多分类模型——softmax回归

一、什么是softmax回归?

简单来说softmax回归是用来做多分类任务的模型,是一个单层神经网络。与线性回归模型针对连续值的预测(预测房屋价格、天气温度变化等)不同,它更适合离散值的训练和预测。由于该模型是为了识别多种类别,故最终的输出不再是单个值(像relu/sigmoid输入一个实数,输出一个实数)或者是二分类情景(只有两种标记0和1)的两个值,而是多个值(向量),但这些值并不是简单的 0~n-1 的标记,而是经过运算后的类别概率。

在猫狗分类模型中,将猫和狗的类别分别标记为0和1,现假设为了要识别出更多的动物,同样可以把不同的动物类别分别标记为0~n-1,这里,n为训练时输入的样本最终被分类的总个数。

在吴恩达视频中,做了4种可能的类别,即n=4,于是分别为(猫、狗、鸡、其它)这四种类别标记成(0、1、2、3)。

为此,可以构建如下神经网络:

1ecd1b2606ed46e9956a89f231c9802c.png

当输入为样本X时,最后一层即输出层L,它有4个神经元,也就是4种类别,从上到下分别对应其它、猫、狗、鸡。实际上,多分类任务是将输出当作某个预测类别的置信度,就是将值最大的输出所对应的类作为预测输出,例如当输出的值为0.1,0.1,10,0.1,因为10最大,那么预测类别为2,是狗。然而,通常并不直接使用输出层的输出,而是使用softmax将每种类别的的输出值转化为概率,且概率和为1。

这是为什么呢?

1.当一个样本经过模型进行预测时,执行完最后一层全连接,便得到的了T*1维的向量,其值就是预测标签,但得到的这些值的大小很随机,也就是说,由于输出层的输出值的范围不确定,我们难以直观的判断这些值的意义。就比如说,当上面例子的输出由于w的更新变化成(100,100,10,100),这个时候输出值10表示图像类别为猫的概率又变得很低。

2.由于真实标签是离散值,这些离散值与不确定范围的输出值之间的误差难以衡量。

于是softmax层就闪亮登场了,它出现在输出层L之后,经过运算,会得到的与上层相同维度的向量输出,只不过向量值变成了各个类别的概率。

其实看下面这张图就很明显。

2020062310470442.png

上图等号左边就是全连接计算(网络最后一层),即权值W和上层输出特征X相乘(WX),可以看到特征X是N*1的向量,而W是个T*N的矩阵,这个N和X的N对应,然后得到一个T*1的向量(图中的logits[T*1]),这里的T就表示类别数,且这个向量里面每个数的大小都没有限制,即每个数的取值范围是从负无穷到正无穷。所以,对于多分类问题,在全连接层后面接一个softmax层,就相当于是一个范围的约束。其中,这个softmax的输入是T*1的向量,即全连接的输出,运算后得到的结果也是T*1的向量(也就是图中的prob[T*1],这个向量的每个值表示这个样本属于每个类的概率),只不过输出的向量的每个值的大小范围为0到1,同时各类别概率和为1,这就使得某一类别概率确定了,其他类别的概率不会没有范围限制的变化。

二、softmax回归运算

先来看下softmax的公式:

20200617123557795.jpg

首先要计算神经网络最后一层的输出,即线性计算:

1ecd1b2606ed46e9956a89f231c9802c.png

然后应用softmax激活函数,它的计算过程是这样的,先要计算一个临时变量t,也就是t=e^{z[l]},设z[l]是4*1的向量,然后对向量中的每个值求幂,就得到的t,它也就是4*1维的,然后将t中的值相加,最后输出:

1ecd1b2606ed46e9956a89f231c9802c.png

这个a[l]也是4*1维的,而4维向量的第i个元素就是:

2020062310470442.png

所以上面的公式也不难理解了,再说一遍,分子代表向量中的第j个元素,分母是向量中的4个值相加,这是个4分类问题,所以就是4,如果是T,就是向量中的T个值相加,然后就能得到某种类别小于1的概率值了。


假设在测试模型的时候,当一个样本经过softmax层并输出一个T*1的向量,就会取这个向量中值最大的那个数的index作为这个样本的预测标签。


因此我们训练全连接层的W的目标就是使得其输出的WX在经过softmax层计算后其对应于真实标签的预测概率要最高。


比如,z[l]=WX=[5,2,-1,3],那么经过softmax层后就会得到[148.4,7.4,0.4,20.1],这4个数字表示这个样本属于第0,1,2,3类的概率分别是148.4,7.4,0.4,20.1,故这个样本预测出的标签就是0,猫。

三、交叉熵损失函数

在经过softmax后,就可以将预测输出的标签与真实的离散标签进行误差估计了。因为,预测类别用概率来输出,同样的,对于真实标签也可以用类别概率来表示。即,对于样本i,我们可以构造一个输出类别为q的向量,然后使向量中某个(i)类别的元素值为1(样本i类别的离散数值) ,其余为0,那么,我们的训练目标可以设为使预测概率分布尽可能接近真实的标签概率分布。

一个比较适合衡量两个概率分布差异的测量函数就是交叉熵(cross entropy):

1ecd1b2606ed46e9956a89f231c9802c.png

首先H就是损失。y^(i)j 是softmax的输出向量y^(i) 的第j个值,它表示的是这个样本属于第j个类别的概率。y(i)j 前面有个求和符号,j的范围也是1到类别数q,因此y是一个1*q的向量,里面有q个值,而且只有1个值是1(z向量中的最大元素置1),其他q-1个值都是0(这就是所谓的hardmax)。那么哪个位置的值是1呢?答案是真实标签对应的位置的那个值是1,其他都是0。所以这个公式其实有一个更简单的形式:

2020062310470442.png

下标应该是j,j指向当前样本的真实标签。这就意味着,如果你的学习算法试图将该式变小,(因为梯度下降法是用来减少训练集的损失的),唯一方式就是使等式右边的式子变小,要想做到这一点,就需要使y^(i)j 尽可能大,因为这些是概率,所以不可能比1大,概括来讲,损失函数所做的就是找到训练集中的真实类别,然后试图使该类别相应的概率尽可能地高,也就是说,交叉熵只关心对正确类别的预测概率,因为只要其值足够大,就可以确保分类结果正确。举个例子:


假设一个5分类问题,然后一个样本i的标签y=[0,0,0,1,0],也就是说样本I的真实标签是4,假设模型预测的结果概率(softmax的输出)p=[0.1,0.15,0.05,0.6,0.1],可以看出这个预测是对的,那么对应的损失H=-log(0.6),也就是当这个样本经过这样的网络参数产生这样的预测p时,它的损失是-log(0.6)。那么假设p=[0.15,0.2,0.4,0.1,0.15],这个预测结果就很离谱了,因为真实标签是4,而预测出这个样本是4的概率只有0.1(远不如其他概率高,如果是在测试阶段,那么模型就会预测该样本属于类别3),故对应损失H=-log(0.1)。那么假设p=[0.05,0.15,0.4,0.3,0.1],这个预测结果虽然也错了,但是没有前面那个那么离谱,对应的损失H=-log(0.3)。我们知道log函数在输入小于1的时候是个负数,而且log函数是递增函数,所以-log(0.6) < -log(0.3) < -log(0.1)。简单讲就是你预测错比预测对的损失要大,预测错得离谱比预测错得轻微的损失要大。

上面是单个训练样本的损失(0.6),当遇到一个样本有多个标签时,例如图像里含有不止一个物体时,我们并不能做这一步简化。但即便对于这种情况,交叉熵同样只关心对图像中出现的物体类别的预测概率。

假设训练数据集的样本数为n,交叉熵损失函数定义为:

1ecd1b2606ed46e9956a89f231c9802c.png

同样地,如果每个样本只有一个标签,那么交叉熵损失可以简写成:

2020062310470442.png

其中Θ代表模型参数。

为什么不用线性回归使用的平方损失函数?

实际上是可以用的,然而,想要预测分类结果正确,其实并不需要预测概率完全等于标签概率。如果真样本实标签y(i)=3,那么我们只需要预测输出值y^(i)3比其他两个预测值y^(i)1和y^(i)2大就行了。即使y^(i)3值为0.6,不管其他两个预测值为多少,类别预测均正确。而平方损失则过于严格,例如y^(i)1=y^(i)2=0.2要比y^(i)1=0,y^(i)2=0.4的损失小很多(代入平方损失函数),虽然两者都有同样正确的分类预测结果。

四、在有Softmax输出层时如何实现梯度下降法

其实初始化反向传播所需要的关键步骤或者说关键方程是这个表达式:

1ecd1b2606ed46e9956a89f231c9802c.png

 未完待续。。。

相关文章
|
8月前
|
机器学习/深度学习 人工智能 测试技术
使用随机森林分类器对基于NDRE(归一化差异水体指数)的特征进行分类
使用随机森林分类器对基于NDRE(归一化差异水体指数)的特征进行分类
66 1
|
1月前
|
机器学习/深度学习 算法 数据可视化
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择分类心肌梗塞数据模型案例(上)
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择分类心肌梗塞数据模型案例
|
1月前
|
机器学习/深度学习 数据可视化
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择分类心肌梗塞数据模型案例(下)
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择分类心肌梗塞数据模型案例
|
1月前
|
机器学习/深度学习 算法 数据可视化
R语言用标准最小二乘OLS,广义相加模型GAM ,样条函数进行逻辑回归LOGISTIC分类
R语言用标准最小二乘OLS,广义相加模型GAM ,样条函数进行逻辑回归LOGISTIC分类
|
1月前
|
机器学习/深度学习 算法 数据可视化
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择的分类模型案例
R语言惩罚logistic逻辑回归(LASSO,岭回归)高维变量选择的分类模型案例
|
1月前
|
SQL 数据可视化 数据挖掘
R语言线性分类判别LDA和二次分类判别QDA实例
R语言线性分类判别LDA和二次分类判别QDA实例
|
1月前
|
机器学习/深度学习 数据采集 算法
乳腺癌预测:特征交叉+随机森林=成功公式?
乳腺癌预测:特征交叉+随机森林=成功公式?
34 0
乳腺癌预测:特征交叉+随机森林=成功公式?
|
8月前
为什么进行线性回归前需要对特征进行离散化处理?
为什么进行线性回归前需要对特征进行离散化处理?
145 1
|
9月前
|
机器学习/深度学习 算法 索引
逻辑回归与多项式特征:解密分类问题的强大工具
逻辑回归与多项式特征:解密分类问题的强大工具
|
11月前
|
机器学习/深度学习 存储 索引
用4种回归方法绘制预测结果图表:向量回归、随机森林回归、线性回归、K-最近邻回归
用4种回归方法绘制预测结果图表:向量回归、随机森林回归、线性回归、K-最近邻回归
123 0