PyTorch 深度学习(GPT 重译)(五)(2)

简介: PyTorch 深度学习(GPT 重译)(五)

PyTorch 深度学习(GPT 重译)(五)(1)https://developer.aliyun.com/article/1485248

12.4.2 将平衡的 LunaDataset 与之前的运行进行对比

作为提醒,我们不平衡的训练运行结果如下:

$ python -m p2ch12.training
...
E1 LunaTrainingApp
E1 trn      0.0185 loss,  99.7% correct, 0.0000 precision, 0.0000 recall, nan f1 score
E1 trn_neg  0.0026 loss, 100.0% correct (494717 of 494743)
E1 trn_pos  6.5267 loss,   0.0% correct (0 of 1215)
...
E1 val      0.0173 loss,  99.8% correct, nan precision, 0.0000 recall, nan f1 score
E1 val_neg  0.0026 loss, 100.0% correct (54971 of 54971)
E1 val_pos  5.9577 loss,   0.0% correct (0 of 136)

但是当我们使用--balanced运行时,我们看到以下情况:

$ python -m p2ch12.training --balanced
...
E1 LunaTrainingApp
E1 trn      0.1734 loss,  92.8% correct, 0.9363 precision, 0.9194 recall, 0.9277 f1 score
E1 trn_neg  0.1770 loss,  93.7% correct (93741 of 100000)
E1 trn_pos  0.1698 loss,  91.9% correct (91939 of 100000)
...
E1 val      0.0564 loss,  98.4% correct, 0.1102 precision, 0.7941 recall, 0.1935 f1 score
E1 val_neg  0.0542 loss,  98.4% correct (54099 of 54971)
E1 val_pos  0.9549 loss,  79.4% correct (108 of 136)

这看起来好多了!我们放弃了大约 5%的负样本正确答案,以获得 86%的正确正样本答案。我们又回到了一个扎实的 B 范围内!⁵

然而,就像第十一章一样,这个结果是具有欺骗性的。由于负样本比正样本多 400 倍,即使只有 1%的错误,也意味着我们会将负样本错误地分类为正样本,比实际正样本总数多四倍!

尽管如此,这显然比第十一章的完全错误行为要好得多,比随机抛硬币要好得多。事实上,我们甚至已经进入了(几乎)在实际场景中有用的领域。回想一下我们过度劳累的放射科医生仔细检查每一个 CT 上的每一个斑点:现在我们有了一些可以合理地筛除 95%的假阳性的东西。这是一个巨大的帮助,因为这意味着机器辅助人类的生产力增加了大约十倍。

当然,还有那令人讨厌的 14%被错过的正样本问题,我们可能需要处理一下。也许增加一些额外的训练轮次会有所帮助。让我们看看(再次提醒,每轮至少需要花费 10 分钟):

$ python -m p2ch12.training --balanced --epochs 20
...
E2 LunaTrainingApp
E2 trn      0.0432 loss,  98.7% correct, 0.9866 precision, 0.9879 recall, 0.9873 f1 score
E2 trn_ben  0.0545 loss,  98.7% correct (98663 of 100000)
E2 trn_mal  0.0318 loss,  98.8% correct (98790 of 100000)
E2 val      0.0603 loss,  98.5% correct, 0.1271 precision, 0.8456 recall, 0.2209 f1 score
E2 val_ben  0.0584 loss,  98.6% correct (54181 of 54971)
E2 val_mal  0.8471 loss,  84.6% correct (115 of 136)
...
E5 trn      0.0578 loss,  98.3% correct, 0.9839 precision, 0.9823 recall, 0.9831 f1 score
E5 trn_ben  0.0665 loss,  98.4% correct (98388 of 100000)
E5 trn_mal  0.0490 loss,  98.2% correct (98227 of 100000)
E5 val      0.0361 loss,  99.2% correct, 0.2129 precision, 0.8235 recall, 0.3384 f1 score
E5 val_ben  0.0336 loss,  99.2% correct (54557 of 54971)
E5 val_mal  1.0515 loss,  82.4% correct (112 of 136)...
...
E10 trn      0.0212 loss,  99.5% correct, 0.9942 precision, 0.9953 recall, 0.9948 f1 score
E10 trn_ben  0.0281 loss,  99.4% correct (99421 of 100000)
E10 trn_mal  0.0142 loss,  99.5% correct (99530 of 100000)
E10 val      0.0457 loss,  99.3% correct, 0.2171 precision, 0.7647 recall, 0.3382 f1 score
E10 val_ben  0.0407 loss,  99.3% correct (54596 of 54971)
E10 val_mal  2.0594 loss,  76.5% correct (104 of 136)
...
E20 trn      0.0132 loss,  99.7% correct, 0.9964 precision, 0.9974 recall, 0.9969 f1 score
E20 trn_ben  0.0186 loss,  99.6% correct (99642 of 100000)
E20 trn_mal  0.0079 loss,  99.7% correct (99736 of 100000)
E20 val      0.0200 loss,  99.7% correct, 0.4780 precision, 0.7206 recall, 0.5748 f1 score
E20 val_ben  0.0133 loss,  99.8% correct (54864 of 54971)
E20 val_mal  2.7101 loss,  72.1% correct (98 of 136)

哎呀。要滚动到我们感兴趣的数字,需要滚过很多文本。让我们坚持下去,专注于val_mal XX.X% correct数字(或者直接跳到下一节的 TensorBoard 图表)。第 2 轮之后,我们达到了 87.5%;第 5 轮时,我们达到了 92.6%的峰值;然后到了第 20 轮,我们下降到了 86.8%–低于我们的第二轮!

注意 正如前面提到的,由于网络权重的随机初始化和每轮训练样本的随机选择和排序,预计每次运行都会有独特的行为。

训练集的数字似乎没有同样的问题。负训练样本被正确分类的概率为 98.8%,正样本则为 99.1%。发生了什么?

12.4.3 识别过拟合的症状

我们所看到的是过拟合的明显迹象。让我们看一下我们在正样本上的损失图,见图 12.18。

图 12.18 我们的正损失显示出明显的过拟合迹象,因为训练损失和验证损失趋势不同。

在这里,我们可以看到我们的正样本的训练损失几乎为零–每个正样本训练样本都得到了几乎完美的预测。然而,我们的正样本的验证损失却在增加,这意味着我们的实际表现可能正在变差。在这一点上,最好停止训练脚本,因为模型不再改进。

提示 通常,如果您的模型在训练集上的表现正在提高,而在验证集上表现变差,那么模型已经开始过拟合。

然而,我们必须注意检查正确的指标,因为这种趋势只发生在我们的损失上。如果我们看一下我们的整体损失,一切似乎都很好!这是因为我们的验证集不平衡,所以整体损失被我们的负样本所主导。正如图 12.19 所示,我们在我们的负样本中没有看到相同的发散行为。相反,我们的负损失看起来很好!这是因为我们有 400 倍的负样本,所以模型要记住个别细节要困难得多。然而,我们的正训练集只有 1,215 个样本。虽然我们多次重复这些样本,但这并不会使它们更难记忆。模型正在从泛化原则转变为基本上记住这 1,215 个样本的怪癖,并声称不属于这几个样本之一的任何东西都是负样本。这包括负训练样本和我们验证集中的所有内容(正负样本都有)。

图 12.19 我们的负损失没有显示过拟合的迹象

显然,仍然存在一些泛化,因为我们大约正确分类了 70%的正验证集。我们只需要改变我们训练模型的方式,使我们的训练集和验证集都朝着正确的方向发展。

12.5 重新审视过拟合问题

我们在第五章中提到了过拟合的概念,现在是时候更仔细地看看如何解决这种常见情况了。我们训练模型的目标是教会它识别我们感兴趣的类别的一般属性,如我们数据集中所表达的那样。这些一般属性存在于该类别的一些或所有样本中,并且可以泛化并用于预测未经训练的样本。当模型开始学习训练集的特定属性时,就会发生过拟合,模型开始失去泛化的能力。如果这有点抽象,让我们使用另一个类比。

12.5.1 一个过拟合的人脸到年龄预测模型

假设我们有一个模型,它以人脸图像作为输入,并输出预测的年龄。一个好的模型会注意到年龄的特征,如皱纹、白发、发型、服装选择等,并利用这些建立不同年龄看起来的一般模型。当呈现一张新图片时,它会考虑“保守的发型”、“眼镜”和“皱纹”等因素,得出“大约 65 岁”的结论。

与之相比,过拟合模型则是通过记住识别细节来记住特定的人。“那个发型和那副眼镜意味着那是弗兰克。他 62.8 岁了”;“哦,那个伤疤意味着那是哈里。他 39.3 岁了”;等等。当展示一个新的人时,模型将无法识别这个人,也完全不知道该预测多少岁。

更糟糕的是,如果展示弗兰克的儿子的照片(看起来像他爸爸,至少戴着眼镜时是这样!),模型会说:“我认为那是弗兰克。他 62.8 岁了。”尽管小弗兰克实际上年轻了 25 岁!

过拟合通常是由于训练样本太少,与模型仅仅记住答案的能力相比。普通人可以记住自己家人的生日,但在预测比一个小村庄规模更大的群体的年龄时,就必须求助于概括。

我们的人脸到年龄模型有能力简单地记住那些看起来不完全符合其年龄的照片。正如我们在第 1 部分中讨论的,模型容量是一个有点抽象的概念,但大致是模型参数数量乘以这些参数的有效使用方式。当模型的容量相对于需要记住训练集中难样本的数据量很高时,模型很可能会开始过拟合这些更难的训练样本。

12.6 通过数据增强防止过拟合

是时候将我们的模型训练从好到优秀了。我们需要完成图 12.20 中的最后一步。

图 12.20 本章的主题集,重点是数据增强

我们通过对单个样本应用合成的改变来增强数据集,从而得到一个有效大小比原始数据集更大的新数据集。典型的目标是使改变导致合成样本仍然代表与源样本相同的一般类别,但不能与原始样本一起轻松记忆。当正确执行时,这种增强可以将训练集大小增加到模型能够记忆的范围之外,从而迫使模型越来越依赖泛化,这正是我们想要的。在处理有限数据时,这种增强尤其有用,正如我们在第 12.4.1 节中看到的。

当然,并非所有的增强都同样有用。回到我们的面部年龄预测模型的例子,我们可以轻松地将每个图像的四个角像素的红色通道更改为随机值 0-255,这将导致数据集比原始数据集大 40 亿倍。当然,这并不特别有用,因为模型可以相当轻松地学会忽略图像角落的红点,而图像的其余部分仍然像单个未经增强的原始图像一样容易记忆。将这种方法与左右翻转图像进行对比。这样做只会使数据集比原始数据集大两倍,但每个图像对于训练目的来说会更有用。年龄的一般属性与左右无关,因此镜像图像仍然具有代表性。同样,面部图片很少是完全对称的,因此镜像版本不太可能与原始版本轻松记忆。

12.6.1 具体的数据增强技术

我们将实现五种特定类型的数据增强。我们的实现将允许我们单独或合并地对任何一种或全部进行实验。这五种技术如下:

  • 将图像上下、左右和/或前后镜像
  • 将图像移动几个体素
  • 将图像放大或缩小
  • 将图像围绕头-脚轴旋转
  • 添加噪声到图像

对于每种技术,我们希望确保我们的方法保持训练样本的代表性,同时又足够不同,以便样本用于训练时是有用的。

我们将定义一个函数 getCtAugmentedCandidate,负责获取我们标准的 CT 块并对其中的候选进行修改。我们的主要方法将定义一个仿射变换矩阵(mng.bz/Edxq),并将其与 PyTorch 的 affine_gridpytorch.org/docs/stable/nn.html#affine-grid)和 grid_samplepytorch.org/docs/stable/nn.html#torch.nn.functional.grid_sample)函数一起使用,以对我们的候选进行重新采样。

列表 12.11 dsets.py:149, def getCtAugmentedCandidate

def getCtAugmentedCandidate(
    augmentation_dict,
    series_uid, center_xyz, width_irc,
    use_cache=True):
  if use_cache:
    ct_chunk, center_irc = \
      getCtRawCandidate(series_uid, center_xyz, width_irc)
  else:
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)

我们首先获取 ct_chunk,可以从缓存中获取,也可以直接通过加载 CT 获取(这在我们创建自己的候选中会很方便),然后将其转换为张量。接下来是仿射网格和采样代码。

列表 12.12 dsets.py:162, def getCtAugmentedCandidate

transform_t = torch.eye(4)
# ...                        # ❶
# ... line 195
affine_t = F.affine_grid(
    transform_t[:3].unsqueeze(0).to(torch.float32),
    ct_t.size(),
    align_corners=False,
  )
augmented_chunk = F.grid_sample(
    ct_t,
    affine_t,
    padding_mode='border',
    align_corners=False,
  ).to('cpu')
# ... line 214
return augmented_chunk[0], center_irc

❶ 转换 transform_tensor 的修改将在这里进行。

没有任何额外的东西,这个函数不会有太多作用。让我们看看需要添加一些实际变换的步骤。

注意 重要的是要构建数据流水线,使得缓存步骤发生在增强之前!否则将导致数据被增强一次,然后保留在那种状态,这违背了初衷。

镜像

当镜像一个样本时,我们保持像素值完全相同,只改变图像的方向。由于肿瘤生长与左右或前后没有强烈的相关性,我们应该能够在不改变样本代表性质的情况下翻转它们。指数轴(在患者坐标中称为Z)对应于直立人体中的重力方向,然而,肿瘤的顶部和底部可能存在差异的可能性。我们将假设这没问题,因为快速的视觉调查并没有显示任何明显的偏差。如果我们正在进行一个临床相关的项目,我们需要向专家确认这一假设。

列表 12.13 dsets.py:165,def getCtAugmentedCandidate

for i in range(3):
  if 'flip' in augmentation_dict:
    if random.random() > 0.5:
      transform_t[i,i] *= -1

grid_sample 函数将范围 [-1, 1] 映射到旧张量和新张量的范围(如果大小不同,则会隐式地进行重新缩放)。这个范围映射意味着为了镜像数据,我们只需要将变换矩阵的相关元素乘以 -1。

通过随机偏移进行移动

将结节候选物体移动一下不会产生很大的影响,因为卷积是独立于平移的,尽管这会使我们的模型对不完全居中的结节更加稳健。更重要的是,偏移量可能不是整数个体素数;相反,数据将使用三线性插值重新采样,这可能会引入一些轻微的模糊。样本边缘的体素将被重复,这可以看作是沿边界的一部分呈现出模糊、条纹状的区域。

列表 12.14 dsets.py:165,def getCtAugmentedCandidate

for i in range(3):
  # ... line 170
  if 'offset' in augmentation_dict:
    offset_float = augmentation_dict['offset']
    random_float = (random.random() * 2 - 1)
    transform_t[i,3] = offset_float * random_float

请注意,我们的 'offset' 参数是以与网格采样函数期望的 [-1, 1] 范围相同的比例表示的最大偏移量。

缩放

稍微缩放图像与镜像和移动非常相似。这样做也会导致我们刚刚讨论的在移动样本时提到的相同重复边缘体素。

列表 12.15 dsets.py:165,def getCtAugmentedCandidate

for i in range(3):
  # ... line 175
  if 'scale' in augmentation_dict:
    scale_float = augmentation_dict['scale']
    random_float = (random.random() * 2 - 1)
    transform_t[i,i] *= 1.0 + scale_float * random_float

由于 random_float 被转换为在范围 [-1, 1],所以实际上无论我们将 scale_float * random_float 添加到 1.0 还是从 1.0 中减去它都没有关系。

旋转

旋转是我们将使用的第一种增强技术,我们必须仔细考虑我们的数据,以确保我们不会通过导致其不再具有代表性的转换来破坏我们的样本。请记住,我们的 CT 切片在行和列(X 和 Y 轴)上具有均匀间距,但在指数(或 Z)方向上,体素是非立方体的。这意味着我们不能将这些轴视为可互换的。

一种选择是重新采样我们的数据,使得我们沿指数轴的分辨率与其他两个轴的分辨率相同,但这并不是一个真正的解决方案,因为沿着那个轴的数据会非常模糊和模糊。即使我们插入更多的体素,数据的保真度仍然很差。相反,我们将把这个轴视为特殊轴,并将我们的旋转限制在 X-Y 平面上。

列表 12.16 dsets.py:181,def getCtAugmentedCandidate

if 'rotate' in augmentation_dict:
  angle_rad = random.random() * math.pi * 2
  s = math.sin(angle_rad)
  c = math.cos(angle_rad)
  rotation_t = torch.tensor([
    [c, -s, 0, 0],
    [s, c, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1],
  ])
  transform_t @= rotation_t
噪音

我们的最终增强技术与其他技术不同,因为它在某种程度上对我们的样本进行了积极破坏,而翻转或旋转样本则没有这种情况。如果我们向样本添加太多噪音,它将淹没真实数据,并使其实际上无法分类。虽然如果我们使用极端输入值,移动和缩放样本也会产生类似的效果,但我们选择的值只会影响样本的边缘。噪音将对整个图像产生影响。

列表 12.17 dsets.py:208,def getCtAugmentedCandidate

if 'noise' in augmentation_dict:
  noise_t = torch.randn_like(augmented_chunk)
  noise_t *= augmentation_dict['noise']
  augmented_chunk += noise_t

其他增强类型已经增加了我们数据集的有效大小。噪音使我们模型的工作更加困难。一旦我们看到一些训练结果,我们将重新审视这一点。

检查增强候选物体

我们可以在图 12.21 中看到我们努力的结果。左上角的图像显示了一个未增强的正候选样本,接下来的五个图像显示了每种增强类型的效果。最后,底部行显示了三次组合结果。

图 12.21 在正结节样本上执行的各种增强类型

由于对增强数据集的每次__getitem__调用都会随机重新应用增强,底部行的每个图像看起来都不同。这也意味着几乎不可能再次生成完全相同的图像!还要记住,有时'flip'增强会导致没有翻转。始终返回翻转图像与一开始不翻转一样限制。现在让我们看看这是否有所不同。

12.6.2 从数据增强中看到改进

我们将训练额外的模型,每种增强类型一个,还有一个将所有增强类型组合在一起的额外模型训练运行。一旦它们完成,我们将在 TensorBoard 中查看我们的数据。

为了能够打开和关闭我们的新增强类型,我们需要将augmentation_dict的构建暴露给我们的命令行界面。程序的参数将通过parser.add_argument调用添加(未显示,但类似于我们的程序已经具有的那些),然后将被馈送到实际构建augmentation_dict的代码中。

列表 12.18 training.py:105,LunaTrainingApp.__init__

self.augmentation_dict = {}
if self.cli_args.augmented or self.cli_args.augment_flip:
  self.augmentation_dict['flip'] = True
if self.cli_args.augmented or self.cli_args.augment_offset:
  self.augmentation_dict['offset'] = 0.1                     # ❶
if self.cli_args.augmented or self.cli_args.augment_scale:
  self.augmentation_dict['scale'] = 0.2                      # ❶
if self.cli_args.augmented or self.cli_args.augment_rotate:
  self.augmentation_dict['rotate'] = True
if self.cli_args.augmented or self.cli_args.augment_noise:
  self.augmentation_dict['noise'] = 25.0                     # ❶

❶ 这些值是经验选择的,具有合理的影响,但可能存在更好的值。

现在我们已经准备好这些命令行参数,您可以运行以下命令,或者重新查看 p2_run_everything.ipynb 并运行第 8 到 16 个单元格。无论如何运行,都需要花费相当长的时间才能完成:

$ .venv/bin/python -m p2ch12.prepcache                   # ❶
$ .venv/bin/python -m p2ch12.training --epochs 20 \
        --balanced sanity-bal                            # ❷
$ .venv/bin/python -m p2ch12.training --epochs 10 \
        --balanced --augment-flip   sanity-bal-flip
$ .venv/bin/python -m p2ch12.training --epochs 10 \
        --balanced --augment-shift  sanity-bal-shift
$ .venv/bin/python -m p2ch12.training --epochs 10 \
        --balanced --augment-scale  sanity-bal-scale
$ .venv/bin/python -m p2ch12.training --epochs 10 \
        --balanced --augment-rotate sanity-bal-rotate
$ .venv/bin/python -m p2ch12.training --epochs 10 \
        --balanced --augment-noise  sanity-bal-noise
$ .venv/bin/python -m p2ch12.training --epochs 20 \
        --balanced --augmented sanity-bal-aug

❶ 您每章只需要准备一次缓存。

❷ 您可能在本章的早些时候运行过这个;在这种情况下,无需重新运行!

在此期间,我们可以启动 TensorBoard。让我们通过更改logdir参数来指示它仅显示这些运行,如下所示:../path/to/tensorboard --logdir runs/p2ch12

根据您手头的硬件情况,训练可能需要很长时间。如果需要加快进度,可以跳过flipshiftscale训练任务,并将第一次和最后一次运行减少到 11 个周期。我们选择了 20 次运行,因为这有助于使它们脱颖而出,但 11 次也可以。

如果让所有内容运行到完成,您的 TensorBoard 应该有类似图 12.22 所示的数据。我们将取消选择除验证数据之外的所有内容,以减少混乱。当您实时查看数据时,还可以更改平滑值,这有助于澄清趋势线。快速查看一下图,然后我们将详细介绍它。

图 12.22 用各种增强方案训练的网络在验证集上正确分类的百分比、损失、F1 分数、精度和召回率

在左上角的图表中第一件要注意的事情(“标签:正确/全部”)是各个增强类型有些混乱。我们的未增强和完全增强的运行位于该混乱的两侧。这意味着当结合时,我们的增强效果超过了其各部分之和。还有一个有趣的地方是,我们的完全增强运行得到了更多错误答案。虽然这通常是不好的,但如果我们看一下右侧的图像列(重点是我们实际关心的正候选样本–那些真正的结节),我们会发现我们的完全增强模型在查找正候选样本方面要好得多。完全增强模型的召回率很高!它也更不容易过拟合。正如我们之前看到的,我们的未增强模型随着时间的推移变得更糟。

值得注意的一点是,噪声增强模型在识别结节方面比未增强模型更差。如果我们记得我们说过噪声会让模型的工作变得更困难,这就说得通了。

在实时数据中看到的另一个有趣的事情(在这里有点混乱)是,旋转增强模型在召回方面几乎与完全增强模型一样好,并且在精度上有很大提高。由于我们的 F1 分数受精度限制(由于负样本数量较高),旋转增强模型的 F1 分数也更高。

未来我们将继续使用完全增强的模型,因为我们的用例需要高召回率。F1 分数仍将用于确定哪个时期保存为最佳。在实际项目中,我们可能希望花费额外的时间来调查不同的增强类型和参数值组合是否能产生更好的结果。

12.7 结论

在本章中,我们花了很多时间和精力重新构思我们对模型性能的看法。通过糟糕的评估方法很容易被误导,而且对评估模型的因素有强烈的直觉理解至关重要。一旦这些基本原理内化,就更容易发现我们何时被误导。

我们还学习了如何处理数据源不足的情况。能够合成代表性的训练样本非常有用。确实很少有太多的训练数据的情况!

现在我们有一个表现合理的分类器,我们将把注意力转向自动查找候选结节进行分类。第十三章将从那里开始;然后,在第十四章中,我们将把这些候选者反馈到我们在这里开发的分类器中,并着手构建另一个分类器来区分恶性结节和良性结节。

12.8 练习

  1. F1 分数可以推广支持除 1 以外的值。
  1. 阅读en.wikipedia.org/wiki/F1_score,并实现 F2 和 F0.5 分数。
  2. 确定 F1、F2 和 F0.5 中哪个对这个项目最有意义。跟踪该值,并与 F1 分数进行比较和对比。⁶
  1. 实现WeightedRandomSampler方法来平衡LunaDataset的正负训练样本,ratio_int设置为0
  1. 您如何获取每个样本类别的所需信息?
  2. 哪种方法更容易?哪种导致更易读的代码?
  1. 尝试不同的类平衡方案。
  1. 两个时期后哪个比例得分最高?20 个时期后呢?
  2. 如果比例是epoch_ndx的函数会怎样?
  1. 尝试不同的数据增强方法。
  1. 是否可以使任何现有方法更具侵略性(噪声、偏移等)?
  2. 噪声增强的包含是否有助于或妨碍您的训练结果?
  • 是否有其他值会改变这个结果?
  1. 研究其他项目使用的数据增强方法。这里有哪些适用的?
  • 为正结节候选实现“mixup”增强。这有帮助吗?
  1. 将初始归一化从nn.BatchNorm更改为自定义内容,并重新训练模型。
  1. 使用固定归一化能获得更好的结果吗?
  2. 什么归一化偏移和比例是有意义的?
  3. 非线性归一化如平方根是否有帮助?
  1. TensorBoard 除了我们在这里介绍的内容之外还可以显示哪些其他数据?
  1. 你能让它显示有关网络权重的信息吗?
  2. 在运行模型对特定样本的中间结果时有什么?
  • 将模型的骨干包装在nn.Sequential的实例中是否有助于或妨碍这一努力?

12.9 总结

  • 二进制标签和二进制分类阈值结合在一起,将数据集分成四个象限:真正阳性、真正阴性、假阴性和假阳性。这四个量为我们改进的性能指标提供了基础。
  • 回忆是模型最大化真正阳性的能力。选择每一个项目都能保证完美的回忆——因为所有正确答案都包括在内——但也表现出较低的精度。
  • 精度是模型最小化假阳性的能力。不选择任何内容保证了完美的精度——因为没有错误答案被包括在内——但也表现出较低的回忆。
  • F1 分数将精度和回忆结合成一个描述模型性能的单一指标。我们使用 F1 分数来确定对训练或模型进行的更改对我们的性能有何影响。
  • 在训练过程中平衡训练集,使得正负样本数量相等,可以使模型表现更好(定义为具有正的、增加的 F1 分数)。
  • 数据增强是指采用现有的有机数据样本并对其进行修改,使得生成的增强样本与原始样本有明显不同,但仍代表同一类别的样本。这样可以在数据有限的情况下进行额外的训练而不会过拟合。
  • 常见的数据增强策略包括改变方向、镜像、重新缩放、偏移、添加噪音。根据项目的不同,其他更具体的策略也可能相关。

¹ 没有人实际说过这个。

² 如果花费的时间超过这个时间,请确保您已运行prepcache脚本。

³ 请记住,这些图像只是分类空间的一种表示,不代表真实情况。

⁴ 目前尚不清楚这是否属实,但这是有可能的,而且损失确实在改善中……

⁵ 请记住,这是在仅呈现了 200,000 个训练样本之后,而不是不平衡数据集的 500,000+个样本之后,所以我们用了不到一半的时间就达到了这个结果。

⁶ 是的,这是一个暗示,这不是 F1 分数!

十三、使用分割找到可疑结节

本章涵盖

  • 使用像素到像素模型对数据进行分割
  • 使用 U-Net 进行分割
  • 使用 Dice 损失理解掩模预测
  • 评估分割模型的性能

在过去的四章中,我们取得了很大的进展。我们了解了 CT 扫描和肺部肿瘤,数据集和数据加载器,以及指标和监控。我们还应用了我们在第一部分学到的许多东西,并且我们有一个可用的分类器。然而,我们仍然在一个有些人为的环境中操作,因为我们需要手动注释的结节候选信息加载到我们的分类器中。我们没有一个很好的方法可以自动创建这个输入。仅仅将整个 CT 输入到我们的模型中——也就是说,插入重叠的 32×32×32 数据块——会导致每个 CT 有 31×31×7=6,727 个数据块,大约是我们拥有的注释样本数量的 10 倍。我们需要重叠边缘;我们的分类器期望结节候选位于中心,即使如此,不一致的定位可能会带来问题。

正如我们在第九章中解释的,我们的项目使用多个步骤来解决定位可能结节、识别它们,并指示可能恶性的问题。这是从业者中常见的方法,而在深度学习研究中,有一种倾向是展示单个模型解决复杂问题的能力。我们在本书中使用的多阶段项目设计给了我们一个很好的借口,逐步介绍新概念。

13.1 向我们的项目添加第二个模型

在前两章中,我们完成了图 13.1 中显示的计划的第 4 步:分类。在本章中,我们不仅要回到上一步,而是回到上两步。我们需要找到一种方法告诉我们的分类器在哪里查找。为此,我们将对原始 CT 扫描进行处理,找出可能是结节的所有内容。这是图中突出显示的第 2 步。为了找到这些可能的结节,我们必须标记看起来可能是结节的体素,这个过程被称为分割。然后,在第十四章中,我们将处理第 3 步,并通过将这幅图像的分割掩模转换为位置注释来提供桥梁。

图 13.1 我们的端到端肺癌检测项目,重点关注本章主题:第 2 步,分割

到本章结束时,我们将创建一个新模型,其架构可以执行像素级标记,或分割。完成这项任务的代码将与上一章的代码非常相似,特别是如果我们专注于更大的结构。我们将要做出的所有更改都将更小且有针对性。正如我们在图 13.2 中看到的,我们需要更新我们的模型(图中的第 2A 步),数据集(2B),以及训练循环(2C),以适应新模型的输入、输出和其他要求。(如果你在图中右侧的步骤 2 中不认识每个组件,不要担心。我们在到达每个步骤时会详细讨论。)最后,我们将检查运行新模型时得到的结果(图中的第 3 步)。

图 13.2 用于分割的新模型架构,以及我们将实施的模型、数据集和训练循环更新

将图 13.2 分解为步骤,我们本章的计划如下:

  1. 分割。首先,我们将学习使用 U-Net 模型进行分割的工作原理,包括新模型组件是什么,以及在我们进行分割过程中会发生什么。这是图 13.2 中的第 1 步。
  2. 更新。为了实现分割,我们需要在三个主要位置更改我们现有的代码库,如图 13.2 右侧的子步骤所示。代码在结构上与我们为分类开发的代码非常相似,但在细节上有所不同:
  1. 更新模型(步骤 2A)。我们将把一个现有的 U-Net 集成到我们的分割模型中。我们在第十二章的模型输出一个简单的真/假分类;而在本章中的模型将输出整个图像。
  2. 更改数据集(步骤 2B)。我们需要更改我们的数据集,不仅提供 CT 的片段,还要为结节提供掩模。分类数据集由围绕结节候选的 3D 裁剪组成,但我们需要收集完整的 CT 切片和用于分割训练和验证的 2D 裁剪。
  3. 调整训练循环(步骤 2C)。我们需要调整训练循环,以引入新的损失进行优化。因为我们想在 TensorBoard 中显示我们的分割结果的图像,我们还会做一些事情,比如将我们的模型权重保存到磁盘上。
  1. 结果。最后,当我们查看定量分割结果时,我们将看到我们努力的成果。

13.2 各种类型的分割

要开始,我们需要讨论不同类型的分割。对于这个项目,我们将使用语义分割,这是使用标签对图像中的每个像素进行分类的行为,就像我们在分类任务中看到的那样,例如,“熊”,“猫”,“狗”等。如果做得正确,这将导致明显的块或区域,表示诸如“所有这些像素都是猫的一部分”之类的事物。这采用标签掩模或热图的形式,用于识别感兴趣的区域。我们将有一个简单的二进制标签:真值将对应结节候选,假值表示无趣的健康组织。这部分满足了我们找到结节候选的需求,稍后我们将把它们馈送到我们的分类网络中。

在深入细节之前,我们应该简要讨论我们可以采取的其他方法来找到结节候选。例如,实例分割使用不同的标签标记感兴趣的单个对象。因此,语义分割会为两个人握手的图片使用两个标签(“人”和“背景”),而实例分割会有三个标签(“人 1”,“人 2”和“背景”),其中边界大约在握手处。虽然这对我们区分“结节 1”和“结节 2”可能有用,但我们将使用分组来识别单个结节。这种方法对我们很有效,因为结节不太可能接触或重叠。

另一种处理这类任务的方法是目标检测,它在图像中定位感兴趣的物品并在该物品周围放置一个边界框。虽然实例分割和目标检测对我们来说可能很好,但它们的实现有些复杂,我们认为它们不是你接下来学习的最好内容。此外,训练目标检测模型通常需要比我们的方法更多的计算资源。如果你感到挑战,YOLOv3 论文比大多数深度学习研究论文更有趣。² 对我们来说,语义分割就是最好的选择。

注意 当我们在本章的代码示例中进行操作时,我们将依赖您从 GitHub 检查大部分更大上下文的代码。我们将省略那些无趣或与之前章节类似的代码,以便我们可以专注于手头问题的关键。

13.3 语义分割:逐像素分类

通常,分割用于回答“这张图片中的猫在哪里?”这种问题。显然,大多数猫的图片,如图 13.3,其中有很多非猫的部分;背景中的桌子或墙壁,猫坐在上面的键盘,这种情况。能够说“这个像素是猫的一部分,这个像素是墙壁的一部分”需要基本不同的模型输出和不同的内部结构,与我们迄今为止使用的分类模型完全不同。分类可以告诉我们猫是否存在,而分割将告诉我们在哪里可以找到它。

图 13.3 分类结果产生一个或多个二进制标志,而分割产生一个掩码或热图。

如果您的项目需要区分近处猫和远处猫,或者左边的猫和右边的猫,那么分割可能是正确的方法。迄今为止我们实现的图像消费分类模型可以被看作是漏斗或放大镜,将大量像素聚焦到一个“点”(或者更准确地说,一组类别预测)中,如图 13.4 所示。分类模型提供的答案形式为“是的,这一大堆像素中有一只猫”,或者“不,这里没有猫”。当您不关心猫在哪里,只关心图像中是否有猫时,这是很好的。

图 13.4 用于分类的放大镜模型结构

重复的卷积和下采样层意味着模型从消耗原始像素开始,产生特定的、详细的检测器,用于识别纹理和颜色等内容,然后构建出更高级的概念特征检测器,用于眼睛、耳朵、嘴巴和鼻子等部位³,最终得出“猫”与“狗”的结论。由于每个下采样层后卷积的接受域不断增加,这些更高级的检测器可以利用来自输入图像越来越大区域的信息。

不幸的是,由于分割需要产生类似图像的输出,最终得到一个类似于单一分类列表的二进制标志是行不通的。正如我们从第 11.4 节回忆的那样,下采样是增加卷积层接受域的关键,也是帮助将构成图像的像素数组减少到单一类别列表的关键。请注意图 13.5,它重复了图 11.6。

图 13.5 LunaModel块的卷积架构,由两个 3×3 卷积和一个最大池组成。最终像素具有 6×6 的接受域。

在图中,我们的输入从左到右在顶部行中流动,并在底部行中继续。为了计算出影响右下角单个像素的接受域–我们可以向后推导。最大池操作有 2×2 的输入,产生每个最终输出像素。底部行中的 3×3 卷积在每个方向(包括对角线)查看一个相邻像素,因此导致 2×2 输出的卷积的总接受域为 4×4(带有右侧的“x”字符)。顶部行中的 3×3 卷积然后在每个方向添加一个额外的像素上下文,因此右下角单个输出像素的接受域是顶部左侧输入的 6×6 区域。通过来自最大池的下采样,下一个卷积块的接受域将具有双倍宽度,每次额外的下采样将再次使其加倍,同时缩小输出的大小。

如果我们希望输出与输入大小相同,我们将需要不同的模型架构。一个用于分割的简单模型可以使用重复的卷积层而没有任何下采样。在适当的填充下,这将导致输出与输入大小相同(好),但由于基于多层小卷积的有限重叠,会导致非常有限的感受野(坏)。分类模型使用每个下采样层来使后续卷积的有效范围加倍;没有这种有效领域大小的增加,每个分割像素只能考虑一个非常局部的邻域。

注意 假设 3×3 卷积,堆叠卷积的简单模型的感受野大小为 2 * L + 1,其中L是卷积层数。

四层 3×3 卷积将每个输出像素的感受野大小为 9×9。通过在第二个和第三个卷积之间插入一个 2×2 最大池,并在最后插入另一个,我们将感受野增加到…

注意 看看你是否能自己算出数学问题;完成后,回到这里查看。

… 16×16。最终的一系列 conv-conv-pool 具有 6×6 的感受野,但这发生在第一个最大池之后,这使得原始输入分辨率中的最终有效感受野为 12×12。前两个卷积层在 12×12 周围添加了总共 2 个像素的边框,总共为 16×16。

因此问题仍然是:如何在保持输入像素与输出像素 1:1 比率的同时改善输出像素的感受野?一个常见的答案是使用一种称为上采样的技术,它将以给定分辨率的图像生成更高分辨率的图像。最简单的上采样只是用一个N×N像素块替换每个像素,每个像素的值与原始输入像素相同。从那里开始,可能性变得更加复杂,选项包括线性插值和学习反卷积。

13.3.1 U-Net 架构

在我们陷入可能的上采样算法的兔子洞之前,让我们回到本章的目标。根据图 13.6,第一步是熟悉一个名为 U-Net 的基础分割算法。

图 13.6 我们将使用的分割新模型架构

U-Net 架构是一种可以产生像素级输出的神经网络设计,专为分割而发明。从图 13.6 的突出部分可以看出,U-Net 架构的图表看起来有点像字母U,这解释了名称的起源。我们还立即看到,它比我们熟悉的大多数顺序结构的分类器要复杂得多。不久我们将在图 13.7 中看到 U-Net 架构的更详细版本,并了解每个组件的具体作用。一旦我们了解了模型架构,我们就可以开始训练一个来解决我们的分割任务。

图 13.7 来自 U-Net 论文的架构,带有注释。来源:本图的基础由 Olaf Ronneberger 等人提供,来源于论文“U-Net:用于生物医学图像分割的卷积网络”,可在arxiv.org/abs/1505.04597lmb.informatik.uni-freiburg.de/people/ronneber/u-net找到。

图 13.7 中显示的 U-Net 架构是图像分割的一个早期突破。让我们看一看,然后逐步了解架构。

在这个图表中,方框代表中间结果,箭头代表它们之间的操作。架构的 U 形状来自网络操作的多个分辨率。顶部一行是完整分辨率(对我们来说是 512×512),下面一行是其一半,依此类推。数据从左上流向底部中心,通过一系列卷积和下采样,正如我们在分类器中看到的并在第八章中详细讨论的那样。然后我们再次上升,使用上采样卷积回到完整分辨率。与原始 U-Net 不同,我们将填充物,以便不会在边缘丢失像素,因此我们左右两侧的分辨率相同。

早期的网络设计已经具有这种 U 形状,人们试图利用它来解决完全卷积网络的有限感受野大小问题。为了解决这个有限的感受野大小问题,他们使用了一种设计,复制、反转并附加图像分类网络的聚焦部分,以创建一个从精细详细到宽感受野再到精细详细的对称模型。

然而,早期的网络设计存在收敛问题,这很可能是由于在下采样过程中丢失了空间信息。一旦信息到达大量非常缩小的图像,对象边界的确切位置变得更难编码,因此更难重建。为了解决这个问题,U-Net 的作者在图中心添加了我们看到的跳跃连接。我们在第八章首次接触到跳跃连接,尽管它们在这里的应用方式与 ResNet 架构中的不同。在 U-Net 中,跳跃连接将输入沿着下采样路径短路到上采样路径中的相应层。这些层接收来自 U 较低位置的宽感受野层的上采样结果以及通过“复制和裁剪”桥接连接的早期精细详细层的输出作为输入。这是 U-Net 的关键创新(有趣的是,这比 ResNet 更早)。

所有这些意味着这些最终的细节层在最佳状态下运作。它们既具有关于周围环境的更大背景信息,又具有来自第一组全分辨率层的精细详细数据。

最右侧的“conv 1x1”层位于网络头部,将通道数从 64 改变为 2(原始论文有 2 个输出通道;我们的情况下有 1 个)。这在某种程度上类似于我们在分类网络中使用的全连接层,但是逐像素、逐通道:这是一种将最后一次上采样步骤中使用的滤波器数量转换为所需的输出类别数量的方法。


PyTorch 深度学习(GPT 重译)(五)(3)https://developer.aliyun.com/article/1485250

相关文章
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(七)(4)
JavaScript 权威指南第七版(GPT 重译)(七)
24 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(3)
JavaScript 权威指南第七版(GPT 重译)(七)
33 0
|
前端开发 JavaScript Unix
JavaScript 权威指南第七版(GPT 重译)(七)(2)
JavaScript 权威指南第七版(GPT 重译)(七)
42 0
|
前端开发 JavaScript 算法
JavaScript 权威指南第七版(GPT 重译)(七)(1)
JavaScript 权威指南第七版(GPT 重译)(七)
60 0
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(4)
JavaScript 权威指南第七版(GPT 重译)(六)
90 2
JavaScript 权威指南第七版(GPT 重译)(六)(4)
|
13天前
|
前端开发 JavaScript API
JavaScript 权威指南第七版(GPT 重译)(六)(3)
JavaScript 权威指南第七版(GPT 重译)(六)
55 4
|
13天前
|
XML 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(六)(2)
JavaScript 权威指南第七版(GPT 重译)(六)
60 4
JavaScript 权威指南第七版(GPT 重译)(六)(2)
|
13天前
|
前端开发 JavaScript 安全
JavaScript 权威指南第七版(GPT 重译)(六)(1)
JavaScript 权威指南第七版(GPT 重译)(六)
27 3
JavaScript 权威指南第七版(GPT 重译)(六)(1)
|
13天前
|
存储 前端开发 JavaScript
JavaScript 权威指南第七版(GPT 重译)(五)(4)
JavaScript 权威指南第七版(GPT 重译)(五)
39 9
|
13天前
|
前端开发 JavaScript 程序员
JavaScript 权威指南第七版(GPT 重译)(五)(3)
JavaScript 权威指南第七版(GPT 重译)(五)
36 8