JavaScript 深度学习(一)(4)https://developer.aliyun.com/article/1516944
3.2. 输出的非线性:用于分类的模型
我们到目前为止看到的两个例子都是回归任务,我们试图预测一个数值(如下载时间或平均房价)。然而,机器学习中另一个常见的任务是分类。一些分类任务是二元分类,其中目标是对一个是/否问题的答案。技术世界充满了这种类型的问题,包括
- 是否给定的电子邮件是垃圾邮件
- 是否给定的信用卡交易是合法的还是欺诈的
- 是否给定的一秒钟音频样本包含特定的口语单词
- 两个指纹图像是否匹配(来自同一个人的同一个手指)
另一种分类问题是多类别分类任务,对此类任务也有很多例子:
- 一篇新闻文章是关于体育、天气、游戏、政治还是其他一般话题
- 一幅图片是猫、狗、铲子等等
- 给定电子笔的笔触数据,确定手写字符是什么
- 在使用机器学习玩一个类似 Atari 的简单视频游戏的场景中,确定游戏角色应该向四个可能的方向之一(上、下、左、右)前进,给定游戏的当前状态
3.2.1. 什么是二元分类?
我们将从一个简单的二元分类案例开始。给定一些数据,我们想要一个是/否的决定。对于我们的激励示例,我们将谈论钓鱼网站数据集。任务是,给定关于网页和其 URL 的一组特征,预测该网页是否用于钓鱼(伪装成另一个站点,目的是窃取用户的敏感信息)。
⁷
Rami M. Mohammad, Fadi Thabtah, 和 Lee McCluskey,“Phishing Websites Features,”
mng.bz/E1KO。
数据集包含 30 个特征,所有特征都是二元的(表示值为-1 和 1)或三元的(表示为-1、0 和 1)。与我们为波士顿房屋数据集列出所有单个特征不同,这里我们提供一些代表性的特征:
HAVING_IP_ADDRESS—是否使用 IP 地址作为域名的替代(二进制值:{-1, 1})SHORTENING_SERVICE—是否使用 URL 缩短服务(二进制值:{1, -1})SSLFINAL_STATE—URL 是否使用 HTTPS 并且发行者是受信任的,它是否使用 HTTPS 但发行者不受信任,或者没有使用 HTTPS(三元值:{-1, 0, 1})
数据集由大约 5500 个训练示例和相同数量的测试示例组成。在训练集中,大约有 45%的示例是正面的(真正的钓鱼网页)。在测试集中,正面示例的百分比大约是相同的。
这只是最容易处理的数据集类型——数据中的特征已经在一致的范围内,因此无需对其均值和标准偏差进行归一化,就像我们为波士顿房屋数据集所做的那样。此外,相对于特征数量和可能预测数量(两个——是或否),我们有大量的训练示例。总的来说,这是一个很好的健全性检查,表明这是一个我们可以处理的数据集。如果我们想要花更多时间研究我们的数据,我们可能会进行成对特征相关性检查,以了解是否有冗余信息;但是,这是我们的模型可以容忍的。
由于我们的数据与我们用于波士顿房屋(后归一化)的数据相似,我们的起始模型基于相同的结构。此问题的示例代码可在 tfjs-examples 存储库的 website-phishing 文件夹中找到。您可以按照以下方式查看和运行示例:
git clone https://github.com/tensorflow/tfjs-examples.git cd tfjs-examples/website-phishing yarn && yarn watch
列表 3.5. 为钓鱼检测定义二分类模型(来自 index.js)
const model = tf.sequential(); model.add(tf.layers.dense({ inputShape: [data.numFeatures], units: 100, activation: 'sigmoid' })); model.add(tf.layers.dense({units: 100, activation: 'sigmoid'})); model.add(tf.layers.dense({units: 1, activation: 'sigmoid'})); model.compile({ optimizer: 'adam', loss: 'binaryCrossentropy', metrics: ['accuracy'] });
这个模型与我们为波士顿房屋问题构建的多层网络有很多相似之处。它以两个隐藏层开始,两者都使用 sigmoid 激活。最后(输出)有确切的 1 个单元,这意味着模型为每个输入示例输出一个数字。然而,这里的一个关键区别是,我们用于钓鱼检测的模型的最后一层具有 sigmoid 激活,而不是波士顿房屋模型中的默认线性激活。这意味着我们的模型受限于只能输出介于 0 和 1 之间的数字,这与波士顿房屋模型不同,后者可能输出任何浮点数。
之前,我们已经看到 sigmoid 激活对隐藏层有助于增加模型容量。但是为什么在这个新模型的输出处使用 sigmoid 激活?这与我们手头问题的二分类特性有关。对于二分类,我们通常希望模型产生正类别的概率猜测——也就是说,模型“认为”给定示例属于正类别的可能性有多大。您可能还记得高中数学中的知识,概率始终是介于 0 和 1 之间的数字。通过让模型始终输出估计的概率值,我们获得了两个好处:
- 它捕获了对分配的分类的支持程度。
sigmoid值为0.5表示完全不确定性,其中每个分类都得到了同等的支持。值为0.6表示虽然系统预测了正分类,但支持程度很低。值为0.99表示模型非常确定该示例属于正类,依此类推。因此,我们使得将模型的输出转换为最终答案变得简单而直观(例如,只需在给定值处对输出进行阈值处理,例如0.5)。现在想象一下,如果模型的输出范围可能变化很大,那么找到这样的阈值将会有多难。 - 我们还使得更容易构造一个可微的损失函数,它根据模型的输出和真实的二进制目标标签产生一个衡量模型错失程度的数字。至于后者,当我们检查该模型使用的实际二元交叉熵时,我们将会更详细地阐述。
但是,问题是如何将神经网络的输出强制限制在[0, 1]范围内。神经网络的最后一层通常是一个密集层,它对其输入执行矩阵乘法(matMul)和偏置加法(biasAdd)操作。在matMul或biasAdd操作中都没有固有的约束,以保证结果在[0, 1]范围内。将sigmoid等压缩非线性添加到matMul和biasAdd的结果中是实现[0, 1]范围的一种自然方法。
清单 3.5 中代码的另一个新方面是优化器的类型:'adam',它与之前示例中使用的'sgd'优化器不同。adam与sgd有何不同?正如你可能还记得上一章第 2.2.2 节所述,sgd优化器总是将通过反向传播获得的梯度乘以一个固定数字(学习率乘以-1)以计算模型权重的更新。这种方法有一些缺点,包括当选择较小的学习率时,收敛速度较慢,并且当损失(超)表面的形状具有某些特殊属性时,在权重空间中出现“之”形路径。adam优化器旨在通过以一种智能方式使用梯度的历史(来自先前的训练迭代)的乘法因子来解决这些sgd的缺点。此外,它对不同的模型权重参数使用不同的乘法因子。因此,与一系列深度学习模型类型相比,adam通常导致更好的收敛性和对学习率选择的依赖性较小;因此,它是优化器的流行选择。TensorFlow.js 库提供了许多其他优化器类型,其中一些也很受欢迎(如rmsprop)。信息框 3.1 中的表格提供了它们的简要概述。
TensorFlow.js 支持的优化器
下表总结了 TensorFlow.js 中最常用类型的优化器的 API,以及对每个优化器的简单直观解释。
TensorFlow.js 中常用的优化器及其 API
| 名称 | API(字符串) | API(函数) | 描述 |
| 随机梯度下降(SGD) | ‘sgd’ | tf.train.sgd | 最简单的优化器,始终使用学习率作为梯度的乘子 |
| Momentum | ‘momentum’ | tf.train.momentum | 以一种方式累积过去的梯度,使得对于某个权重参数的更新在过去的梯度更多地朝着同一方向时变得更快,并且当它们在方向上发生大变化时变得更慢 |
| RMSProp | ‘rmsprop’ | tf.train.rmsprop | 通过跟踪模型不同权重参数的最近梯度的均方根(RMS)值的历史记录,为不同的权重参数设置不同的乘法因子;因此得名 |
| AdaDelta | ‘adadelta’ | tf.train.adadelta | 类似于 RMSProp,以一种类似的方式为每个单独的权重参数调整学习率 |
| ADAM | ‘adam’ | tf.train.adam | 可以理解为 AdaDelta 的自适应学习率方法和动量方法的结合 |
| AdaMax | ‘adamax’ | tf.train.adamax | 类似于 ADAM,但使用稍微不同的算法跟踪梯度的幅度 |
一个明显的问题是,针对你正在处理的机器学习问题和模型,应该使用哪种优化器。不幸的是,在深度学习领域尚无共识(这就是为什么 TensorFlow.js 提供了上表中列出的所有优化器!)。在实践中,你应该从流行的优化器开始,包括 adam 和 rmsprop。在有足够的时间和计算资源的情况下,你还可以将优化器视为超参数,并通过超参数调整找到为你提供最佳训练结果的选择(参见 section 3.1.2)。
3.2.2. 衡量二元分类器的质量:准确率、召回率、准确度和 ROC 曲线
在二元分类问题中,我们发出两个值之一——0/1、是/否等等。在更抽象的意义上,我们将讨论正例和负例。当我们的网络进行猜测时,它要么正确要么错误,所以我们有四种可能的情况,即输入示例的实际标签和网络输出,如 table 3.1 所示。
表 3.1. 二元分类问题中的四种分类结果类型
| 预测 | ||
| 正类 | ||
| 正类 | 真正例(TP) | |
| 负类 | 假正例(FP) |
真正的正例(TP)和真正的负例(TN)是模型预测出正确答案的地方;假正例(FP)和假负例(FN)是模型出错的地方。如果我们用计数填充这四个单元格,我们就得到了一个混淆矩阵;表 3.2 显示了我们钓鱼检测问题的一个假设性混淆矩阵。
表 3.2. 一个假设的二元分类问题的混淆矩阵
| 预测 | ||
| 正例 | ||
| 正例 | 4 | |
| 负例 | 1 |
在我们假设的钓鱼示例结果中,我们看到我们正确识别了四个钓鱼网页,漏掉了两个,而且有一个误报。现在让我们来看看用于表达这种性能的不同常见指标。
准确率是最简单的度量标准。它量化了多少百分比的示例被正确分类:
Accuracy = (#TP + #TN) / #examples = (#TP + #TN) / (#TP + #TN + #FP + #FN)
在我们特定的例子中,
Accuracy = (4 + 93) / 100 = 97%
准确率是一个易于沟通和易于理解的概念。然而,它可能会具有误导性——在二元分类任务中,我们通常没有相等分布的正负例。我们通常处于这样的情况:正例要远远少于负例(例如,大多数链接不是钓鱼网站,大多数零件不是有缺陷的,等等)。如果 100 个链接中只有 5 个是钓鱼的,我们的网络可以总是预测为假,并获得 95% 的准确率!这样看来,准确率似乎是我们系统的一个非常糟糕的度量。高准确率听起来总是很好,但通常会误导人。监视准确率是件好事,但作为损失函数使用则是一件非常糟糕的事情。
下一对指标试图捕捉准确率中缺失的微妙之处——精确率和召回率。在接下来的讨论中,我们通常考虑的是一个正例意味着需要进一步的行动——一个链接被标记,一篇帖子被标记为需要手动审查——而负例表示现状不变。这些指标专注于我们的预测可能出现的不同类型的“错误”。
精确率是模型预测的正例中实际为正例的比率:
precision = #TP / (#TP + #FP)
根据我们混淆矩阵的数字,我们将计算
precision = 4 / (4 + 1) = 80%
与准确率类似,通常可以操纵精确率。例如,您可以通过仅将具有非常高 S 型输出(例如 >0.95,而不是默认的 >0.5)的输入示例标记为正例,从而使您的模型非常保守地发出正面预测。这通常会导致精确率提高,但这样做可能会导致模型错过许多实际的正例(将它们标记为负例)。这最后一个成本被常与精确率配合使用并补充的度量所捕获,即召回率。
召回率是模型将实际正例分类为正例的比率:
recall = #TP / (#TP + #FN)
根据示例数据,我们得到了一个结果
recall = 4 / (4 + 2) = 66.7%
在样本集中所有阳性样本中,模型发现了多少个?通常会有一个有意识的决定,即接受较高的误报率以降低遗漏的可能性。为了优化这一指标,你可以简单地声明所有样本为阳性;由于假阳性不进入计算,因此你可以在降低精确度的代价下获得 100%的召回率。
我们可以看到,制作一个在准确度、召回率或精确度上表现出色的系统相当容易。在现实世界中的二元分类问题中,同时获得良好的精确度和召回率通常很困难。(如果这样做很容易,你就会面临一个简单的问题,可能根本不需要使用机器学习。)精确度和召回率涉及在对正确答案存在根本不确定的复杂区域调整模型。你会看到更多细致和组合的指标,如在 X%召回率下的精确度,其中 X 通常为 90%——如果我们调整到至少发现 X%的阳性样本,精确度是多少?例如,在图 3.5 中,我们看到经过 400 个轮次的训练后,当模型的概率输出门槛设为 0.5 时,我们的钓鱼检测模型能够达到 96.8%的精确度和 92.9%的召回率。
图 3.5。训练模型用于钓鱼网页检测的一轮结果示例。注意底部的各种指标:精确度、召回率和 FPR。曲线下面积(AUC)在 3.2.3 节中讨论。
如我们已略有提及的,一个重要的认识是,对正预测的选择,不需要在 sigmoid 输出上设置恰好为 0.5 的门槛。事实上,根据情况,它可能最好设定为 0.5 以上(但小于 1)或 0.5 以下(但大于 0)。降低门槛使模型在将输入标记为阳性时更加自由,这会导致更高的召回率但可能降低精确度。另一方面,提高门槛使模型在将输入标记为阳性时更加谨慎,通常会导致更高的精确度但可能降低召回率。因此,我们可以看到精确度和召回率之间存在权衡,这种权衡很难用我们迄今讨论过的任何一种指标来量化。幸运的是,二元分类研究的丰富历史为我们提供了更好的方式来量化和可视化这种权衡关系。我们接下来将讨论的 ROC 曲线是这种常用的工具之一。
3.2.3。ROC 曲线:展示二元分类中的权衡
ROC 曲线被用于广泛的工程问题,其中包括二分类或特定类型事件的检测。全名“接收者操作特性”是一个来自雷达早期的术语。现在,你几乎看不到这个扩展名了。图 3.6 是我们应用程序的一个样本 ROC 曲线。
图 3.6. 在钓鱼检测模型训练期间绘制的一组样本 ROC 曲线。每条曲线对应不同的周期数。这些曲线显示了二分类模型随着训练的进展而逐渐改进的质量。
正如你可能已经在图 3.6 的坐标轴标签中注意到的,ROC 曲线并不是通过将精确度和召回率指标相互绘制得到的。相反,它们是基于两个稍微不同的指标。ROC 曲线的横轴是假阳性率(FPR),定义为
FPR = #FP / (#FP + #TN)
ROC 曲线的纵轴是真阳性率(TPR),定义为
TPR = #TP / (#TP + #FN) = recall
TPR 与召回率具有完全相同的定义,只是使用了不同的名称。然而,FPR 是一些新的东西。分母是实际类别为负的案例数量;分子是所有误报的数量。换句话说,FPR 是将实际上是负的案例错误分类为正的比例,这是一个常常被称为*虚警(false alarm)*的概率。表 3.3 总结了在二分类问题中遇到的最常见的指标。
表 3.3. 二分类问题中常见的指标
| 指标名称 | 定义 | ROC 曲线或精确度/召回率曲线中的使用方式 |
| 准确度(Accuracy) | (#TP + #TN) / (#TP + #TN + # FP + #FN) | (ROC 曲线中不使用) |
| 精确度(Precision) | #TP / (#TP + #FP) | 精确度/召回率曲线的纵轴 |
| 召回率/灵敏度/真阳性率(TPR) | #TP / (#TP + #FN) | ROC 曲线的纵轴(如图 3.6)或精确度/召回率曲线的横轴 |
| 假阳性率(False positive rate,FPR) | #FP / (#FP + #TN) | ROC 曲线的横轴(见图 3.6) |
| 曲线下面积(Area under the curve,AUC) | 将 ROC 曲线的数值积分计算得出;查看代码示例 3.7 以获取示例 | (ROC 曲线不使用,而是从 ROC 曲线计算得到) |
图 3.6 中的七条 ROC 曲线分别绘制于七个不同的训练周期的开头,从第一个周期 (周期 001) 到最后一个周期 (周期 400)。每条曲线都是基于模型在测试数据上的预测结果(而不是训练数据)创建的。代码清单 3.6 显示了如何利用 Model.fit() API 中的 onEpochBegin 回调函数详细实现此过程。这种方法使您可以在训练过程中执行有趣的分析和可视化,而不需要编写 for 循环或使用多个 Model.fit() 调用。
代码清单 3.6 使用回调函数在模型训练中间绘制 ROC 曲线
await model.fit(trainData.data, trainData.target, { batchSize, epochs, validationSplit: 0.2, callbacks: { onEpochBegin: async (epoch) => { if ((epoch + 1)% 100 === 0 || epoch === 0 || epoch === 2 || epoch === 4) { ***1*** const probs = model.predict(testData.data); drawROC(testData.target, probs, epoch); } }, onEpochEnd: async (epoch, logs) => { await ui.updateStatus( `Epoch ${epoch + 1} of ${epochs} completed.`); trainLogs.push(logs); ui.plotLosses(trainLogs); ui.plotAccuracies(trainLogs); } } });
- 1 每隔几个周期绘制 ROC 曲线。
函数 drawROC() 的主体包含了如何创建 ROC 曲线的细节(参见代码清单 3.7)。它执行以下操作:
- 根据神经网络的 S 型输出(概率)的阈值,可获取不同分类结果的集合。
- 将 TPR 绘制在 FPR 上以形成 ROC 曲线。
- ⁸
如 图 3.6 所示,在训练开始时(周期 001),由于模型的权重是随机初始化的,ROC 曲线非常接近连接点 (0, 0) 和点 (1, 1) 的对角线。这就是随机猜测的样子。随着训练的进行,ROC 曲线越来越向左上角推进——那里的 FPR 接近 0,TPR 接近 1。如果我们专注于任何一个给定的 FPR 级别,例如 0.1,我们可以看到在训练过程中,相应的 TPR 值随着训练的进展而单调递增。简而言之,这意味着随着训练的进行,如果我们将假报警率(FPR)保持不变,就可以实现越来越高的召回率(TPR)。
“理想”的 ROC 曲线向左上角弯曲得越多,就会变成一个类似 γ^([8]) 形状的曲线。在这种情况下,您可以获得 100% 的 TPR 和 0% 的 FPR,这是任何二元分类器的“圣杯”。然而,在实际问题中,我们只能改进模型,将 ROC 曲线推向左上角,但理论上的左上角理想状态是无法实现的。
注释:γ 字母
对于每个分类结果,将其与实际标签(目标)结合使用,计算 TPR 和 FPR。
基于对 ROC 曲线形状及其含义的讨论,我们可以看到通过查看其下方的区域(即 ROC 曲线和 x 轴之间的单位正方形的空间)来量化 ROC 曲线的好坏是可能的。这被称为曲线下面积(AUC),并且也在 listing 3.7 的代码中计算。这个指标比精确率、召回率和准确率更好,因为它考虑了假阳性和假阴性之间的权衡。随机猜测的 ROC 曲线(对角线)的 AUC 为 0.5,而γ形状的理想 ROC 曲线的 AUC 为 1.0。我们的钓鱼检测模型在训练后达到了 0.981 的 AUC。
listing 3.7 的代码用于计算和绘制 ROC 曲线和 AUC
function drawROC(targets, probs, epoch) { return tf.tidy(() => { const thresholds = [ ***1*** 0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, ***1*** 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, ***1*** 0.9, 0.92, 0.94, 0.96, 0.98, 1.0 ***1*** ]; ***1*** const tprs = []; // True positive rates. const fprs = []; // False positive rates. let area = 0; for (let i = 0; i < thresholds.length; ++i) { const threshold = thresholds[i]; const threshPredictions = ***2*** utils.binarize(probs, threshold).as1D(); ***2*** const fpr = falsePositiveRate( ***3*** targets, ***3*** threshPredictions).arraySync(); ***3*** const tpr = tf.metrics.recall(targets, threshPredictions).arraySync(); fprs.push(fpr); tprs.push(tpr); if (i > 0) { ***4*** area += (tprs[i] + tprs[i - 1]) * (fprs[i - 1] - fprs[i]) / 2; ***4*** } ***4*** } ui.plotROC(fprs, tprs, epoch); return area; }); }
- 1 一组手动选择的概率阈值
- 2 通过阈值将概率转换为预测
- 3 falsePositiveRate()函数通过比较预测和实际目标来计算假阳性率。该函数在同一文件中定义。
- 4 用于 AUC 计算的面积累积
除了可视化二元分类器的特性外,ROC 还帮助我们在实际情况下做出明智的选择,比如如何选择概率阈值。例如,想象一下,我们是一家商业公司,正在开发钓鱼检测器作为一项服务。我们想要做以下哪项?
- 由于错过了真实的网络钓鱼网站将会在责任或失去合同方面给我们造成巨大的损失,因此将阈值设定相对较低。
- 由于我们更不愿意接受将正常网站误分类为可疑而导致用户提交投诉,因此将阈值设定相对较高。
每个阈值对应于 ROC 曲线上的一个点。当我们将阈值从 0 逐渐增加到 1 时,我们从图的右上角(其中 FPR 和 TPR 都为 1)移动到图的左下角(其中 FPR 和 TPR 都为 0)。在实际的工程问题中,选择 ROC 曲线上的哪个点的决定总是基于权衡这种相反的现实生活成本,并且在不同的客户和不同的业务发展阶段可能会有所不同。
除了 ROC 曲线之外,二元分类的另一个常用可视化方法是精确率-召回率曲线(有时称为 P/R 曲线,在 table 3.3 中简要提到)。与 ROC 曲线不同,精确率-召回率曲线将精确率绘制为召回率的函数。由于精确率-召回率曲线在概念上与 ROC 曲线相似,我们在这里不会深入讨论它们。
在 代码清单 3.7 中值得指出的一点是使用了 tf.tidy()。这个函数确保了在作为参数传递给它的匿名函数内创建的张量被正确地处理,这样它们就不会继续占用 WebGL 内存。在浏览器中,TensorFlow.js 无法管理用户创建的张量的内存,主要是因为 JavaScript 中缺乏对象终结和底层 TensorFlow.js 张量下层的 WebGL 纹理缺乏垃圾回收。如果这样的中间张量没有被正确清理,就会发生 WebGL 内存泄漏。如果允许这样的内存泄漏持续足够长的时间,最终会导致 WebGL 内存不足错误。附录 B 的 章节 1.3 包含了有关 TensorFlow.js 内存管理的详细教程。此外,附录 B 的 章节 1.5 中还有关于这个主题的练习题。如果您计划通过组合 TensorFlow.js 函数来定义自定义函数,您应该仔细研究这些章节。
3.2.4. 二元交叉熵:二元分类的损失函数
到目前为止,我们已经讨论了几种不同的度量标准,用于量化二元分类器的不同表现方面,比如准确率、精确率和召回率(表 3.3)。但我们还没有讨论一个重要的度量标准,一个可以微分并生成梯度来支持模型梯度下降训练的度量标准。这就是我们在 代码清单 3.5 中简要看到的 binaryCrossentropy,但我们还没有解释过:
model.compile({ optimizer: 'adam', loss: 'binaryCrossentropy', metrics: ['accuracy'] });
首先,你可能会问,为什么不能直接以精确度、准确度、召回率,或者甚至 AUC 作为损失函数?毕竟这些指标容易理解。此外,在之前我们见过的回归问题中,我们使用了 MSE 作为训练的损失函数,这是一个相当容易理解的指标。答案是,这些二分类度量指标都无法产生我们需要训练的梯度。以精确度指标为例:要了解为什么它不友好的梯度,请认识到计算精确度需要确定模型的预测哪些是正样本,哪些是负样本(参见 表 3.3 的第一行)。为了做到这一点,必须应用一个 阈值函数,将模型的 sigmoid 输出转换为二进制预测。这里就是问题的关键:虽然阈值函数(在更技术的术语中称为step function)几乎在任何地方都是可微分的(“几乎”是因为它在 0.5 的“跳跃点”处不可微分),但其导数始终恰好为零(参见图 3.7)!如果您试图通过该阈值函数进行反向传播会发生什么呢?因为上游梯度值在某些地方需要与该阈值函数的所有零导数相乘,所以您的梯度最终将全是零。更简单地说,如果将精确度(或准确度、召回率、AUC 等)选为损失,底层阶跃函数的平坦部分使得训练过程无法知道在权重空间中向哪个方向移动可以降低损失值。
图 3.7 用于转换二分类模型的概率输出的阶跃函数,几乎在每个可微点都是可微分的。不幸的是,每个可微分点的梯度(导数)恰好为零。
因此,如果使用精确度作为损失函数,便无法计算有用的梯度,从而阻止了在模型的权重上获得有意义的更新。此限制同样适用于包括准确度、召回率、FPR 和 AUC 在内的度量。虽然这些指标对人类理解二分类器的行为很有用,但对于这些模型的训练过程来说是无用的。
我们针对二分类任务使用的损失函数是二进制交叉熵,它对应于我们的钓鱼检测模型代码中的 'binaryCrossentropy' 配置(见列表 3.5 和 3.6)。算法上,我们可以用以下伪代码来定义二进制交叉熵。
列表 3.8 二进制交叉熵损失函数的伪代码^([9])
⁹
binaryCrossentropy的实际代码需要防范prob或1 - prob等恰好为零的情况,否则如果将这些值直接传递给log函数,会导致无穷大。这是通过在将它们传递给对数函数之前添加一个非常小的正数(例如1e-6,通常称为“epsilon”或“修正因子”)来实现的。
function binaryCrossentropy(truthLabel, prob): if truthLabel is 1: return -log(prob) else: return -log(1 - prob)
在此伪代码中,truthLabel 是一个数字,取 0 到 1 的值,指示输入样本在现实中是否具有负(0)或正(1)标签。prob 是模型预测的样本属于正类的概率。请注意,与 truthLabel 不同,prob 应为实数,可以取 0 到 1 之间的任何值。log 是自然对数,以 e(2.718)为底,您可能还记得它来自高中数学。binaryCrossentropy 函数的主体包含一个 if-else 逻辑分支,根据 truthLabel 是 0 还是 1 执行不同的计算。图 3.8 在同一图中绘制了这两种情况。
图 3.8。二元交叉熵损失函数。两种情况(truthLabel = 1 和 truthLabel = 0)分别绘制在一起,反映了 代码清单 3.8 中的 if-else 逻辑分支。
在查看 图 3.8 中的图表时,请记住较低的值更好,因为这是一个损失函数。关于损失函数需要注意的重要事项如下:
- 如果
truthLabel为 1,prob值接近 1.0 会导致较低的损失函数值。这是有道理的,因为当样本实际上是正例时,我们希望模型输出的概率尽可能接近 1.0。反之亦然:如果truthLabel为 0,则当概率值接近 0 时,损失值较低。这也是有道理的,因为在这种情况下,我们希望模型输出的概率尽可能接近 0。 - 与 图 3.7 中显示的二进制阈值函数不同,这些曲线在每个点都有非零斜率,导致非零梯度。这就是为什么它适用于基于反向传播的模型训练。
你可能会问的一个问题是,为什么不重复我们为回归模型所做的事情——只是假装 0-1 值是回归目标,并使用 MSE 作为损失函数?毕竟,MSE 是可微分的,并且计算真实标签和概率之间的 MSE 会产生与binaryCrossentropy一样的非零导数。答案与 MSE 在边界处具有“递减收益”有关。例如,在 表 3.4 中,我们列出了当 truthLabel 为 1 时一些 prob 值的 binaryCrossentropy 和 MSE 损失值。当 prob 接近 1(期望值)时,MSE 相对于binaryCrossentropy的减小速度会越来越慢。因此,当 prob 已经接近 1(例如,0.9)时,它不太好地“鼓励”模型产生较高(接近 1)的 prob 值。同样,当 truthLabel 为 0 时,MSE 也不如 binaryCrossentropy 那样好,不能生成推动模型的 prob 输出向 0 靠近的梯度。
表 3.4. 比较假想的二分类结果的二元交叉熵和 MSE 值
| 真实标签 | 概率 | 二元交叉熵 | MSE |
| 1 | 0.1 | 2.302 | 0.81 |
| 1 | 0.5 | 0.693 | 0.25 |
| 1 | 0.9 | 0.100 | 0.01 |
| 1 | 0.99 | 0.010 | 0.0001 |
| 1 | 0.999 | 0.001 | 0.000001 |
| 1 | 1 | 0 | 0 |
这展示了二分类问题与回归问题不同的另一个方面:对于二分类问题,损失(binaryCrossentropy)和指标(准确率、精确率等)是不同的,而对于回归问题通常是相同的(例如,meanSquaredError)。正如我们将在下一节看到的那样,多类别分类问题也涉及不同的损失函数和指标。
3.3. 多类别分类
在 第 3.2 节 中,我们探讨了如何构建二分类问题的结构;现在我们将快速进入 非二分类 的处理方式——即,涉及三个或更多类别的分类任务。^([10]) 我们将使用用于说明多类别分类的数据集是 鸢尾花数据集,这是一个有着统计学根源的著名数据集(参见 en.wikipedia.org/wiki/Iris_flower_data_set)。这个数据集关注于三种鸢尾花的品种,分别为 山鸢尾、变色鸢尾 和 维吉尼亚鸢尾。这三种鸢尾花可以根据它们的形状和大小来区分。在 20 世纪初,英国统计学家罗纳德·费舍尔测量了 150 个鸢尾花样本的花瓣和萼片(花的不同部位)的长度和宽度。这个数据集是平衡的:每个目标标签都有确切的 50 个样本。
¹⁰
不要混淆 多类别 分类和 多标签 分类。在多标签分类中,单个输入示例可能对应于多个输出类别。一个例子是检测输入图像中各种类型物体的存在。一个图像可能只包括一个人;另一个图像可能包括一个人、一辆车和一个动物。多标签分类器需要生成一个表示适用于输入示例的所有类别的输出,无论该类别是一个还是多个。本节不涉及多标签分类。相反,我们专注于更简单的单标签、多类别分类,其中每个输入示例都对应于>2 个可能类别中的一个输出类别。
在这个问题中,我们的模型以四个数值特征(花瓣长度、花瓣宽度、萼片长度和萼片宽度)作为输入,并尝试预测一个目标标签(三种物种之一)。该示例位于 tfjs-examples 的 iris 文件夹中,您可以使用以下命令查看并运行:
git clone https://github.com/tensorflow/tfjs-examples.git cd tfjs-examples/iris yarn && yarn watch
3.3.1. 对分类数据进行 one-hot 编码
在研究解决鸢尾花分类问题的模型之前,我们需要强调这个多类别分类任务中分类目标(物种)的表示方式。到目前为止,在本书中我们看到的所有机器学习示例都涉及更简单的目标表示,例如下载时间预测问题中的单个数字以及波士顿房屋问题中的数字,以及钓鱼检测问题中的二进制目标的 0-1 表示。然而,在鸢尾问题中,三种花的物种以稍微不那么熟悉的方式称为 one-hot 编码 进行表示。打开 data.js,您将注意到这一行:
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
这里,shuffledTargets 是一个普通的 JavaScript 数组,其中包含按随机顺序排列的示例的整数标签。其元素的值均为 0、1 和 2,反映了数据集中的三种鸢尾花品种。通过调用 tf.tensor1d(shuffledTargets).toInt(),它被转换为 int32 类型的 1D 张量。然后将结果的 1D 张量传递到 tf.oneHot() 函数中,该函数返回形状为 [numExamples, IRIS_NUM_CLASSES] 的 2D 张量。numExamples 是 targets 包含的示例数,而 IRIS_NUM_CLASSES 简单地是常量 3。您可以通过在先前引用的行下面添加一些打印行来查看 targets 和 ys 的实际值,例如:
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES); // Added lines for printing the values of `targets` and `ys`. console.log('Value of targets:', targets); ys.print();[11]
¹¹
与
targets不同,ys不是一个普通的 JavaScript 数组。相反,它是由 GPU 内存支持的张量对象。因此,常规的 console.log 不会显示其值。print()方法是专门用于从 GPU 中检索值,以形状感知和人性化的方式进行格式化,并将其记录到控制台的方法。
一旦您进行了这些更改,Yarn watch 命令在终端启动的包捆绑器进程将自动重建 Web 文件。然后,您可以打开用于观看此演示的浏览器选项卡中的开发工具,并刷新页面。console.log() 和 print() 调用的打印消息将记录在开发工具的控制台中。您将看到的打印消息将类似于这样:
Value of targets: (50) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] Tensor [[1, 0, 0], [1, 0, 0], [1, 0, 0], ..., [1, 0, 0], [1, 0, 0], [1, 0, 0]]
或者
Value of targets: (50) [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Tensor [[0, 1, 0], [0, 1, 0], [0, 1, 0], ..., [0, 1, 0], [0, 1, 0], [0, 1, 0]]
等等。用言语来描述,以整数标签 0 为例,您会得到一个值为 [1, 0, 0] 的值行;对于整数标签为 1 的示例,您会得到一个值为 [0, 1, 0] 的行,依此类推。这是独热编码的一个简单明了的例子:它将一个整数标签转换为一个向量,该向量除了在对应标签的索引处的值为 1 之外,其余都为零。向量的长度等于所有可能类别的数量。向量中只有一个 1 值的事实正是这种编码方案被称为“独热”的原因。
对于您来说,这种编码可能看起来过于复杂了。在一个类别中使用三个数字来表示,为什么不使用一个单一的数字就能完成任务呢?为什么我们选择这种复杂的编码而不是更简单和更经济的单整数索引编码呢?这可以从两个不同的角度来理解。
首先,对于神经网络来说,输出连续的浮点型值要比整数值容易得多。在浮点型输出上应用舍入也不够优雅。一个更加优雅和自然的方法是,神经网络的最后一层输出几个单独的浮点型数值,每个数值通过一个类似于我们用于二元分类的 S 型激活函数的精心选择的激活函数被限制在 [0, 1] 区间内。在这种方法中,每个数字都是模型对输入示例属于相应类别的概率的估计。这正是独热编码的用途:它是概率分数的“正确答案”,模型应该通过其训练过程来拟合。
第二,通过将类别编码为整数,我们隐含地为类别创建了一个顺序。例如,我们可以将 鸢尾花 setosa 标记为 0,鸢尾花 versicolor 标记为 1,鸢尾花 virginica 标记为 2。但是,这样的编号方案通常是人为的和不合理的。例如,这种编号方案暗示 setosa 比 versicolor 更“接近” virginica,这可能并不正确。神经网络基于实数进行操作,并且基于诸如乘法和加法之类的数学运算。因此,它们对数字的数量和顺序敏感。如果将类别编码为单一数字,则成为神经网络必须学习的额外非线性关系。相比之下,独热编码的类别不涉及任何隐含的排序,因此不会以这种方式限制神经网络的学习能力。
就像我们将在第九章中看到的那样,独热编码不仅用于神经网络的输出目标,而且还适用于分类数据形成神经网络的输入。
3.3.2. Softmax 激活函数
了解了输入特征和输出目标的表示方式后,我们现在可以查看定义模型的代码(来自 iris/index.js)。
列表 3.9. 用于鸢尾花分类的多层神经网络
const model = tf.sequential(); model.add(tf.layers.dense( {units: 10, activation: 'sigmoid', inputShape: [xTrain.shape[1]]})); model.add(tf.layers.dense({units: 3, activation: 'softmax'})); model.summary(); const optimizer = tf.train.adam(params.learningRate); model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'], });
在列表 3.9 中定义的模型导致了以下摘要:
_________________________________________________________________ Layer (type) Output shape Param # ================================================================= dense_Dense1 (Dense) [null,10] 50 ________________________________________________________________ dense_Dense2 (Dense) [null,3] 33 ================================================================= Total params: 83 Trainable params: 83 Non-trainable params: ________________________________________________________________
通过查看打印的概述,我们可以看出这是一个相当简单的模型,具有相对较少的(83 个)权重参数。输出形状[null, 3]对应于分类目标的独热编码。最后一层使用的激活函数,即softmax,专门设计用于多分类问题。softmax 的数学定义可以写成以下伪代码:
softmax([x1, x2, ..., xn]) = [exp(x1) / (exp(x1) + exp(x2) + ... + exp(xn)), exp(x2) / (exp(x1) + exp(x2) + ... + exp(xn)), ..., exp(xn) / (exp(x1) + exp(x2) + ... + exp(xn))]
与我们之前见过的 sigmoid 激活函数不同,softmax 激活函数不是逐元素的,因为输入向量的每个元素都以依赖于所有其他元素的方式进行转换。具体来说,输入的每个元素被转换为其指数(以* e*=2.718 为底)的自然指数。然后指数被除以所有元素的指数的和。这样做有什么作用?首先,它确保了每个数字都在 0 到 1 的区间内。其次,保证了输出向量的所有元素之和为 1。这是一个理想的属性,因为 1)输出可以被解释为分配给各个类别的概率得分,2)为了与分类交叉熵损失函数兼容,输出必须满足此属性。第三,该定义确保输入向量中的较大元素映射到输出向量中的较大元素。举个具体的例子,假设最后一个密集层的矩阵乘法和偏置相加生成了一个向量
[-3, 0, -8]
它的长度为 3,因为密集层被配置为具有 3 个单元。请注意,这些元素是浮点数,不受特定范围的约束。softmax 激活函数将向量转换为
[0.0474107, 0.9522698, 0.0003195]
您可以通过运行以下 TensorFlow.js 代码(例如,在页面指向js.tensorflow.org时,在开发工具控制台中)来自行验证这一点:
const x = tf.tensor1d([-3, 0, -8]); tf.softmax(x).print();
Softmax 函数的输出有三个元素。1)它们都在[0, 1]区间内,2)它们的和为 1,3)它们的顺序与输入向量中的顺序相匹配。由于这些属性的存在,输出可以被解释为被模型分配的(概率)值,表示所有可能的类别。在前面的代码片段中,第二个类别被分配了最高的概率,而第一个类别被分配了最低的概率。
因此,当使用这种多类别分类器的输出时,你可以选择最高 softmax 元素的索引作为最终决策——也就是输入属于哪个类别的决策。这可以通过使用方法 argMax() 来实现。例如,这是 index.js 的摘录:
const predictOut = model.predict(input); const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
predictOut 是形状为 [numExamples, 3] 的二维张量。调用它的 argMax0 方法会导致形状被减少为 [numExample]。参数值 -1 表示 argMax() 应该在最后一个维度上查找最大值并返回它们的索引。例如,假设 predictOut 有以下值:
[[0 , 0.6, 0.4], [0.8, 0 , 0.2]]
那么,argMax(-1) 将返回一个张量,指示沿着最后(第二个)维度找到的最大值分别在第一个和第二个示例的索引为 1 和 0:
[1, 0]
3.3.3. 分类交叉熵:多类别分类的损失函数
在二元分类示例中,我们看到了如何使用二元交叉熵作为损失函数,以及为什么其他更易于人类理解的指标,如准确率和召回率,不能用作损失函数。多类别分类的情况相当类似。存在一个直观的度量标准——准确率——它是模型正确分类的例子的比例。这个指标对于人们理解模型的性能有重要意义,并且在 列表 3.9 中的这段代码片段中使用:
model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'], });
然而,准确率对于损失函数来说是一个糟糕的选择,因为它遇到了与二元分类中的准确率相同的零梯度问题。因此,人们为多类别分类设计了一个特殊的损失函数:分类交叉熵。它只是将二元交叉熵推广到存在两个以上类别的情况。
列表 3.10. 用于分类交叉熵损失的伪代码
function categoricalCrossentropy(oneHotTruth, probs): for i in (0 to length of oneHotTruth) if oneHotTruth(i) is equal to 1 return -log(probs[i]);
在前面的伪代码中,oneHotTruth 是输入示例的实际类别的独热编码。probs 是模型的 softmax 概率输出。从这段伪代码中可以得出的关键信息是,就分类交叉熵而言,probs 中只有一个元素是重要的,那就是与实际类别对应的索引的元素。probs 的其他元素可以随意变化,但只要它们不改变实际类别的元素,就不会影响分类交叉熵。对于 probs 的特定元素,它越接近 1,交叉熵的值就越低。与二元交叉熵类似,分类交叉熵直接作为 tf.metrics 命名空间下的一个函数可用,你可以用它来计算简单但说明性的示例的分类交叉熵。例如,使用以下代码,你可以创建一个假设的独热编码的真实标签和一个假设的 probs 向量,并计算相应的分类交叉熵值:
const oneHotTruth = tf.tensor1d([0, 1, 0]); const probs = tf.tensor1d([0.2, 0.5, 0.3]); tf.metrics.categoricalCrossentropy(oneHotTruth, probs).print();
这给出了一个约为 0.693 的答案。这意味着当模型对实际类别分配的概率为 0.5 时,categoricalCrossentropy的值为 0.693。你可以根据 pseudo-code(伪代码)进行验证。你也可以尝试将值从 0.5 提高或降低,看看categoricalCrossentropy如何变化(例如,参见 table 3.5)。表中还包括一列显示了单热真实标签和probs向量之间的 MSE。
表 3.5. 不同概率输出下的分类交叉熵值。不失一般性,所有示例(行)都是基于有三个类别的情况(如鸢尾花数据集),并且实际类别是第二个类别。
| One-hot truth label | probs (softmax output) | Categorical cross entropy | MSE |
| [0, 1, 0] | [0.2, 0.5, 0.3] | 0.693 | 0.127 |
| [0, 1, 0] | [0.0, 0.5, 0.5] | 0.693 | 0.167 |
| [0, 1, 0] | [0.0, 0.9, 0.1] | 0.105 | 0.006 |
| [0, 1, 0] | [0.1, 0.9, 0.0] | 0.105 | 0.006 |
| [0, 1, 0] | [0.0, 0.99, 0.01] | 0.010 | 0.00006 |
通过比较表中的第 1 行和第 2 行,或比较第 3 行和第 4 行,可以明显看出更改probs中与实际类别不对应的元素不会改变二元交叉熵的值,尽管这可能会改变单热真实标签和probs之间的 MSE。同样,就像在二元交叉熵中一样,当probs值接近 1 时,MSE 显示出递减的回报,并且在这个区间内,MSE 不适合鼓励正确类别的概率值上升,而分类熵则更适合作为多类别分类问题的损失函数。
3.3.4. 混淆矩阵:多类别分类的细致分析
点击示例网页上的从头开始训练模型按钮,你可以在几秒钟内得到一个经过训练的模型。正如图 3.9 所示,模型经过 40 个训练周期后几乎达到了完美的准确度。这反映了鸢尾花数据集是一个相对较小且在特征空间中类别边界相对明确的数据集的事实。
图 3.9. 40 个训练周期后鸢尾花模型的典型结果。左上方:损失函数随训练周期变化的图表。右上方:准确度随训练周期变化的图表。底部:混淆矩阵。
图 3.9 的底部显示了描述多类分类器行为的另一种方式,称为混淆矩阵。混淆矩阵根据其实际类别和模型预测类别将多类分类器的结果进行了细分。它是一个形状为[numClasses, numClasses]的方阵。索引[i, j](第 i 行和第 j 列)处的元素是属于类别i并由模型预测为类别j的示例数量。因此,混淆矩阵的对角线元素对应于正确分类的示例。一个完美的多类分类器应该产生一个没有对角线之外的非零元素的混淆矩阵。这正是图 3.9 中的混淆矩阵的情况。
除了展示最终的混淆矩阵外,鸢尾花示例还在每个训练周期结束时使用onTrainEnd()回调绘制混淆矩阵。在早期周期中,您可能会看到一个不太完美的混淆矩阵,与图 3.9 中的混淆矩阵不同。图 3.10 中的混淆矩阵显示,24 个输入示例中有 8 个被错误分类,对应的准确率为 66.7%。然而,混淆矩阵告诉我们不仅仅是一个数字:它显示了哪些类别涉及最多的错误,哪些涉及较少。在这个特定的示例中,所有来自第二类的花都被错误分类(要么作为第一类,要么作为第三类),而来自第一类和第三类的花总是被正确分类。因此,您可以看到,在多类分类中,混淆矩阵比简单的准确率更具信息量,就像精确率和召回率一起形成了比二分类准确率更全面的衡量标准一样。混淆矩阵可以提供有助于与模型和训练过程相关的决策的信息。例如,某些类型的错误可能比混淆其他类别对更为昂贵。也许将一个体育网站误认为游戏网站不如将体育网站误认为钓鱼网站那么严重。在这些情况下,您可以调整模型的超参数以最小化最昂贵的错误。
图 3.10. 一个“不完美”混淆矩阵的示例,在对角线之外存在非零元素。该混淆矩阵是在训练收敛之前的仅 2 个周期后生成的。
到目前为止,我们所见的模型都将一组数字作为输入。换句话说,每个输入示例都表示为一组简单的数字列表,其中长度固定,元素的排序不重要,只要它们对馈送到模型的所有示例都一致即可。虽然这种类型的模型涵盖了重要和实用的机器学习问题的大量子集,但它远非唯一的类型。在接下来的章节中,我们将研究更复杂的输入数据类型,包括图像和序列。在 第四章 中,我们将从图像开始,这是一种无处不在且广泛有用的输入数据类型,为此已经开发了强大的神经网络结构,以将机器学习模型的准确性推向超人级别。
练习
- 当创建用于波士顿房屋问题的神经网络时,我们停留在一个具有两个隐藏层的模型上。鉴于我们所说的级联非线性函数会增强模型的容量,那么将更多的隐藏层添加到模型中会导致评估准确性提高吗?通过修改 index.js 并重新运行训练和评估来尝试一下。
- 是什么因素阻止了更多的隐藏层提高评估准确性?
- 是什么让您得出这个结论?(提示:看一下训练集上的误差。)
- 看看 清单 3.6 中的代码如何使用
onEpochBegin回调在每个训练时期的开始计算并绘制 ROC 曲线。您能按照这种模式并对回调函数的主体进行一些修改,以便您可以在每个时期的开始打印精度和召回率值(在测试集上计算)吗?描述这些值随着训练的进行而如何变化。 - 研究 清单 3.7 中的代码,并理解它是如何计算 ROC 曲线的。您能按照这个示例并编写一个新的函数,名为
drawPrecisionRecallCurve(),它根据名称显示一个精度-召回率曲线吗?写完函数后,从onEpochBegin回调中调用它,以便在每个训练时期的开始绘制一个精度-召回率曲线。您可能需要对 ui.js 进行一些修改或添加。 - 假设您得知二元分类器结果的 FPR 和 TPR。凭借这两个数字,您能计算出整体准确性吗?如果不能,您需要什么额外信息?
- 二元交叉熵(3.2.4 节)和分类交叉熵(3.3.3 节)的定义都基于自然对数(以 e 为底的对数)。如果我们改变定义,让它们使用以 10 为底的对数会怎样?这会如何影响二元和多类分类器的训练和推断?
- 将超参数网格搜索的伪代码转换为实际的 JavaScript 代码,并使用该代码对列表 3.1 中的两层波士顿房屋模型进行超参数优化。具体来说,调整隐藏层的单位数和学习率。可以自行决定要搜索的单位和学习率的范围。注意,机器学习工程师通常使用近似几何序列(即对数)间隔进行这些搜索(例如,单位= 2、5、10、20、50、100、200,…)。
摘要
- 分类任务与回归任务不同,因为它们涉及进行离散预测。
- 分类有两种类型:二元和多类。在二元分类中,对于给定的输入,有两种可能的类别,而在多类分类中,有三个或更多。
- 二元分类通常可以被看作是在所有输入示例中检测一种称为正例的特定类型事件或对象。从这个角度来看,我们可以使用精确率、召回率和 FPR 等指标,除了准确度,来量化二元分类器行为的各个方面。
- 在二元分类任务中,需要在捕获所有正例和最小化假阳性(误报警)之间进行权衡是很常见的。ROC 曲线与相关的 AUC 指标是一种帮助我们量化和可视化这种关系的技术。
- 为了进行二元分类而创建的神经网络应该在其最后(输出)层使用 sigmoid 激活,并在训练过程中使用二元交叉熵作为损失函数。
- 为了创建一个用于多类分类的神经网络,输出目标通常由独热编码表示。神经网络应该在其输出层使用 softmax 激活,并使用分类交叉熵损失函数进行训练。
- 对于多类分类,混淆矩阵可以提供比准确度更细粒度的信息,关于模型所犯错误的信息。
- 表 3.6 总结了迄今为止我们见过的最常见的机器学习问题类型(回归、二元分类和多类分类)的推荐方法。
- 超参数是关于机器学习模型结构、其层属性以及其训练过程的配置。它们与模型的权重参数不同,因为 1)它们在模型的训练过程中不变化,2)它们通常是离散的。超参数优化是一种寻找超参数值以在验证数据集上最小化损失的过程。超参数优化仍然是一个活跃的研究领域。目前,最常用的方法包括网格搜索、随机搜索和贝叶斯方法。
表格 3.6. 最常见的机器学习任务类型,它们适用的最后一层激活函数和损失函数,以及有助于量化模型质量的指标的概述
| 任务类型 | 输出层的激活函数 | 损失函数 | 在 Model.fit() 调用中支持的适用指标 | 额外的指标 |
| 回归 | ‘linear’ (默认) | ‘meanSquaredError’ 或 ‘meanAbsoluteError’ | (与损失函数相同) | |
| 二分类 | ‘sigmoid’ | ‘binaryCrossentropy’ | ‘accuracy’ | 精确率,召回率,精确-召回曲线,ROC 曲线,AUC 值 |
| 单标签,多类别分类 | ‘softmax’ | ‘categoricalCrossentropy’ | ‘accuracy’ | 混淆矩阵 |
- 7 ↩︎