PyTorch 深度学习(GPT 重译)(六)(1)

简介: PyTorch 深度学习(GPT 重译)(六)

十四、端到端结节分析,以及接下来的步骤

本章内容包括

  • 连接分割和分类模型
  • 为新任务微调网络
  • 直方图和其他指标类型添加到 TensorBoard
  • 从过拟合到泛化

在过去的几章中,我们已经构建了许多对我们的项目至关重要的系统。我们开始加载数据,构建和改进结节候选的分类器,训练分割模型以找到这些候选,处理训练和评估这些模型所需的支持基础设施,并开始将我们的训练结果保存到磁盘。现在是时候将我们拥有的组件统一起来,以便实现我们项目的完整目标:是时候自动检测癌症了。

14.1 迈向终点

通过查看图 14.1 我们可以得到剩余工作的一些线索。在第 3 步(分组)中,我们看到我们仍需要建立第十三章的分割模型和第十二章的分类器之间的桥梁,以确定分割网络找到的是否确实是结节。右侧是第 5 步(结节分析和诊断),整体目标的最后一步:查看结节是否为癌症。这是另一个分类任务;但为了在过程中学到一些东西,我们将通过借鉴我们已有的结节分类器来采取新的方法。

图 14.1 我们的端到端肺癌检测项目,重点关注本章的主题:第 3 步和第 5 步,分组和结节分析

当然,这些简短的描述及其在图 14.1 中的简化描述遗漏了很多细节。让我们通过图 14.2 放大一下,看看我们还有哪些任务要完成。

图 14.2 一个关于我们端到端项目剩余工作的详细查看

正如您所看到的,还有三项重要任务。以下列表中的每一项对应于图 14.2 的一个主要项目:

  1. 生成结节候选。这是整个项目的第 3 步。这一步骤包括三项任务:
  1. 分割 --第十三章的分割模型将预测给定像素是否感兴趣:如果我们怀疑它是结节的一部分。这将在每个 2D 切片上完成,并且每个 2D 结果将被堆叠以形成包含结节候选预测的体素的 3D 数组。
  2. 分组 --我们将通过将预测应用于阈值来将体素分组为结节候选,然后将连接区域的标记体素分组。
  3. 构建样本元组 --每个识别的结节候选将用于构建一个用于分类的样本元组。特别是,我们需要生成该结节中心的坐标(索引、行、列)。

一旦实现了这一点,我们将拥有一个应用程序,该应用程序接收患者的原始 CT 扫描并生成检测到的结节候选列表。生成这样的列表是 LUNA 挑战的任务。如果这个项目被临床使用(我们再次强调我们的项目不应该被使用!),这个结节列表将适合由医生进行更仔细的检查。

  1. 对结节和恶性进行分类。我们将取出我们刚刚产生的结节候选并将其传递到我们在第十二章实现的候选分类步骤,然后对被标记为结节的候选进行恶性检测:
  1. 结节分类 --从分割和分组中得到的每个结节候选将被分类为结节或非结节。这样做将允许我们筛选出被我们的分割过程标记为许多正常解剖结构。
  2. ROC/AUC 指标 --在我们开始最后的分类步骤之前,我们将定义一些用于检查分类模型性能的新指标,并建立一个基准指标,以便与我们的恶性分类器进行比较。
  3. 微调恶性模型 --一旦我们的新指标就位,我们将定义一个专门用于分类良性和恶性结节的模型,对其进行训练,并查看其表现。我们将通过微调进行训练:这个过程会剔除现有模型的一些权重,并用新值替换它们,然后我们将这些值调整到我们的新任务中。

到那时,我们将离我们的最终目标不远了:将结节分类为良性和恶性类别,然后从 CT 中得出诊断。再次强调,在现实世界中诊断肺癌远不止盯着 CT 扫描,因此我们进行这种诊断更多是为了看看我们能够使用深度学习和成像数据单独走多远。

  1. 端到端检测。最后,我们将把所有这些组合起来,达到终点,将组件组合成一个端到端的解决方案,可以查看 CT 并回答问题“肺部是否存在恶性结节?”
  1. IRC --我们将对我们的 CT 进行分割,以获取结节候选样本进行分类。
  2. 确定结节 --我们将对候选进行结节分类,以确定是否应将其输入恶性分类器。
  3. *确定恶性程度 --*我们将对通过结节分类器的结节进行恶性分类,以确定患者是否患癌症。

我们有很多事情要做。冲刺终点!

注意 正如前一章中所述,我们将在文本中详细讨论关键概念,并略过重复、繁琐或显而易见的代码部分。完整的细节可以在书籍的代码存储库中找到。

14.2 验证集的独立性

我们面临着一个微妙但关键的错误的危险,我们需要讨论并避免:我们有一个潜在的从训练集到验证集的泄漏!对于分割和分类模型的每一个,我们都小心地将数据分割成一个训练集和一个独立的验证集,通过将每十个示例用于验证,其余用于训练。

然而,分类模型的分割是在结节列表上进行的,分割模型的分割是在 CT 扫描列表上进行的。这意味着我们很可能在分类模型的训练集中有来自分割验证集的结节,反之亦然。我们必须避免这种情况!如果不加以修正,这种情况可能导致性能指标人为地高于我们在独立数据集上获得的性能。这被称为泄漏,它将使我们的验证失效。

为了纠正这种潜在的数据泄漏,我们需要重新设计分类数据集,以便像我们在第十三章中为分割任务所做的那样也在 CT 扫描级别上工作。然后我们需要用这个新数据集重新训练分类模型。好消息是,我们之前没有保存我们的分类模型,所以我们无论如何都需要重新训练。

你应该从中得到的启示是在定义验证集时要注意整个端到端的过程。可能最简单的方法(也是对大多数重要数据集采用的方法)是尽可能明确地进行验证分割–例如,通过为训练和验证分别设置两个目录–然后在整个项目中坚持这种分割。当您需要重新分割时(例如,当您需要按某些标准对数据集进行分层时),您需要使用新分割的数据集重新训练所有模型。

我们为您做的是从第 10-12 章的LunaDataset中复制候选列表,并从第十三章的Luna2dSegmentationDataset中将其分割为测试和验证数据集。由于这是非常机械的,并且没有太多细节可供学习(您现在已经是数据集专家了),我们不会详细展示代码。

我们将通过重新运行分类器的训练来重新训练我们的分类模型:¹

$ python3 -m p2ch14.training --num-workers=4 --epochs 100 nodule-nonnodule

经过 100 个周期,我们对正样本的准确率达到约 95%,对负样本达到 99%。由于验证损失没有再次上升的趋势,我们可以继续训练模型以查看是否会继续改善。

经过 90 个周期,我们达到了最大的 F1 分数,并且在验证准确率方面达到了 99.2%,尽管在实际结节上只有 92.8%。我们将采用这个模型,尽管我们可能也会尝试在恶性结节的准确率上稍微牺牲一些总体准确率(在此期间,模型在实际结节上的准确率为 95.4%,总准确率为 98.9%)。这对我们来说已经足够了,我们准备连接这些模型。

14.3 连接 CT 分割和结节候选分类

现在我们已经从第十三章保存了一个分割模型,并且在上一节刚刚训练了一个分类模型,图 14.3 的步骤 1a、1b 和 1c 显示我们已经准备好开始编写代码,将我们的分割输出转换为样本元组。我们正在进行分组:在图 14.3 的步骤 1b 的高亮周围找到虚线轮廓。我们的输入是分割:由第 1a 中的分割模型标记的体素。我们想要找到 1c,即每个“块”中心的质心坐标:我们需要在样本元组列表中提供的是 1b 加号标记的索引、行和列。

图 14.3 我们本章的计划,重点是将分割的体素分组为结节候选

运行模型时,其处理方式与我们在训练和验证(尤其是验证)期间处理它们的方式非常相似。这里的区别在于对 CT 进行循环。对于每个 CT,我们会分割每个切片,然后将所有分割输出作为分组的输入。分组的输出将被馈送到结节分类器中,通过该分类器幸存下来的结节将被馈送到恶性分类器中。

这是对 CT 的外部循环,对每个 CT 进行分割、分组、分类候选,并提供分类以进行进一步处理。

列表 14.1 nodule_analysis.py:324,NoduleAnalysisApp.main

for _, series_uid in series_iter:                        # ❶
  ct = getCt(series_uid)                                 # ❷
  mask_a = self.segmentCt(ct, series_uid)                # ❸
  candidateInfo_list = self.groupSegmentationOutput(     # ❹
    series_uid, ct, mask_a)
  classifications_list = self.classifyCandidates(        # ❺
    ct, candidateInfo_list)

❶ 循环遍历系列 UID

❷ 获取 CT(大图中的步骤 1)

❸ 在其上运行我们的分割模型(步骤 2)

❹ 对输出中的标记体素进行分组(步骤 3)

❺ 在它们上运行我们的结节分类器(步骤 4)

我们将在以下部分详细介绍segmentCtgroupSegmentationOutputclassifyCandidates方法。

14.3.1 分割

首先,我们将对整个 CT 扫描的每个切片执行分割。由于我们需要逐个患者的 CT 逐个切片进行处理,我们构建一个Dataset,加载具有单个series_uid的 CT 并返回每个切片,每次调用__getitem__

注意 特别是在 CPU 上执行时,分割步骤可能需要相当长的时间。尽管我们在这里只是简单提及,但代码将在可用时使用 GPU。

除了更广泛的输入之外,主要区别在于我们如何处理输出。回想一下,输出是每个像素的概率数组(即在 0…1 范围内),表示给定像素是否属于结节。在遍历切片时,我们在一个与我们的 CT 输入形状相同的掩模数组中收集切片预测。之后,我们对预测进行阈值处理以获得二进制数组。我们将使用 0.5 的阈值,但如果需要,我们可以尝试不同的阈值来在增加假阳性的情况下获得更多真阳性。

我们还包括一个使用 scipy.ndimage.morphology 中的腐蚀操作进行小的清理步骤。它删除一个边缘体素层,仅保留内部体素——那些所有八个相邻体素在轴方向上也被标记的体素。这使得标记区域变小,并导致非常小的组件(小于 3 × 3 × 3 体素)消失。结合数据加载器的循环,我们指示它向我们提供来自单个 CT 的所有切片,我们有以下内容。

列表 14.2 nodule_analysis.py:384, .segmentCt

def segmentCt(self, ct, series_uid):
  with torch.no_grad():                                     # ❶
    output_a = np.zeros_like(ct.hu_a, dtype=np.float32)     # ❷
    seg_dl = self.initSegmentationDl(series_uid)  #         # ❸
    for input_t, _, _, slice_ndx_list in seg_dl:
      input_g = input_t.to(self.device)                     # ❹
      prediction_g = self.seg_model(input_g)                # ❺
      for i, slice_ndx in enumerate(slice_ndx_list):        # ❻
        output_a[slice_ndx] = prediction_g[i].cpu().numpy()
    mask_a = output_a > 0.5                                 # ❼
    mask_a = morphology.binary_erosion(mask_a, iterations=1)
  return mask_a

❶ 我们这里不需要梯度,所以我们不构建图。

❷ 这个数组将保存我们的输出:一个概率注释的浮点数组。

❸ 我们获得一个数据加载器,让我们可以按批次循环遍历我们的 CT。

❹ 将输入移动到 GPU 后…

❺ … 我们运行分割模型 …

❻ … 并将每个元素复制到输出数组中。

❼ 将概率输出阈值化以获得二进制输出,然后应用二进制腐蚀进行清理

这已经足够简单了,但现在我们需要发明分组。

14.3.2 将体素分组为结节候选

我们将使用一个简单的连通分量算法将我们怀疑的结节体素分组成块以输入分类。这种分组方法标记连接的组件,我们将使用 scipy.ndimage.measurements.label 完成。label 函数将获取所有与另一个非零像素共享边缘的非零像素,并将它们标记为属于同一组。由于我们从分割模型输出的大部分都是高度相邻像素的块,这种方法很好地匹配了我们的数据。

列表 14.3 nodule_analysis.py:401

def groupSegmentationOutput(self, series_uid,  ct, clean_a):
  candidateLabel_a, candidate_count = measurements.label(clean_a)   # ❶
  centerIrc_list = measurements.center_of_mass(                     # ❷
    ct.hu_a.clip(-1000, 1000) + 1001,
    labels=candidateLabel_a,
    index=np.arange(1, candidate_count+1),
  )

❶ 为每个体素分配所属组的标签

❷ 获取每个组的质心作为索引、行、列坐标

输出数组 candidateLabel_a 与我们用于输入的 clean_a 具有相同的形状,但在背景体素处为 0,并且递增的整数标签 1、2、…,每个连接的体素块组成一个结节候选。请注意,这里的标签 是分类意义上的标签!这只是在说“这个体素块是体素块 1,这边的体素块是体素块 2,依此类推”。

SciPy 还提供了一个函数来获取结节候选的质心:scipy.ndimage.measurements.center_of_mass。它接受一个每个体素密度的数组,刚刚调用的 label 函数返回的整数标签,以及需要计算质心的这些标签的列表。为了匹配函数期望的质量为非负数,我们将(截取的)ct.hu_a 偏移了 1,001。请注意,这导致所有标记的体素都携带一些权重,因为我们将最低的空气值在本机 CT 单位中夹紧到 -1,000 HU。

列表 14.4 nodule_analysis.py:409

candidateInfo_list = []
for i, center_irc in enumerate(centerIrc_list):
  center_xyz = irc2xyz(                                                   # ❶
    center_irc,
    ct.origin_xyz,
    ct.vxSize_xyz,
    ct.direction_a,
  )
  candidateInfo_tup = \
    CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)  # ❷
  candidateInfo_list.append(candidateInfo_tup)
return candidateInfo_list

❶ 将体素坐标转换为真实患者坐标

❷ 构建我们的候选信息元组并将其附加到检测列表中

作为输出,我们得到一个包含三个数组的列表(分别为索引、行和列),与我们的 candidate_count 长度相同。我们可以使用这些数据来填充一个 candidateInfo_tup 实例的列表;我们已经对这种小数据结构产生了依恋,所以我们将结果放入自从第十章以来一直在使用的相同类型的列表中。由于我们实际上没有适合的数据来填充前四个值(isNodule_boolhasAnnotation_boolisMal_booldiameter_mm),我们插入了适当类型的占位符值。然后我们在循环中将我们的坐标从体素转换为物理坐标,创建列表。将我们的坐标从基于数组的索引、行和列移开可能看起来有点愚蠢,但所有消耗 candidateInfo_tup 实例的代码都期望 center_xyz,而不是 center_irc。如果我们尝试互换一个和另一个,我们将得到极其错误的结果!

耶–我们征服了第 3 步,从体素级别的检测中获取结节位置!现在我们可以裁剪出疑似结节,并将它们馈送给我们的分类器,以进一步消除一些假阳性。

14.3.3 我们找到了结节吗?分类以减少假阳性

当我们开始本书的第 2 部分时,我们描述了放射科医生查看 CT 扫描以寻找癌症迹象的工作如下:

目前,审查数据的工作必须由经过高度训练的专家执行,需要对细节进行仔细的注意,主要是在不存在癌症的情况下。

做好这项工作就像被放在 100 堆草垛前,并被告知:“确定这些草垛中是否有针。”

我们已经花费了时间和精力讨论谚语中的针;让我们通过查看图 14.4 来讨论一下草垛。我们的工作,可以说,就是尽可能多地从我们那位眼睛发直的放射科医生面前的草垛中分离出来,这样他们就可以重新聚焦他们经过高度训练的注意力,以便发挥最大的作用。

图 14.4 我们端到端检测项目的步骤,以及每个步骤删除的数据的数量级。

让我们看看在执行端到端诊断时每个步骤丢弃了多少数据。图 14.4 中的箭头显示了数据从原始 CT 体素流经我们的项目到最终恶性确定的过程。以 X 结尾的每个箭头表示上一步丢弃的一部分数据;指向下一步的箭头代表经过筛选幸存下来的数据。请注意,这里的数字是非常近似的。

让我们更详细地看一下图 14.4 中的步骤:

  1. 分割 --分割从整个 CT 开始:数百张切片,或大约 3300 万(225)体素(加减很多)。大约有 220 个体素被标记为感兴趣的;这比总输入要小几个数量级,这意味着我们要丢弃 97%的体素(这是左边导致 X 的 225)。
  2. 分组。虽然分组并没有明确删除任何内容,但它确实减少了我们考虑的项目数量,因为我们将体素合并为结节候选者。分组从 100 万体素中产生了大约 1000 个候选者(210)。一个 16×16×2 体素的结节将有总共 210 个体素。²
  3. 结节分类。这个过程丢弃了剩下的大多数~210 个项目。从我们成千上万的结节候选者中,我们剩下了数十个结节:大约 25 个。
  4. 恶性分类。最后,恶性分类器会取出数十个结节(25 个),找出其中一个或两个(21 个)是癌症的。

沿途的每一步都允许我们丢弃大量数据,我们的模型确信这些数据与我们的癌症检测目标无关。我们从数百万数据点到少数肿瘤。

完全自动化与辅助系统

完全自动化系统和旨在增强人类能力的系统之间存在差异。对于我们的自动化系统,一旦一条数据被标记为无关紧要,它就永远消失了。然而,当向人类呈现数据供其消化时,我们应该允许他们剥开一些层次,查看近似情况,并用一定的信心程度注释我们的发现。如果我们设计一个用于临床使用的系统,我们需要仔细考虑我们确切的预期用途,并确保我们的系统设计能够很好地支持这些用例。由于我们的项目是完全自动化的,我们可以继续前进,而不必考虑如何最好地展示近似情况和不确定的答案。

现在我们已经确定了图像中我们的分割模型认为是潜在候选的区域,我们需要从 CT 中裁剪这些候选并将它们馈送到分类模块中。幸运的是,我们有前一节的 candidateInfo_list,所以我们只需要从中创建一个 DataSet,将其放入 DataLoader,并对其进行迭代。概率预测的第一列是预测的这是一个结节的概率,这是我们想要保留的。就像以前一样,我们收集整个循环的输出。

列表 14.5 结节分析.py:357,.classifyCandidates

def classifyCandidates(self, ct, candidateInfo_list):
  cls_dl = self.initClassificationDl(candidateInfo_list)        # ❶
  classifications_list = []
  for batch_ndx, batch_tup in enumerate(cls_dl):
    input_t, _, _, series_list, center_list = batch_tup
    input_g = input_t.to(self.device)                           # ❷
    with torch.no_grad():
      _, probability_nodule_g = self.cls_model(input_g)         # ❸
      if self.malignancy_model is not None:                     # ❹
        _, probability_mal_g = self.malignancy_model(input_g)
      else:
        probability_mal_g = torch.zeros_like(probability_nodule_g)
    zip_iter = zip(center_list,
      probability_nodule_g[:,1].tolist(),
      probability_mal_g[:,1].tolist())
    for center_irc, prob_nodule, prob_mal in zip_iter:          # ❺
      center_xyz = irc2xyz(center_irc,
        direction_a=ct.direction_a,
        origin_xyz=ct.origin_xyz,
        vxSize_xyz=ct.vxSize_xyz,
      )
      cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
      classifications_list.append(cls_tup)
  return classifications_list

❶ 再次,我们获得一个数据加载器来循环遍历,这次是基于我们的候选列表。

❷ 将输入发送到设备

❸ 将输入通过结节与非结节网络运行

❹ 如果我们有一个恶性模型,我们也运行它。

❺ 进行我们的簿记,构建我们结果的列表

这太棒了!我们现在可以将输出概率阈值化,得到我们的模型认为是实际结节的列表。在实际设置中,我们可能希望将它们输出供放射科医生检查。同样,我们可能希望调整阈值以更安全地出错一点:也就是说,如果我们的阈值是 0.3 而不是 0.5,我们将呈现更多的候选,结果证明不是结节,同时减少错过实际结节的风险。

列表 14.6 结节分析.py:333,NoduleAnalysisApp.main

if not self.cli_args.run_validation:                                  # ❶
    print(f"found nodule candidates in {series_uid}:")
    for prob, prob_mal, center_xyz, center_irc in classifications_list:
      if prob > 0.5:                                                    # ❷
        s = f"nodule prob {prob:.3f}, "
        if self.malignancy_model:
          s += f"malignancy prob {prob_mal:.3f}, "
        s += f"center xyz {center_xyz}"
        print(s)
  if series_uid in candidateInfo_dict:                                  # ❸
    one_confusion = match_and_score(
      classifications_list, candidateInfo_dict[series_uid]
    )
    all_confusion += one_confusion
    print_confusion(
      series_uid, one_confusion, self.malignancy_model is not None
    )
print_confusion(
  "Total", all_confusion, self.malignancy_model is not None
)

❶ 如果我们不通过运行验证,我们打印单独的信息…

❷ … 对于分割找到的所有候选,其中分类器分配的结节概率为 50% 或更高。

❸ 如果我们有真实数据,我们计算并打印混淆矩阵,并将当前结果添加到总数中。

让我们针对验证集中的给定 CT 运行这个:³

$ python3.6 -m p2ch14.nodule_analysis 1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864
...
found nodule candidates in 1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864:
nodule prob 0.533, malignancy prob 0.030, center xyz XyzTuple   # ❶(x=-128.857421875, y=-80.349609375, z=-31.300007820129395) 
nodule prob 0.754, malignancy prob 0.446, center xyz XyzTuple(x=-116.396484375, y=-168.142578125, z=-238.30000233650208)
...
nodule prob 0.974, malignancy prob 0.427, center xyz XyzTuple   # ❷(x=121.494140625, y=-45.798828125, z=-211.3000030517578)
nodule prob 0.700, malignancy prob 0.310, center xyz XyzTuple(x=123.759765625, y=-44.666015625, z=-211.3000030517578)
...

❶ 这个候选被分配了 53% 的恶性概率,所以它勉强达到了 50% 的概率阈值。恶性分类分配了一个非常低(3%)的概率。

❷ 被检测为结节,具有非常高的置信度,并被分配了 42% 的恶性概率

脚本总共找到了 16 个结节候选。由于我们正在使用验证集,我们对每个 CT 都有完整的注释和恶性信息,我们可以使用这些信息创建一个混淆矩阵来展示我们的结果。行是真相(由注释定义),列显示我们的项目如何处理每种情况:

1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864           # ❶
                  |    Complete Miss |     Filtered Out |     Pred. Nodule # ❷
      Non-Nodules |                  |             1088 |               15 # ❸
           Benign |                1 |                0 |                0
        Malignant |                0 |                0 |                1

❶ 扫描 ID

❷ 预后:完全未检出表示分割未找到结节,被过滤掉是分类器的工作,预测结节是它标记为结节的。

❸ 行包含了真相。

完全未检出列是当我们的分割器根本没有标记结节时。由于分割器并不试图标记非结节,我们将该单元格留空。我们的分割器经过训练具有很高的召回率,因此有大量的非结节,但我们的结节分类器很擅长筛选它们。

所以我们在这个扫描中找到了 1 个恶性结节,但漏掉了第 17 个良性结节。此外,有 15 个误报的非结节通过了结节分类器。分类器的过滤将误报降至 1,000 多个!正如我们之前看到的,1,088 大约是 O(210),所以这符合我们的预期。同样,15 大约是 O(24),这与我们估计的 O(25) 差不多。

很棒!但更大的画面是什么?

14.4 定量验证

现在我们有了一些个案证据表明我们建立的东西可能在一个案例上起作用,让我们看看我们的模型在整个验证集上的表现。这样做很简单:我们将我们的验证集通过之前的预测运行,检查我们得到了多少结节,漏掉了多少,以及多少候选被错误地识别为结节。

我们运行以下内容,如果在 GPU 上运行,应该需要半小时到一个小时。喝完咖啡(或者睡个好觉)后,这是我们得到的结果:

$ python3 -m p2ch14.nodule_analysis --run-validation
...
Total
                 |    Complete Miss |     Filtered Out |     Pred. Nodule
     Non-Nodules |                  |           164893 |             2156
          Benign |               12 |                3 |               87
       Malignant |                1 |                6 |               45

我们检测到了 154 个结节中的 132 个,或者 85%。我们错过的 22 个中,有 13 个未被分割认为是候选结节,因此这将是改进的明显起点。

大约 95%的检测到的结节是假阳性。这当然不是很好;另一方面,这并不是很关键–不得不查看 20 个结节候选才能找到一个结节要比查看整个 CT 要容易得多。我们将在第 14.7.2 节中更详细地讨论这一点,但我们要强调的是,与其将这些错误视为黑匣子,不如调查被错误分类的情况并看看它们是否有共同点。有什么特征可以将它们与被正确分类的样本区分开吗?我们能找到什么可以用来改善我们表现的东西吗?

目前,我们将接受我们的数字如此:不错,但并非完美。当您运行自己训练的模型时,确切的数字可能会有所不同。在本章末尾,我们将提供一些指向可以帮助改善这些数字的论文和技术。通过灵感和一些实验,我们确信您可以获得比我们在这里展示的更好的分数。

14.5 预测恶性

现在我们已经实现了 LUNA 挑战的结节检测任务,并可以生成自己的结节预测,我们问自己一个逻辑上的下一个问题:我们能区分恶性结节和良性结节吗?我们应该说,即使有一个好的系统,诊断恶性可能需要更全面地查看患者,额外的非 CT 背景信息,最终可能需要活检,而不仅仅是孤立地查看 CT 扫描中的单个结节。因此,这似乎是一个可能由医生执行的任务,未来可能会有一段时间。

14.5.1 获取恶性信息

LUNA 挑战专注于结节检测,并不包含恶性信息。LIDC-IDRI 数据集(mng.bz/4A4R)包含了用于 LUNA 数据集的 CT 扫描的超集,并包括有关已识别肿瘤恶性程度的额外信息。方便地,有一个可以轻松安装的 PyLIDC 库,如下所示:

$ pip3 install pylidc

pylicd库为我们提供了我们想要的额外恶性信息的便捷访问。就像我们在第 10 章中所做的那样,将 LIDC 的注释与 LUNA 候选者的坐标匹配,我们需要将 LIDC 的注释信息与 LUNA 候选者的坐标关联起来。

在 LIDC 注释中,恶性信息按照每个结节和诊断放射科医师(最多四位医师查看同一结节)使用从 1(高度不可能)到适度不可能、不确定、适度可疑,最后是 5(高度可疑)的有序五值量表进行编码。这些注释基于图像本身,并受到关于患者的假设的影响。为了将数字列表转换为单个布尔值是/否,我们将考虑当至少有两位放射科医师将该结节评为“适度可疑”或更高时,结节被认为是恶性的。请注意,这个标准有些是任意的;事实上,文献中有许多不同的处理这些数据的方法,包括预测五个步骤,使用平均值,或者从数据集中删除放射科医师评级不确定或不一致的结节。

结合数据的技术方面与第十章相同,因此我们跳过在此处显示代码(代码存储库中有此章节的代码),并将使用扩展的 CSV 文件。我们将以与我们为结节分类器所做的非常相似的方式使用数据集,只是现在我们只需要处理实际结节,并使用给定结节是否为恶性作为要预测的标签。这在结构上与我们在第十二章中使用的平衡非常相似,但我们不是从pos_listneg_list中抽样,而是从mal_listben_list中抽样。就像我们为结节分类器所做的那样,我们希望保持训练数据平衡。我们将这些放入MalignancyLunaDataset类中,该类是LunaDataset的子类,但在其他方面非常相似。

为了方便起见,我们在 training.py 中创建了一个dataset命令行参数,并动态使用命令行指定的数据集类。我们通过使用 Python 的getattr函数来实现这一点。例如,如果self.cli_args.dataset是字符串MalignancyLunaDataset,它将获取p2ch14.dsets.MalignancyLunaDataset并将此类型分配给ds_cls,我们可以在这里看到。

列表 14.7 training.py:154,.initTrainDl

ds_cls = getattr(p2ch14.dsets, self.cli_args.dataset)   # ❶
train_ds = ds_cls(
  val_stride=10,
  isValSet_bool=False,
  ratio_int=1,                                          # ❷
)

❶ 动态类名查找

❷ 请记住,这是训练数据之间的一对一平衡,这里是良性和恶性之间的平衡。

14.5.2 曲线下面积基线:按直径分类

有一个基线总是好的,可以看到什么性能比没有好。我们可以追求比随机更好,但在这里我们可以使用直径作为恶性的预测因子–更大的结节更有可能是恶性的。图 14.5 的第 2b 步提示了一个我们可以用来比较分类器的新度量标准。

图 14.5 我们在本章中实施的端到端项目,重点是 ROC 图

我们可以将结节直径作为假设分类器预测结节是否为恶性的唯一输入。这不会是一个很好的分类器,但事实证明,说“一切大于这个阈值 X 的东西都是恶性的”比我们预期的更好地预测了恶性。当然,选择正确的阈值是关键–有一个甜蜜点,可以获取所有巨大的肿瘤,而没有任何微小的斑点,并且大致分割了那个不确定区域,其中有一堆较大的良性结节和较小的恶性结节。

正如我们可能从第十二章中记得的那样,我们的真正阳性、假正性、真正性和假负性计数会根据我们选择的阈值值而改变。当我们降低我们预测结节为恶性的阈值时,我们将增加真正阳性的数量,但也会增加假正性的数量。假正率(FPR)是 FP /(FP + TN),而真正率(TPR)是 TP /(TP + FN),您可能还记得这是从第十二章中的召回中得到的。

测量假阳性没有一种真正的方法:精度与假阳性率

这里的 FPR 和第十二章中的精度是(介于 0 和 1 之间的)率,用于衡量不完全相反的事物。正如我们讨论过的,精度是 TP /(TP + FP),用于衡量预测为阳性的样本中有多少实际上是阳性的。FPR 是 FP /(FP + TN),用于衡量实际上为负的样本中有多少被预测为阳性。对于极度不平衡的数据集(如结节与非结节分类),我们的模型可能会实现非常好的 FPR(这与交叉熵标准作为损失密切相关),而精度–因此 F1 分数–仍然非常差。低 FPR 意味着我们正在淘汰我们不感兴趣的很多内容,但如果我们正在寻找那根传说中的针,我们仍然主要是干草。

让我们为我们的阈值设定一个范围。下限将是使得所有样本都被分类为阳性的值,上限将是相反的情况,即所有样本都被分类为阴性。在一个极端情况下,我们的 FPR 和 TPR 都将为零,因为不会有任何阳性;在另一个极端情况下,两者都将为一,因为不会有 TN 和 FN(一切都是阳性!)。

对于我们的结节数据,直径范围从 3.25 毫米(最小结节)到 22.78 毫米(最大结节)。如果我们选择一个介于这两个值之间的阈值,然后可以计算 FPR(阈值)和 TPR(阈值)。如果我们将 FPR 值设为X,TPR 设为Y,我们可以绘制代表该阈值的点;如果我们反而绘制每个可能阈值的 FPR 对 TPR,我们得到一个名为受试者工作特征(ROC)的图表,如图 14.6 所示。阴影区域是ROC 曲线下的面积,或者 AUC。它的取值范围在 0 到 1 之间,数值越高越好。⁵


图 14.6 我们基线的受试者工作特征(ROC)曲线

在这里,我们还指出了两个特定的阈值:直径为 5.42 毫米和 10.55 毫米。我们选择这两个值,因为它们为我们可能考虑的阈值范围提供了相对合理的端点,如果我们需要选择一个单一的阈值。小于 5.42 毫米,我们只会降低我们的 TPR。大于 10.55 毫米,我们只会将恶性结节标记为良性而没有任何收益。这个分类器的最佳阈值可能会在中间某处。

我们实际上是如何计算这里显示的数值的呢?我们首先获取候选信息列表,过滤出已注释的结节,并获取恶性标签和直径。为了方便起见,我们还获取了良性和恶性结节的数量。

列表 14.8 p2ch14_malben_baseline.ipynb

# In[2]:
ds = p2ch14.dsets.MalignantLunaDataset(val_stride=10, isValSet_bool=True) # ❶
nodules = ds.ben_list + ds.mal_list
is_mal = torch.tensor([n.isMal_bool for n in nodules])                    # ❷
diam  = torch.tensor([n.diameter_mm for n in nodules])
num_mal = is_mal.sum()                                                    # ❸
num_ben = len(is_mal) - num_mal

❶ 获取常规数据集,特别是良性和恶性结节的列表

❷ 获取恶性状态和直径的列表

❸ 为了对 TPR 和 FPR 进行归一化,我们获取了恶性和良性结节的数量。

要计算 ROC 曲线,我们需要一个可能阈值的数组。我们从 torch.linspace 获取这个数组,它取两个边界元素。我们希望从零预测的阳性开始,所以我们从最大阈值到最小阈值。这就是我们已经提到的 3.25 到 22.78:

# In[3]:
threshold = torch.linspace(diam.max(), diam.min())

然后我们构建一个二维张量,其中行是每个阈值,列是每个样本信息,值是该样本是否被预测为阳性。然后根据样本的标签(恶性或良性)对此布尔张量进行过滤。我们对行求和以计算True条目的数量。除以恶性或良性结节的数量给出了 TPR 和 FPR–ROC 曲线的两个坐标:

# In[4]:
predictions = (diam[None] >= threshold[:, None])                   # ❶
tp_diam = (predictions & is_mal[None]).sum(1).float() / num_mal    # ❷
fp_diam = (predictions & ~is_mal[None]).sum(1).float() / num_ben

❶ 通过 None 索引添加了一个大小为 1 的维度,就像 .unsqueeze(ndx) 一样。这使我们得到一个 2D 张量,其中给定结节(在列中)是否被分类为恶性,直径(在行中)。

❷ 使用预测矩阵,我们可以通过对列求和来计算每个直径的 TPR 和 FPR。

要计算这条曲线下的面积,我们使用梯形法进行数值积分(en.wikipedia.org/wiki/Trapezoidal_rule),其中我们将两点之间的平均 TPR(Y 轴上)乘以两个 FPR 之间的差值(X 轴上)–图表中两点之间梯形的面积。然后我们将梯形的面积相加:

# In[5]:
fp_diam_diff =  fp_diam[1:] - fp_diam[:-1]
tp_diam_avg  = (tp_diam[1:] + tp_diam[:-1])/2
auc_diam = (fp_diam_diff * tp_diam_avg).sum()

现在,如果我们运行pyplot.plot(fp_diam, tp_diam, label=f"diameter baseline, AUC={auc_diam:.3f}")(以及我们在第 8 单元中看到的适当图表设置),我们将得到图 14.6 中看到的图表。

PyTorch 深度学习(GPT 重译)(六)(2)https://developer.aliyun.com/article/1485254

相关文章
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(七)(4)
JavaScript 权威指南第七版(GPT 重译)(七)
24 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(3)
JavaScript 权威指南第七版(GPT 重译)(七)
33 0
|
前端开发 JavaScript Unix
JavaScript 权威指南第七版(GPT 重译)(七)(2)
JavaScript 权威指南第七版(GPT 重译)(七)
42 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(1)
JavaScript 权威指南第七版(GPT 重译)(七)
60 0
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(4)
JavaScript 权威指南第七版(GPT 重译)(六)
93 2
JavaScript 权威指南第七版(GPT 重译)(六)(4)
|
13天前
|
前端开发 JavaScript API
JavaScript 权威指南第七版(GPT 重译)(六)(3)
JavaScript 权威指南第七版(GPT 重译)(六)
55 4
|
13天前
|
XML 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(2)
JavaScript 权威指南第七版(GPT 重译)(六)
60 4
JavaScript 权威指南第七版(GPT 重译)(六)(2)
|
13天前
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(六)(1)
JavaScript 权威指南第七版(GPT 重译)(六)
28 3
JavaScript 权威指南第七版(GPT 重译)(六)(1)
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(五)(4)
JavaScript 权威指南第七版(GPT 重译)(五)
39 9
|
13天前
|
前端开发 JavaScript 程序员
JavaScript 权威指南第七版(GPT 重译)(五)(3)
JavaScript 权威指南第七版(GPT 重译)(五)
36 8