为什么在二分类问题中使用交叉熵函数作为损失函数

简介: 为什么在二分类问题中使用交叉熵函数作为损失函数

0. 前言

本文将直奔主题从数学角度介绍交叉熵函数,以及在深度学习中使用交叉熵函数作为损失函数的理由。


对于交叉熵在信息论中的实际意义,以及相关香农熵等概念在本文不做赘述介绍。


1. 交叉熵介绍


交叉熵(Cross-entropy)是一种常用于衡量两个概率分布之间差异的量度。具体来说,假设有两个概率分布 p i p_i piq i q_i qi,其中 p i p_i pi 表示 i i i事件真实概率分布, q i q_i qi 表示 i i i事件模型输出的概率分布。交叉熵的计算公式为:

image.png

特别地,对于二分类问题(是非问题)交叉熵可以写为:

image.png

2. 交叉熵函数数学特性


对于真实概率 p ∈ ( 0 , 1 ) p∈(0,1) p(0,1)实际上是一个定值,在交叉熵函数中可以作为一个常数。而模型输出概率 q ∈ ( 0 , 1 ) q∈(0,1) q(0,1)是一个变化量,在交叉熵函数中作为一个变量。


为了后续推导简便,指定对数的底为e。这样上面二分类问题的交叉熵函数可以写为:

image.png

其一阶导数为:

image.png

由于二阶导数恒大于0:

image.png

可知在 q ∈ ( 0 , 1 ) q∈(0,1) q(0,1)区间, H ′ ( q ) H'(q) H(q)是从负无穷到正无穷的单调递增函数, H ( q ) H(q) H(q)是一个凹函数。


又由于当 q = p q=p q=p时, H ′ ( q ) = 0 H'(q)=0 H(q)=0。可知在 q = p q=p q=p时, H ( q ) H(q) H(q)取最小值。


3. 为什么在二分类问题中使用交叉熵函数作为损失函数?

这个问题可以拆分为两个问题:

3.1 为什么交叉熵函数可以作为损失函数?


因为机器学习的过程是通过不断优化(降低)损失函数的值,来达到预测值 q q q不断接近真实值 p p p的目的。交叉熵函数由于是一个凹函数,且在 q = p q=p q=p时, H ( q ) H(q) H(q)取最小值(尽管这个最小值不是0),刚好满足作为损失函数的要求。


3.2 交叉熵函数作为损失函数好在哪?

为什么不用简单现有的均方差函数作为损失函数,而引入交叉熵函数?


这是因为上面所述特性:在 q ∈ ( 0 , 1 ) q∈(0,1) q(0,1)区间, H ′ ( q ) H'(q) H(q)是从负无穷到正无穷的单调递增函数,也就是说相比于均方差函数,交叉熵函数的梯度绝对值更大。而更大的梯度绝对值意味着在学习时可以更快更容易收敛。


这一点也可以通过两个函数的3D作图上看出:



明显看出绿色交叉熵的梯度明显比蓝色平方差大,这样就可以更快地收敛。


画图源码

from matplotlib import pyplot as plot
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
figure = plot.figure()
axes = Axes3D(figure)
x = np.arange(0.01, 0.99, 0.01)
y = np.arange(0.01, 0.99, 0.01)
x, y = np.meshgrid(x, y)
z1 = (x - y) ** 2
axes.plot_surface(x, y, z1, cmap='Blues_r')
z2 = -x * np.log(y) - (1 - x) * np.log(1 - y)
axes.plot_surface(x, y, z2, cmap='Greens_r')
plot.show()


相关文章
|
4月前
|
开发框架 人工智能 机器人
LangChain vs LangGraph:大模型应用开发的双子星框架
LangChain是大模型应用的“乐高积木”,提供标准化组件,助力快速构建简单应用;LangGraph则是“交通控制系统”,通过图结构支持复杂、有状态的工作流。两者互补,构成从原型到生产的一体化解决方案。
|
机器学习/深度学习 JSON 算法
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
本文介绍了DeepLab V3在语义分割中的应用,包括数据集准备、模型训练、测试和评估,提供了代码和资源链接。
4402 0
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
|
机器学习/深度学习 自然语言处理 PyTorch
Transformers入门指南:从零开始理解Transformer模型
【10月更文挑战第29天】作为一名机器学习爱好者,我深知在自然语言处理(NLP)领域,Transformer模型的重要性。自从2017年Google的研究团队提出Transformer以来,它迅速成为NLP领域的主流模型,广泛应用于机器翻译、文本生成、情感分析等多个任务。本文旨在为初学者提供一个全面的Transformers入门指南,介绍Transformer模型的基本概念、结构组成及其相对于传统RNN和CNN模型的优势。
13782 1
|
大数据 数据挖掘
大数据中配对删除(Pairwise Deletion)
【10月更文挑战第22天】
735 6
|
机器学习/深度学习 计算机视觉
YOLOv8改进 | 损失函数篇 | QualityFocalLoss质量焦点损失(含代码 + 详细修改教程)
YOLOv8改进 | 损失函数篇 | QualityFocalLoss质量焦点损失(含代码 + 详细修改教程)
2789 2
|
数据采集 资源调度 算法
【数据挖掘】十大算法之K-Means K均值聚类算法
K-Means聚类算法的基本介绍,包括算法步骤、损失函数、优缺点分析以及如何优化和改进算法的方法,还提到了几种改进的K-Means算法,如K-Means++和ISODATA算法。
2311 4
|
机器学习/深度学习 人工智能 算法
计算机视觉:目标检测算法综述
【7月更文挑战第13天】目标检测作为计算机视觉领域的重要研究方向,近年来在深度学习技术的推动下取得了显著进展。然而,面对复杂多变的实际应用场景,仍需不断研究和探索更加高效、鲁棒的目标检测算法。随着技术的不断发展和应用场景的不断拓展,相信目标检测算法将在更多领域发挥重要作用。
|
敏捷开发 缓存 安全
阿里云云效产品使用问题之如何对任务进行分类
云效作为一款全面覆盖研发全生命周期管理的云端效能平台,致力于帮助企业实现高效协同、敏捷研发和持续交付。本合集收集整理了用户在使用云效过程中遇到的常见问题,问题涉及项目创建与管理、需求规划与迭代、代码托管与版本控制、自动化测试、持续集成与发布等方面。
|
机器学习/深度学习 存储 并行计算
C语言与机器学习:K-近邻算法实现
C语言与机器学习:K-近邻算法实现
|
算法
【算法】递归总结:循环与递归的区别?递归与深搜的关系?
【算法】递归总结:循环与递归的区别?递归与深搜的关系?
660 0

热门文章

最新文章

下一篇
开通oss服务