PyTorch 深度学习(GPT 重译)(五)(1)

简介: PyTorch 深度学习(GPT 重译)(五)

十二、通过指标和增强改进训练

本章涵盖

  • 定义和计算精确率、召回率以及真/假阳性/阴性
  • 使用 F1 分数与其他质量指标
  • 平衡和增强数据以减少过拟合
  • 使用 TensorBoard 绘制质量指标图

上一章的结束让我们陷入了困境。虽然我们能够将深度学习项目的机制放置好,但实际上没有任何结果是有用的;网络只是将一切都分类为非结节!更糟糕的是,结果表面看起来很好,因为我们正在查看训练和验证集中被正确分类的整体百分比。由于我们的数据严重倾向于负样本,盲目地将一切都视为负面是我们的模型快速得分的一种简单而快速的方法。太糟糕了,这样做基本上使模型无用!

这意味着我们仍然专注于与第十一章相同的图 12.1 的同一部分。但现在我们正在努力使我们的分类模型工作良好而不是只是工作。本章重点讨论如何衡量、量化、表达,然后改进我们的模型执行工作的能力。

图 12.1 我们的端到端肺癌检测项目,重点放在本章的主题上:第 4 步,分类

12.1 改进的高层计划

虽然有点抽象,图 12.2 向我们展示了我们将如何处理那些广泛的主题。

让我们详细地走过本章的这张有些抽象的地图。我们将处理我们面临的问题,比如过度关注单一、狭窄的指标以及由此产生的行为在一般意义上是无用的。为了使本章的一些概念更具体化,我们将首先使用一个比喻来将我们的困境更具体化:在图 12.2 中,(1)看门狗和(2)鸟和窃贼。

图 12.2 我们将使用的比喻来修改衡量我们模型的指标,使其变得出色

之后,我们将开发一个图形语言来代表上一章实施中所需的核心概念:(3)比率:召回率和精确率。一旦我们将这些概念巩固下来,我们将涉及一些使用这些概念的数学,这将包括一种更健壮的评估我们模型性能的方式,并将其压缩为一个数字:(4)新指标:F1 分数。我们将实施这些新指标的公式,并查看在训练过程中每个时期这些结果值如何变化。最后,我们将对我们的LunaDataset实现进行一些急需的更改,以改善我们的训练结果:(5)平衡和(6)增强。然后我们将看看这些实验性的更改是否对我们的性能指标产生了预期的影响。

到本章结束时,我们训练的模型将表现得更好:(7)工作得很棒!虽然它还没有准备好立即投入临床使用,但它将能够产生明显优于随机的结果。这意味着我们已经有了可行的第 4 步实现,结节候选分类;一旦完成,我们可以开始考虑如何将第 2 步(分割)和第 3 步(分组)纳入项目中。

12.2 好狗与坏家伙:假阳性和假阴性

我们不再考虑模型和肿瘤,而是考虑图 12.3 中的两只看门狗,它们刚从服从学校毕业。它们都想警告我们有窃贼——这是一种罕见但严重的情况,需要及时处理。

图 12.3 本章的主题集,重点放在框架比喻上

不幸的是,虽然两只狗都是好狗,但都不是好的警卫狗。我们的梗犬(Roxie)对几乎所有事情都会吠,而我们的老猎犬(Preston)几乎只会对入室者吠叫——但前提是他在他们到达时恰好醒着。

Roxie 几乎每次都会警告我们有入室者。她还会警告我们有消防车、雷暴、直升机、鸟、邮递员、松鼠、路人等。如果我们对每次吠叫进行跟进,我们几乎永远不会被抢劫(只有最狡猾的偷窃者才能溜过)。完美!… 除了那么勤奋意味着我们实际上并没有通过养警卫狗节省任何工作。相反,我们每隔几个小时就会起床,手持手电筒,因为 Roxie 闻到了猫的气味,或者听到了猫头鹰的叫声,或者看到了一辆晚点的公共汽车经过。Roxie 有一个问题性的假阳性数量。

假阳性是被分类为感兴趣或所需类别的成员(阳性表示“是的,这是我感兴趣了解的类型”)的事件,但实际上并不是真正感兴趣的。对于结节检测问题,当一个实际上无趣的候选者被标记为结节,因此需要放射科医生的关注时,就会发生假阳性。对于 Roxie 来说,这些可能是消防车、雷暴等。在接下来的章节和随后的图中,我们将使用一张猫的图片作为典型的假阳性。

将假阳性与真阳性进行对比:被正确分类的感兴趣项目。这些将在图中由一个人类强盗表示。

与此同时,如果 Preston 吠叫,请立即报警,因为这意味着几乎肯定有人闯入,房子着火了,或者哥斯拉在袭击。然而,Preston 睡得很沉,正在进行家庭入侵的声音不太可能唤醒他,所以每当有人尝试时,我们几乎总是会被抢劫。虽然比没有好,但我们并没有真正获得最初让我们养狗的平静心态。Preston 有一个问题性的假阴性数量。

假阴性是被分类为不感兴趣或不是所需类别的成员(阴性表示“不,这不是我感兴趣了解的类型”)的事件,但实际上确实是感兴趣的。对于结节检测问题,当一个结节(即潜在的癌症)未被检测到时,就会发生假阴性。对于 Preston 来说,这些将是他睡过的抢劫案。在这里我们将有点创意,使用一张啮齿强盗的图片来代表假阴性。它们很狡猾!

将假阴性与真阴性进行对比:正确识别为无趣项目的项目。我们将用一只鸟的图片来代表这些。

为了完成这个比喻,第十一章的模型基本上是一只拒绝对不是金鱼罐的任何东西发出喵声的猫(同时坚定地忽视 Roxie)。我们在上一章末尾的重点是整体训练和验证集的正确百分比。显然,这不是一个很好的评分方式,正如我们从我们每只狗对单一指标的近视关注——比如真阳性或真阴性的数量——可以看出的那样,我们需要一个更广泛关注的指标来捕捉我们的整体表现。

12.3 绘制阳性和阴性

让我们开始制定我们将用来描述真/假阳性/阴性的视觉语言。如果我们的解释变得重复,请耐心等待;我们希望确保您对我们将要讨论的比率形成坚实的心理模型。考虑图 12.4,显示了可能对我们其中一只警卫狗感兴趣的事件。


图 12.4 中的猫、鸟、啮齿动物和强盗构成了我们的四个分类象限。它们由人类标签和狗的分类阈值分隔。

在图 12.4 中,我们将使用两个阈值。第一个是人为决定的将入室盗窃犯与无害动物分开的分界线。具体来说,这是为每个训练或验证样本分配的标签。第二个是狗确定的分类阈值,它决定了狗是否会对某物吠叫。对于深度学习模型,这是在考虑样本时模型产生的预测值。

这两个阈值的组合将我们的事件分成四个象限:真/假阳性/阴性。我们将关注的事件用较深的背景色进行阴影处理(因为那些坏家伙总是在黑暗中潜行)。

当然,现实要复杂得多。并不存在一个关于入室盗窃犯的柏拉图理想,也没有一个相对于分类阈值的单一点,所有入室盗窃犯都会在那里。相反,图 12.5 向我们展示了一些入室盗窃犯会特别狡猾,一些鸟类会特别烦人。我们还将把我们的实例放在一个图中。我们的 X 轴将保持每个事件的吠声价值,由我们的一只看门狗确定。我们将让 Y 轴代表我们作为人类能够感知的一些模糊特质,但我们的狗却无法感知。

由于我们的模型产生二元分类,我们可以将预测阈值视为将单一数值输出与我们的分类阈值值进行比较。这就是为什么我们要求图 12.5 中的分类阈值线是完全垂直的。


图 12.5 每种事件都会有许多可能的实例,我们的看门狗需要评估。

每个可能的入室盗窃犯都是不同的,因此我们的看门狗将需要评估许多不同的情况,这意味着更多犯错的机会。我们可以看到明显的对角线将鸟类与入室盗窃犯分开,但普雷斯顿和洛克西只能在这里感知 X 轴:他们在我们的图中间有一组混乱、重叠的事件。他们必须选择一个垂直的吠声价值阈值,这意味着他们中的任何一个都不可能完美地做到。有时候把你的家电搬到他们的货车上的人是你雇来修理洗衣机的维修人员,有时候入室盗窃犯会开着一辆侧面写着“洗衣机维修”的货车出现。期望狗能够察觉到这些微妙之处注定会失败。

我们要使用的实际输入数据具有高维度–我们需要考虑大量 CT 体素值,以及更抽象的事物,如候选大小、在肺部的整体位置等等。我们模型的工作是将每个事件及其属性映射到这个矩形中,以便我们可以使用单一垂直线(我们的分类阈值)清晰地分离这些正面和负面事件。这是通过我们模型末端的nn.Linear层完成的。垂直线的位置与我们在第 11.6.1 节中看到的classificationThreshold_float完全对应。在那里,我们选择了硬编码值 0.5 作为我们的阈值。

请注意,实际上,所呈现的数据不是二维的;在倒数第二层之后,它变成了非常高维度,到输出时变成了一维(这里是我们的 X 轴)–每个样本只有一个标量(然后被分类阈值二分)。在这里,我们使用第二维(Y 轴)来表示我们的模型无法看到或使用的每个样本特征:例如患者的年龄或性别,结节候选在肺部的位置,甚至模型尚未利用的候选局部特征。它还为我们提供了一种方便的方式来表示非结节和结节样本之间的混淆。

图 12.5 中的象限区域和每个区域中包含的样本数将是我们用来讨论模型性能的值,因为我们可以使用这些值之间的比率来构建越来越复杂的指标,以客观地衡量我们的表现。正如他们所说,“证据在于比例。”¹ 接下来,我们将使用这些事件子集之间的比率来开始定义更好的指标。

12.3.1 召回率是罗克西的优势

召回率基本上是“确保你永远不会错过任何有趣的事件!”正式地说,召回率是真阳性与真阳性和假阴性的并集的比率。我们可以在图 12.6 中看到这一点。

图 12.6 召回率是真阳性与真阳性和假阴性的并集的比率。高召回率可以最小化假阴性。

在某些情境中,召回率被称为敏感性

为了提高召回率,要尽量减少假阴性。在看门狗的术语中,这意味着如果你不确定,就叫一声,以防万一。不要让任何啮齿动物小偷在你的监视下溜走!

罗克西通过将分类阈值推到最左边,使其包含图 12.7 中几乎所有的正面事件,从而实现了极高的召回率。请注意,这样做意味着她的召回值接近 1.0,即 99%的盗贼都会被叫唤。由于这是罗克西定义成功的方式,在她看来,她做得很好。不要在意大量的假阳性!

图 12.7 罗克西选择的阈值优先考虑减少假阴性。每只老鼠都会被叫唤……还有猫,大多数鸟。

12.3.2 精确性是普雷斯顿的长处

精确性基本上是“除非你确定,否则不要叫。”为了提高精确性,要尽量减少假阳性。普雷斯顿不会对某事物叫唤,除非他确定那是一个盗贼。更正式地说,精确性是真阳性与真阳性和假阳性的并集的比率,如图 12.8 所示。

图 12.8 精确性是真阳性与真阳性和假阳性的并集的比率。高精确性可以最小化假阳性。

普雷斯顿通过将分类阈值推到最右边,排除尽可能多的无趣、负面事件,从而实现了极高的精确性(见图 12.9)。这与罗克西的方法相反,意味着普雷斯顿的精确性接近 1.0:他叫的 99%的事物都是盗贼。尽管有大量事件未被检测到,但这也符合他作为一只好看门狗的定义。

虽然精确性和召回率都不能作为评估我们模型的单一指标,但它们在训练过程中是有用的数字。让我们计算并显示这些作为我们训练程序的一部分,然后我们将讨论其他可以使用的指标。

图 12.9 普雷斯顿选择的阈值优先考虑减少假阳性。猫被放过,只有盗贼会被叫唤!

12.3.3 在 logMetrics 中实现精确性和召回率

在训练过程中,精确性和召回率都是有价值的指标,因为它们提供了关于模型行为的重要见解。如果它们中的任何一个降至零(正如我们在第十一章中看到的!),那么我们的模型可能已经开始表现出退化的方式。我们可以使用行为的确切细节来指导我们在哪里进行调查和实验,以使训练重新回到正轨。我们希望更新logMetrics函数,以在每个时期的输出中添加精确性和召回率,以补充我们已经拥有的损失和正确性指标。

到目前为止,我们一直在以“真阳性”等术语定义精确度和召回率,因此我们将在代码中继续这样做。事实证明,我们已经计算了一些我们需要的值,尽管我们给它们起了不同的名称。

列表 12.1 training.py:315,LunaTrainingApp.logMetrics

neg_count = int(negLabel_mask.sum())
pos_count = int(posLabel_mask.sum())
trueNeg_count = neg_correct = int((negLabel_mask & negPred_mask).sum())
truePos_count = pos_correct = int((posLabel_mask & posPred_mask).sum())
falsePos_count = neg_count - neg_correct
falseNeg_count = pos_count - pos_correct

在这里,我们可以看到neg_correcttrueNeg_count是相同的!这其实是有道理的,因为非结节是我们的“负面”值(如“负面诊断”),如果分类器预测正确,那么这就是真阴性。同样,正确标记的结节样本是真阳性。

我们确实需要添加我们的假阳性和假阴性值的变量。这很简单,因为我们可以取良性标签的总数并减去正确的计数。剩下的是被误分类为阳性的非结节样本的计数。因此,它们是假阳性。同样,假阴性计算形式相同,但使用结节计数。

有了这些数值,我们可以计算precisionrecall,并将它们存储在metrics_dict中。

列表 12.2 training.py:333,LunaTrainingApp.logMetrics

precision = metrics_dict['pr/precision'] = \
  truePos_count / np.float32(truePos_count + falsePos_count)
recall  = metrics_dict['pr/recall'] = \
  truePos_count / np.float32(truePos_count + falseNeg_count)

注意双重赋值:虽然有单独的precisionrecall变量并不是绝对必要的,但它们提高了下一节的可读性。我们还扩展了logMetrics中的日志语句以包括新值,但我们暂时跳过实现(我们将在本章稍后重新讨论日志记录)。

12.3.4 我们的终极性能指标:F1 分数

尽管有用,但精确度和召回率都无法完全捕捉我们评估模型所需的内容。正如我们在 Roxie 和 Preston 中看到的,通过操纵我们的分类阈值,可能会单独操纵其中一个,导致模型在其中一个上得分良好,但以牺牲任何实际效用为代价。我们需要一种以防止这种操纵的方式结合这两个值的东西。正如我们在图 12.10 中看到的,现在是引入我们的终极指标的时候了。

通常接受的结合精确度和召回率的方法是使用 F1 分数(en.wikipedia.org/wiki/F1_score)。与其他指标一样,F1 分数的范围在 0(没有真实世界预测能力的分类器)和 1(具有完美预测的分类器)之间。我们将更新logMetrics以包括这一点。

列表 12.3 training.py:338,LunaTrainingApp.logMetrics

metrics_dict['pr/f1_score'] = \
  2 * (precision * recall) / (precision + recall)

乍一看,这可能比我们需要的更复杂,当在精确度和召回率之间进行权衡时,F1 分数的行为可能不会立即显而易见。然而,这个公式有很多好的性质,并且与我们可能考虑的几种其他更简单的替代方案相比较有利。

图 12.10 本章的主题集,重点是最终的 F1 分数指标

一个立即可能的评分函数是将精确度和召回率的值平均起来。不幸的是,这使得avg(p=1.0, r=0.0)avg(p=0.5, r=0.5)都得到相同的 0.5 分数,正如我们之前讨论的,精确度或召回率为零的分类器通常是无用的。将无用的东西与有用的东西赋予相同的非零分数,立即使平均成为一个没有意义的指标。

然而,让我们在图 12.11 中直观比较平均和 F1。有几件事引人注目。首先,我们可以看到平均值的等高线中没有曲线或拐点。这就是让我们的精确度或召回率偏向一侧的原因!永远不会出现这样的情况,即通过使召回率达到 100%(Roxie 方法)然后消除任何容易消除的假阳性来最大化分数是没有意义的。这就为添加分数至少为 0.5 设置了一个底线!拥有一个质量指标,可以轻松获得至少 50% 的分数,感觉不对劲。

图 12.11 使用avg(p, r)计算最终分数。较浅的值接近 1.0。

注意 我们实际上在这里做的是取精确度和召回率的算术平均值en.wikipedia.org/wiki/Arithmetic_mean),这两者都是比率而不是可计数的标量值。取比率的算术平均值通常不会给出有意义的结果。F1 分数是两个比率的调和平均值en.wikipedia.org/wiki/Harmonic_mean)的另一个名称,这是结合这些值的更合适的方式。

与 F1 分数相比:当召回率高而精确度低时,为了将分数移动到平衡的甜蜜点,牺牲很多召回率以换取一点精确度将使分数更接近。有一个漂亮、深刻的拐点,很容易滑入其中。鼓励具有平衡精确度和召回率是我们希望从我们的评分指标中得到的。

假设我们仍然希望有一个更简单的指标,但不会奖励任何偏斜。为了纠正加法的弱点,我们可能会取精确度和召回率的最小值(图 12.12)。

图 12.12 使用min(p, r)计算最终分数

这很好,因为如果任一值为 0,分数也为 0,而要获得 1.0 的分数的唯一方法是两个值都为 1.0。然而,它仍然有待改进,因为使召回率从 0.7 提高到 0.9 而将精确度保持在 0.5 不会改善分数,降低召回率到 0.6 也不会改善分数!尽管这个指标肯定惩罚了精确度和召回率之间的不平衡,但它并没有捕捉到关于这两个值的许多细微差别。正如我们所见,通过简单地移动分类阈值,很容易将一个值换成另一个值。我们希望我们的指标能反映这些交易。

为了更好地实现我们的目标,我们将不得不接受至少更复杂一点。我们可以将这两个值相乘,如图 12.13 所示。这种方法保持了一个很好的特性,即如果任一值为 0,分数也为 0,而分数为 1.0 意味着两个输入都完美。它还有利于在低值处精确度和召回率之间的平衡折衷,尽管当接近完美结果时,它变得更加线性。这并不好,因为我们真的需要将两者都提高才能在那一点上有意义的改进。


图 12.13 使用mult(p, r)计算最终分数

注意 在这里我们正在取两个比率的几何平均值en.wikipedia.org/wiki/Geometric_mean),这也不会产生有意义的结果。

还有一个问题,几乎整个象限从(0, 0)到(0.5, 0.5)都非常接近于零。正如我们将看到的,拥有一个对该区域的变化敏感的指标是重要的,特别是在我们模型设计的早期阶段。

虽然将乘法作为我们的评分函数是可行的(它没有任何立即淘汰的资格,就像之前的评分函数一样),但我们将使用 F1 分数来评估我们的分类模型的性能。

更新日志输出以包括精确度、召回率和 F1 分数

现在我们有了新的指标,将它们添加到我们的日志输出中非常简单。我们将在我们的训练和验证集的主要日志声明中包括精确度、召回率和 F1。

列表 12.4 training.py:341, LunaTrainingApp.logMetrics

log.info(
  ("E{} {:8} {loss/all:.4f} loss, "
     + "{correct/all:-5.1f}% correct, "
     + "{pr/precision:.4f} precision, "   # ❶
     + "{pr/recall:.4f} recall, "         # ❶
     + "{pr/f1_score:.4f} f1 score"       # ❶
  ).format(
    epoch_ndx,
    mode_str,
    **metrics_dict,
  )
)

❶ 格式字符串已更新

另外,我们将为每个负样本和正样本的正确识别计数和总样本数包括精确值。

列表 12.5 training.py:353, LunaTrainingApp.logMetrics

log.info(
  ("E{} {:8} {loss/neg:.4f} loss, "
     + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  ).format(
    epoch_ndx,
    mode_str + '_neg',
    neg_correct=neg_correct,
    neg_count=neg_count,
    **metrics_dict,
  )
)

新版本的正日志声明看起来基本相同。

12.3.5 我们的模型如何使用我们的新指标?

现在我们已经实施了闪亮的新指标,让我们试试它们;我们将在展示 Bash shell 会话结果后讨论结果。在您的系统进行数字计算时,您可能想提前阅读;这可能需要大约半个小时,具体时间取决于您的系统。实际所需时间取决于您的系统的 CPU、GPU 和磁盘速度;我们的系统配备 SSD 和 GTX 1080 Ti,每个完整时期大约需要 20 分钟:

$ ../.venv/bin/python -m p2ch12.training
Starting LunaTrainingApp...
...
E1 LunaTrainingApp
.../p2ch12/training.py:274: RuntimeWarning: invalid value encountered in double_scalars
  metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)                                          # ❶
E1 trn      0.0025 loss,  99.8% correct, 0.0000 prc, 0.0000 rcl, nan f1
E1 trn_ben  0.0000 loss, 100.0% correct (494735 of 494743)
E1 trn_mal  1.0000 loss,   0.0% correct (0 of 1215)
.../p2ch12/training.py:269: RuntimeWarning: invalid value encountered in long_scalars
  precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
E1 val      0.0025 loss,  99.8% correct, nan prc, 0.0000 rcl, nan f1
E1 val_ben  0.0000 loss, 100.0% correct (54971 of 54971)
E1 val_mal  1.0000 loss,   0.0% correct (0 of 136)

❶ 这些 RuntimeWarning 行的确切计数和行号可能会因运行而异。

糟糕。我们收到了一些警告,考虑到我们计算的一些值是nan,可能在某处发生了除零操作。让我们看看我们能找出什么。

首先,由于训练集中没有一个正样本被分类为正,这意味着精确度和召回率都为零,导致我们的 F1 分数计算除以零。其次,对于我们的验证集,由于没有任何东西被标记为正,truePos_countfalsePos_count都为零。这导致我们的precision计算的分母也为零;这是合理的,因为这就是我们看到另一个RuntimeWarning的地方。

少数负训练样本被分类为正(494743 个中有 494735 个被分类为负,因此有 8 个样本被错误分类)。虽然一开始可能看起来很奇怪,但请记住我们是在整个时期内收集我们的训练结果,而不是像我们对验证结果那样使用模型的时期末状态。这意味着第一批次实际上产生了随机结果。其中一些来自第一批次的样本被标记为正并不奇怪。

注意 由于网络权重的随机初始化和训练样本的随机排序,单独的运行可能会表现出略有不同的行为。确切可重现的行为可能是可取的,但超出了我们在本书第 2 部分中所尝试的范围。

嗯,那有点痛苦。切换到我们的新指标导致从 A+降至“零,如果你幸运的话”–如果我们不幸运,分数会很糟糕,甚至不是一个数字。哎呀。

话虽如此,在长期来看,这对我们是有利的。自第十一章以来,我们就知道我们的模型性能很差。如果我们的指标告诉我们除了那个,那将指向指标中的一个基本缺陷!

12.4 理想数据集是什么样的?

在我们为当前糟糕的情况哭泣之前,让我们想想我们实际上希望我们的模型做什么。图 12.14 说,首先我们需要平衡我们的数据,以便我们的模型能够正确训练。让我们建立起达到这个目标所需的逻辑步骤。

图 12.14 本章主题集,重点关注平衡我们的正负样本

回想一下之前的图 12.5,并讨论分类阈值。通过移动阈值来获得更好的结果具有有限的效果–正负类之间有太多重叠无法处理。

相反,我们想看到一个类似图 12.15 的图像。在这里,我们的标签阈值几乎是垂直的。这就是我们想要的,因为这意味着标签阈值和我们的分类阈值可以相当好地对齐。同样,大多数样本集中在图表的两端。这两件事都要求我们的数据易于分离,并且我们的模型具有执行该分离的能力。我们的模型目前具有足够的容量,所以问题不在于此。相反,让我们看看我们的数据。

图 12.15 一个训练良好的模型可以清晰地分离数据,使得很容易选择一个具有少量折衷的分类阈值。

图 12.16 一个大致近似我们 LUNA 分类数据中不平衡的数据集

请记住,我们的数据极度不平衡。正样本与负样本的比例为 400:1。这是极端不平衡的!图 12.16 展示了这种情况。难怪我们的“实际结节”样本在人群中被忽略!

现在,让我们非常清楚:当我们完成时,我们的模型将能够很好地处理这种数据不平衡。我们甚至可以在不改变平衡的情况下训练模型到最后,假设我们愿意等待无数个纪元。但我们是忙碌的人,有很多事情要做,所以与其等到 GPU 烧毁到宇宙的热死亡,不如尝试通过改变我们训练的类平衡使我们的训练数据看起来更理想。

12.4.1 使数据看起来不那么实际,更像“理想”的方法

最好的做法是有相对更多的正样本。在训练的初始时期,当我们从随机混乱过渡到更有组织的状态时,由于正样本太少,它们会被淹没。

然而,这种情况发生的方式有些微妙。请记住,由于我们的网络权重最初是随机的,网络每个样本的输出也是随机的(但被夹在[0-1]范围内)。

注意 我们的损失函数是nn.CrossEntropyLoss,严格来说是在原始 logits 上操作,而不是类概率。在我们的讨论中,我们将忽略这一区别,假设损失和标签预测之间的差异是相同的。

预测与正确标签数值接近的结果对网络权重几乎没有影响,而与正确答案明显不同的预测结果会导致权重发生更大的变化。由于模型在随机权重初始化时输出是随机的,我们可以假设在我们约 500k 个训练样本(准确地说是 495,958 个)中,我们将有以下近似组:

  1. 250,000 个负样本将被预测为负面(0.0 到 0.5),并最多对网络权重产生一点朝向预测负面的变化。
  2. 250,000 个负样本将被预测为正面(0.5 到 1.0),并导致网络权重向预测负面的方向发生大幅变化。
  3. 500 个正样本将被预测为负面,并导致网络权重向预测正面的方向发生变化。
  4. 500 个正样本将被预测为正面,并几乎不会对网络权重产生任何变化。

注意 请记住,实际预测是介于 0.0 和 1.0 之间的实数,因此这些组没有严格的界限。

这里的关键是:1 组和 4 组可以是任意大小,它们对训练几乎没有影响。唯一重要的是 2 组和 3 组能够相互抵消,防止网络崩溃到退化的“只输出一种结果”的状态。由于 2 组比 3 组大 500 倍,我们使用的批量大小为 32,大约需要经过 500/32 = 15 批次才能看到一个正样本。这意味着 15 个训练批次中有 14 个将是 100%负面的,只会将所有模型权重拉向预测负面的方向。这种不平衡的拉力产生了我们一直看到的退化行为。

相反,我们希望正样本和负样本数量相同。因此,在训练的第一部分中,一半的标签将被错误分类,这意味着第 2 组和第 3 组的大小应该大致相等。我们还希望确保我们呈现的批次中包含负样本和正样本的混合。平衡将导致拉锯战平衡,每个批次中的类别混合将使模型有很好的机会学会区分这两个类别。由于我们的 LUNA 数据只有少量固定数量的正样本,我们将不得不接受我们拥有的正样本并在训练期间重复呈现它们。

歧视

在这里,我们将歧视定义为“将两个类别彼此分开的能力”。构建和训练一个能够将“实际结节”候选者与正常解剖结构区分开的模型是我们在第 2 部分所做的全部工作的重点。

歧视的一些其他定义更具问题性。虽然超出了我们在这里讨论工作的范围,但从真实世界数据训练的模型存在更大的问题。如果真实世界数据集是从存在真实世界歧视性偏见的来源收集的(例如,种族偏见在逮捕和定罪率中,或者从社交媒体收集的任何内容),并且在数据集准备或训练期间没有纠正这种偏见,那么生成的模型将继续表现出训练数据中存在的相同偏见。就像在人类中一样,种族主义是被学习的。

这意味着几乎任何从互联网大数据源训练的模型都会在某种程度上受到损害,除非极端小心地清除这些模型中的偏见。请注意,就像我们在第 2 部分的目标一样,这被认为是一个未解决的问题。

回想一下我们在第十一章中提到的教授,他的期末考试有 99 个错误答案和 1 个正确答案。下学期,在被告知“你应该有更平衡的真假答案”后,教授决定增加一次期中考试,其中有 99 个正确答案和 1 个错误答案。“问题解决了!”

显然,正确的方法是以一种不允许学生利用测试的更大结构来回答问题的方式交替真实和错误答案。虽然学生可能会注意到“奇数问题是真实的,偶数问题是错误的”这样的模式,但 PyTorch 使用的批处理系统不允许模型“注意到”或利用那种模式。我们的训练数据集将需要更新,以在正样本和负样本之间交替,就像图 12.17 中那样。

不平衡数据就像我们在第九章开始提到的草堆中的针。如果您必须手动执行这项分类工作,您可能会开始同情普雷斯顿。

图 12.17 不平衡数据的批次将在第一个正事件之前只有负事件,而平衡数据可以每隔一个样本交替出现。

然而,我们不会为验证进行任何平衡。我们的模型需要在现实世界中表现良好,而现实世界是不平衡的(毕竟,这就是我们获取原始数据的地方!)。

我们应该如何实现这种平衡?让我们讨论我们的选择。

取样器可以重塑数据集

DataLoader 的一个可选参数是sampler=...。这允许数据加载器覆盖传入数据集的本机迭代顺序,而是根据需要塑造、限制或重新强调底层数据。当使用一个不受您控制的数据集时,这可能非常有用。将公共数据集重新塑造以满足您的需求比从头开始重新实现该数据集要少得多。

不足之处在于,我们可以通过采样器实现的许多变异需要我们打破底层数据集的封装。例如,假设我们有一个类似于 CIFAR-10(www.cs.toronto.edu/~kriz/cifar.html)的数据集,由 10 个权重相同的类组成,我们想要让 1 个类(比如“飞机”)现在占所有训练图像的 50%。我们可以决定使用WeightedRandomSamplermng.bz/8plK)并将每个“飞机”样本索引的权重提高,但构建weights参数需要我们事先知道哪些索引是飞机。

正如我们讨论的那样,Dataset API 只规定子类提供__len____getitem__,但我们无法直接询问“哪些样本是飞机?”我们要么事先加载每个样本以查询该样本的类别,要么打破封装并希望我们需要的信息可以轻松从查看Dataset子类的内部实现中获得。

由于在我们可以直接控制数据集的情况下,这两种选项都不是特别理想的,因此第 2 部分的代码在Dataset子类内部实现任何所需的数据整形,而不依赖外部采样器。

在数据集中实现类平衡

我们将直接更改我们的LunaDataset,以呈现平衡的正负样本比例进行训练。我们将保留负训练样本和正训练样本的分开列表,并交替从这两个列表中返回样本。这将防止模型通过简单地回答每个呈现的样本为“false”而得分良好的退化行为。此外,正负类别将交错排列,以便权重更新被迫区分类别。

让我们在LunaDataset中添加一个ratio_int,用于控制第N个样本的标签,并跟踪我们按标签分开的样本。

列表 12.6 dsets.py:217,class LunaDataset

class LunaDataset(Dataset):
  def __init__(self,
         val_stride=0,
         isValSet_bool=None,
         ratio_int=0,
      ):
    self.ratio_int = ratio_int
    # ... line 228
    self.negative_list = [
      nt for nt in self.candidateInfo_list if not nt.isNodule_bool
    ]
    self.pos_list = [
      nt for nt in self.candidateInfo_list if nt.isNodule_bool
    ]
    # ... line 265
  def shuffleSamples(self):               # ❶
    if self.ratio_int:
      random.shuffle(self.negative_list)
      random.shuffle(self.pos_list)

❶ 我们将在每个周期的开头调用这个函数,以随机化呈现的样本顺序。

有了这个,我们现在为每个标签都有专门的列表。使用这些列表,更容易根据数据集中的索引返回我们想要的标签。为了确保我们的索引正确,我们应该勾画出我们想要的排序。假设ratio_int为 2,意味着负样本与正样本的比例为 2:1。这意味着每三个索引应该是正样本:

DS Index   0 1 2 3 4 5 6 7 8 9 ...
Label      + - - + - - + - - +
Pos Index  0     1     2     3
Neg Index    0 1   2 3   4 5

数据集索引与正索引之间的关系很简单:将数据集索引除以 3 然后向下取整。负索引稍微复杂一些,因为我们必须从数据集索引中减去 1,然后再减去最近的正索引。

在我们的LunaDataset类中实现,看起来像下面这样。

列表 12.7 dsets.py:286,LunaDataset.__getitem__

def __getitem__(self, ndx):
  if self.ratio_int:                                   # ❶
    pos_ndx = ndx // (self.ratio_int + 1)
    if ndx % (self.ratio_int + 1):                     # ❷
      neg_ndx = ndx - 1 - pos_ndx
      neg_ndx %= len(self.negative_list)               # ❸
      candidateInfo_tup = self.negative_list[neg_ndx]
    else:
      pos_ndx %= len(self.pos_list)                    # ❸
      candidateInfo_tup = self.pos_list[pos_ndx]
  else:
    candidateInfo_tup = self.candidateInfo_list[ndx]   # ❹

❶ 零的ratio_int意味着使用本地平衡。

❷ 非零余数表示这应该是一个负样本。

❸ 溢出导致环绕。

❹ 如果不平衡类别,则返回第 N 个样本

这可能有点复杂,但如果你仔细检查一下,就会明白。请记住,如果比率较低,我们会在用尽数据集之前用完正样本。我们通过在索引到self.pos_list之前取pos_ndx的模来处理这个问题。虽然由于大量负样本的存在,neg_ndx不太可能发生相同类型的索引溢出,但我们仍然执行模运算,以防以后做出可能导致溢出的更改。

我们还将对数据集的长度进行更改。虽然这并非绝对必要,但加快单个周期的速度是很好的。我们将硬编码我们的__len__为 200,000。

列表 12.8 dsets.py:280,LunaDataset.__len__

def __len__(self):
  if self.ratio_int:
    return 200000
  else:
    return len(self.candidateInfo_list)

我们不再受限于特定数量的样本,并且提供“完整的一轮”在我们必须多次重复正样本以呈现平衡的训练集时并不是很有意义。通过选择 20 万个样本,我们减少了开始训练运行并看到结果之间的时间(更快的反馈总是不错!),并且我们给自己一个漂亮、清晰的每轮样本数。随时调整一轮的长度以满足您的需求。

为了完整起见,我们还添加了一个命令行参数。

列表 12.9 training.py:31,class LunaTrainingApp

class LunaTrainingApp:
  def __init__(self, sys_argv=None):
    # ... line 52
    parser.add_argument('--balanced',
      help="Balance the training data to half positive, half negative.",
      action='store_true',
      default=False,
    )

然后我们将该参数传递给LunaDataset构造函数。

列表 12.10 training.py:137,LunaTrainingApp.initTrainDl

def initTrainDl(self):
  train_ds = LunaDataset(
    val_stride=10,
    isValSet_bool=False,
    ratio_int=int(self.cli_args.balanced),    # ❶
  )

❶ 这里我们依赖于 Python 的True可转换为1

我们已经准备就绪。让我们运行它!

PyTorch 深度学习(GPT 重译)(五)(2)https://developer.aliyun.com/article/1485249


相关实践学习
日志服务之使用Nginx模式采集日志
本文介绍如何通过日志服务控制台创建Nginx模式的Logtail配置快速采集Nginx日志并进行多维度分析。
相关文章
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(七)(4)
JavaScript 权威指南第七版(GPT 重译)(七)
24 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(3)
JavaScript 权威指南第七版(GPT 重译)(七)
33 0
|
前端开发 JavaScript Unix
JavaScript 权威指南第七版(GPT 重译)(七)(2)
JavaScript 权威指南第七版(GPT 重译)(七)
42 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(1)
JavaScript 权威指南第七版(GPT 重译)(七)
60 0
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(4)
JavaScript 权威指南第七版(GPT 重译)(六)
93 2
JavaScript 权威指南第七版(GPT 重译)(六)(4)
|
13天前
|
前端开发 JavaScript API
JavaScript 权威指南第七版(GPT 重译)(六)(3)
JavaScript 权威指南第七版(GPT 重译)(六)
55 4
|
13天前
|
XML 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(2)
JavaScript 权威指南第七版(GPT 重译)(六)
60 4
JavaScript 权威指南第七版(GPT 重译)(六)(2)
|
13天前
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(六)(1)
JavaScript 权威指南第七版(GPT 重译)(六)
28 3
JavaScript 权威指南第七版(GPT 重译)(六)(1)
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(五)(4)
JavaScript 权威指南第七版(GPT 重译)(五)
39 9
|
13天前
|
前端开发 JavaScript 程序员
JavaScript 权威指南第七版(GPT 重译)(五)(3)
JavaScript 权威指南第七版(GPT 重译)(五)
36 8