【机器学习】线性分类——线性判别分析LDA(理论+图解+公式推导)

简介: 【机器学习】线性分类——线性判别分析LDA(理论+图解+公式推导)

2021人工智能领域新星创作者,带你从入门到精通,该博客每天更新,逐渐完善机器学习各个知识体系的文章,帮助大家更高效学习。


一、概述

本篇讲解一种新的分类算法,它就是LDA(线性判别分析),它是一个比较经典的一个二分类算法,不过现在不怎么流行了,但是整个算法的思想很具有意义。

它是一种基于降维的方式将所有的样本映射到一维坐标轴上,然后设定一定的阈值,将样本进行区分,画幅图方便理解

它首先会将所有的样本向一个向量做投影,此时我们会获得新的坐标轴上的坐标位置,然后我们希望的是类内小,类间大,也就是同类样本映射后的坐标尽可能的聚在一起,而不同类的样本映射后尽可能的远离。

这个就是LDA的基本算法流程,下面将讲解整个算法的数学推导,以及如何求解方程

二、数学原理推导

1.数学符号

首先我们假设整个样本空间为两个类别,分别是1、-1,我们将1类视为 C 1 C_1C1 ,对应-1我们视为 C 2 C_2C2 ,其中:

∣ X C 1 ∣ = n 1 ∣ X C 2 ∣ = n 2 n 1 + n 2 = n |X_{C_1}|=n1\\|X_{C_2}|=n2\\n1+n2=nXC1=n1XC2=n2n1+n2=n

其中N1、N2代表每个类别样本的个数。

定义z为映射后的坐标即投影

2.定义类均值、方差

由于是将我们的样本数据X向w向量做投影,我们约定w的模长为1,所以有:

z i = w T x i z_i=w^Tx_izi=wTxi

由该式子,我们可以构造映射后的均值和方差,用来衡量样本的类间距离和类内距离

映射后Z坐标的均值为:

z ‾ = 1 n ∑ i = 1 n w T x i \overline z=\frac{1}{n}\sum_{i=1}^nw^Tx_iz=n1i=1nwTxi

Z坐标的方差为:

S z = 1 n ∑ i = 1 n ( z i − z ‾ ) 2 S_z=\frac{1}{n}\sum_{i=1}^n(z_i-\overline z)^2Sz=n1i=1n(ziz)2

这两个公式就可以评估我们映射后的样本密集程度。

C1类样本:

{ z ‾ 1 = 1 n 1 ∑ i = 1 n 1 w T x i S z 1 = 1 n 1 ∑ i = 1 n 1 ( z i − z 1 ‾ ) 2

{z¯¯¯1=1n1n1i=1wTxiSz1=1n1n1i=1(ziz1¯¯¯¯¯)2{z¯1=1n1∑i=1n1wTxiSz1=1n1∑i=1n1(zi−z1¯)2

{z1=n11i=1n1wTxiSz1=n11i=1n1(ziz1)2

C2类样本:

{ z ‾ 1 = 1 n 2 ∑ i = 1 n 2 w T x i S z 2 = 1 n 2 ∑ i = 1 n 2 ( z i − z 2 ‾ ) 2

{z¯¯¯1=1n2n2i=1wTxiSz2=1n2n2i=1(ziz2¯¯¯¯¯)2{z¯1=1n2∑i=1n2wTxiSz2=1n2∑i=1n2(zi−z2¯)2

{ z 1 = n2 1 i=1 n2 w T x i S z2 = n2 1 i=1 n2 ( z i z 2 ) 2

3.构造目标函数

由上面我们获得的方差和均值公式,我们的要求就是类内小,类间大,所以我们可以用相应的函数进行表达

  • 类内小:S z 1 + S z 2 S_{z_1}+S_{z_2}Sz1+Sz2 两个类的方差越小,说明样本越密集
  • 类间大:( z 1 ‾ − z 2 ‾ ) 2 (\overline {z_1}-\overline {z_2})^2(z1z2)2 ,用两个类的均值的距离说明两个类之间的距离

所以我们可以构造目标函数:

J ( w ) = ( z 1 ‾ − z 2 ‾ ) 2 S z 1 + S z 2 J(w)=\frac{(\overline {z_1}-\overline {z_2})^2}{S_{z_1}+S_{z_2}}J(w)=Sz1+Sz2(z1z2)2

我们希望该值是越大越好,也就是:

a r g m a x w J ( w ) argmax_wJ(w)argmaxwJ(w)

三、求解目标函数

上文我们获得了优化函数,接下来的目标就是对该方程进行求解,在求解之前我们先将其进行化简:

J ( w ) = ( z 1 ‾ − z 2 ‾ ) 2 S z 1 + S z 2 = ( 1 n 1 ∑ i = 1 n 1 w T x i − 1 n 2 ∑ i = 1 n 2 w T x i ) 2 1 n 1 ∑ i = 1 n 1 ( z i − z 1 ‾ ) 2 + 1 n 2 ∑ i = 1 n 2 ( z i − z 2 ‾ ) 2 = w T ( X C 1 ‾ − X C 2 ‾ ) ( X C 1 ‾ − X C 2 ‾ ) T w w T ( S C 1 + S C 2 ) w J(w)=\frac{(\overline {z_1}-\overline {z_2})^2}{S_{z_1}+S_{z_2}}\\=\frac{(\frac{1}{n_1}\sum_{i=1}^{n_1}w^Tx_i-\frac{1}{n_2}\sum_{i=1}^{n_2}w^Tx_i)^2}{\frac{1}{n_1}\sum_{i=1}^{n_1}(z_i-\overline {z_1})^2+\frac{1}{n_2}\sum_{i=1}^{n_2}(z_i-\overline {z_2})^2}\\=\frac{w^T(\overline {X_{C_1}}-\overline {X_{C_2}})(\overline {X_{C_1}}-\overline {X_{C_2}})^Tw}{w^T(S_{C_1}+S_{C_2})w}J(w)=Sz1+Sz2(z1z2)2=n11i=1n1(ziz1)2+n21i=1n2(ziz2)2(n11i=1n1wTxin21i=1n2wTxi)2=wT(SC1+SC2)wwT(XC1XC2)(XC1XC2)Tw

这里我们为了方便表示,我们定义:

S b = ( X C 1 ‾ − X C 2 ‾ ) ( X C 1 ‾ − X C 2 ‾ ) T S w = S C 1 + S C 2 S_b=(\overline {X_{C_1}}-\overline {X_{C_2}})(\overline {X_{C_1}}-\overline {X_{C_2}})^T\\S_w=S_{C_1}+S_{C_2}Sb=(XC1XC2)(XC1XC2)TSw=SC1+SC2

所以我们的目标函数就变为了:

J ( w ) = w T S b w w T S w w = w T S b w ( w T S w w ) − 1 J(w)=\frac{w^TS_bw}{w^TS_ww}\\=w^TS_bw(w^TS_ww)^{-1}J(w)=wTSwwwTSbw=wTSbw(wTSww)1

接下来我们需要对目标函数进行求导:

∂ J ( w ) ∂ w = 2 S b w ( w T S w w ) − 1 + w T S b w ( − 1 ) ( w T S w w ) − 2 2 S w w = 0 \frac{\partial J(w)}{\partial w}=2S_bw(w^TS_ww)^{-1}+w^TSbw(-1)(w^TS_ww)^{-2}2S_ww=0wJ(w)=2Sbw(wTSww)1+wTSbw(1)(wTSww)22Sww=0

化简后我们可以获得:

S b w ( w T S w w ) = ( w T S b w ) S w w S_bw(w^TS_ww)=(w^TS_bw)S_wwSbw(wTSww)=(wTSbw)Sww

我们可以观察到,我画括号内的式子为标量,所以此时又可以表达为:

S w w = w T S w w w T S b w S b w = w T S w w w T S b w ( ( X C 1 ‾ − X C 2 ‾ ) ( X C 1 ‾ − X C 2 ‾ ) T ) w S_ww=\frac{w^TS_ww}{w^TS_bw}S_bw\\=\frac{w^TS_ww}{w^TS_bw}((\overline {X_{C_1}}-\overline {X_{C_2}})(\overline {X_{C_1}}-\overline {X_{C_2}})^T)wSww=wTSbwwTSwwSbw=wTSbwwTSww((XC1XC2)(XC1XC2)T)w

又因为 ( X C 1 ‾ − X C 2 ‾ ) T w (\overline {X_{C_1}}-\overline {X_{C_2}})^Tw(XC1XC2)Tw 也是标量,所以我们最终的式子为:

我们的w正比为 S w − 1 ( X C 1 ‾ − X C 2 ‾ ) S_w^{-1}(\overline {X_{C_1}}-\overline {X_{C_2}})Sw1(XC1XC2)

即:

w ∼ S w − 1 ( X C 1 ‾ − X C 2 ‾ ) w \sim S_w^{-1}(\overline {X_{C_1}}-\overline {X_{C_2}})wSw1(XC1XC2)

这里我们没有解得w的准确值是因为我们只需要w的方向,只需要直到将数据映射到的那条直线的方向即可,因为我们之前已经规定过w的模长为1。


目录
相关文章
|
7月前
|
机器学习/深度学习 算法 数据可视化
JAMA | 机器学习中的可解释性:SHAP分析图像复刻与解读
JAMA | 机器学习中的可解释性:SHAP分析图像复刻与解读
1610 1
|
4月前
|
机器学习/深度学习 算法 数据中心
【机器学习】面试问答:PCA算法介绍?PCA算法过程?PCA为什么要中心化处理?PCA为什么要做正交变化?PCA与线性判别分析LDA降维的区别?
本文介绍了主成分分析(PCA)算法,包括PCA的基本概念、算法过程、中心化处理的必要性、正交变换的目的,以及PCA与线性判别分析(LDA)在降维上的区别。
109 4
|
4月前
|
机器学习/深度学习
【机器学习】准确率、精确率、召回率、误报率、漏报率概念及公式
机器学习评估指标中的准确率、精确率、召回率、误报率和漏报率等概念,并给出了这些指标的计算公式。
863 0
|
4月前
|
机器学习/深度学习 算法
【机器学习】简单解释贝叶斯公式和朴素贝叶斯分类?(面试回答)
简要解释了贝叶斯公式及其在朴素贝叶斯分类算法中的应用,包括算法的基本原理和步骤。
83 1
|
4月前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】聚类算法中的距离度量有哪些及公式表示?
聚类算法中常用的距离度量方法及其数学表达式,包括欧式距离、曼哈顿距离、切比雪夫距离、闵可夫斯基距离、余弦相似度等多种距离和相似度计算方式。
424 1
|
5月前
|
机器学习/深度学习 算法 数据可视化
Fisher模型在统计学和机器学习领域通常指的是Fisher线性判别分析(Fisher's Linear Discriminant Analysis,简称LDA)
Fisher模型在统计学和机器学习领域通常指的是Fisher线性判别分析(Fisher's Linear Discriminant Analysis,简称LDA)
|
7月前
|
机器学习/深度学习 算法 数据可视化
机器学习-生存分析:如何基于随机生存森林训练乳腺癌风险评估模型?
机器学习-生存分析:如何基于随机生存森林训练乳腺癌风险评估模型?
133 1
|
7月前
|
机器学习/深度学习 数据采集 自然语言处理
编写员工聊天监控软件的机器学习模块:Scikit-learn在行为分析中的应用
随着企业对员工行为监控的需求增加,开发一种能够自动分析员工聊天内容并检测异常行为的软件变得愈发重要。本文介绍了如何使用机器学习模块Scikit-learn来构建这样一个模块,并将其嵌入到员工聊天监控软件中。
246 3
|
7月前
|
机器学习/深度学习 算法 数据可视化
机器学习——主成分分析(PCA)
机器学习——主成分分析(PCA)
121 0
|
7月前
|
机器学习/深度学习 自然语言处理 JavaScript
GEE机器学习——最大熵分类器案例分析(JavaScript和python代码)
GEE机器学习——最大熵分类器案例分析(JavaScript和python代码)
141 0