sigmoid 函数的损失函数与参数更新

简介: 1 sigmoid 函数的损失函数与参数更新 逻辑回归对应线性回归,但旨在解决分类问题,即将模型的输出转换为 $[0, 1]$ 的概率值。逻辑回归直接对分类的可能性进行建模,无需事先假设数据的分布。最理想的转换函数为单位阶跃函数(也称 Heaviside 函数),但单位阶跃函数是不连续的,没法在实际计算中使用。

1 sigmoid 函数的损失函数与参数更新

逻辑回归对应线性回归,但旨在解决分类问题,即将模型的输出转换为 $[0, 1]$ 的概率值。逻辑回归直接对分类的可能性进行建模,无需事先假设数据的分布。最理想的转换函数为单位阶跃函数(也称 Heaviside 函数),但单位阶跃函数是不连续的,没法在实际计算中使用。故而,在分类过程中更常使用对数几率函数(即 sigmoid 函数):

$$ \sigma(x) = \frac{1}{1+e^{-x}} $$

易推知,$\sigma(x)' = \sigma(x)(1- \sigma(x))$.

假设我们有 $m$ 个样本 $D = \{(x_i, y_i)\}_i^m$, 令 $X = (x_1, x_2, \cdots, x_m)^T, y = (y_1, y_2, \cdots, y_m)^T$, 其中 $x_i \in \mathbb{R}^n, y_i \in \{0, 1\}$, 关于参数 $w \in \mathbb{R}^n, b \in \mathbb{R}$, ($b$ 需要广播操作),我们定义正例的概率为

$$ P(y_j=1|x_j;w,b) = \sigma(x_j^Tw +b) = \sigma(z_j) $$

这样属于类别 $y$ 的概率可改写为

$$ P(y_j|x_j;w,b) = \sigma(z_j)^{y_j}(1-\sigma(z_j))^{1-y_j} $$

令 $z = (z_1, \cdots, z_m)^T$, 则记 $h(z) = (\sigma(z_1), \cdots, \sigma(z_m))^T$, 且 Logistic Regression 的损失函数为

$$ \begin{aligned} L(w, b) =& - \displaystyle \frac{1}{m} \sum_{i=1}^m (y_i \log (\sigma(z_i)) +(1-y_i) \log (1 - \sigma(z_i)))\\ =& - \frac{1}{m} (y^T\log (h(z)) + (\mathbf{1}-y)^T\log(\mathbf{1}- h(z))), \text{ 此时做了广播操作} \end{aligned} $$

这样,我们有

$$ \begin{cases} \nabla_w L(w,b) = \frac{\text{d}z}{\text{d}w} \frac{\text{d}L}{\text{d}z} = - \frac{1}{m}X^T(y-h(z))\\ \nabla_b L(w,b) = \frac{\text{d}z}{\text{d}b} \frac{\text{d}L}{\text{d}z} = - \frac{1}{m}\mathbf{1}^T(y-h(z)) \end{cases} $$

其中,$\mathbf{1}$ 表示全一列向量。这样便有参数更新公式 ($\eta$ 为学习率):

$$ \begin{cases} w \leftarrow w - \eta \nabla_{w} L(w,b)\\ b \leftarrow b - \eta \nabla_b L(w,b) \end{cases} $$

更多机器学习中的数见:机器学习中的数学

目录
相关文章
|
存储 数据采集 缓存
医学影像PACS:大容量图像存储 报告单多种模式及自定义样式
医学影像PACS:大容量图像存储 报告单多种模式及自定义样式
1132 0
医学影像PACS:大容量图像存储 报告单多种模式及自定义样式
|
SQL 数据挖掘 数据库
HiveSQL分位数函数percentile()使用详解+实例代码
HiveSQL分位数函数percentile()使用详解+实例代码
6321 0
HiveSQL分位数函数percentile()使用详解+实例代码
|
机器学习/深度学习 并行计算 PyTorch
ONNX 优化技巧:加速模型推理
【8月更文第27天】ONNX (Open Neural Network Exchange) 是一个开放格式,用于表示机器学习模型,使模型能够在多种框架之间进行转换。ONNX Runtime (ORT) 是一个高效的推理引擎,旨在加速模型的部署。本文将介绍如何使用 ONNX Runtime 和相关工具来优化模型的推理速度和资源消耗。
6681 4
|
Perl
QPS的计算
QPS = req/sec = 请求数/秒   Q:如何根据日志查看一个服务的qps   A: 一般access.log是记录请求的日志,tail  -f XXX.access.log ,可发现格式如下:     前面是请求的时间,后面有接请求的方法名字,那么我们要统计getCart的qps cat osp-cart.
6601 0
|
存储 编解码 PyTorch
Transformers 4.37 中文文档(六十六)(1)
Transformers 4.37 中文文档(六十六)
245 0
|
开发者 Python
|
机器学习/深度学习 数据采集 人工智能
【AI 生成式】生成式 AI 中变分自动编码器 (VAE) 的概念
【5月更文挑战第4天】【AI 生成式】生成式 AI 中变分自动编码器 (VAE) 的概念
|
存储 负载均衡 API
跨语言的GRPC协议
【2月更文挑战第11天】
|
SQL 关系型数据库 MySQL
MySQL事务原理分析(ACID特性、隔离级别、锁、MVCC、并发读异常、并发死锁以及如何避免死锁)
MySQL事务原理分析(ACID特性、隔离级别、锁、MVCC、并发读异常、并发死锁以及如何避免死锁)
374 1
|
安全 Python
一文彻底搞懂Python异常处理:try-except-else-finally
一文彻底搞懂Python异常处理:try-except-else-finally