fast.ai 深度学习笔记(七)(4)

简介: fast.ai 深度学习笔记(七)

fast.ai 深度学习笔记(七)(3)https://developer.aliyun.com/article/1482690

数据

PATH = Path('data/carvana')
MASKS_FN = 'train_masks.csv'
META_FN = 'metadata.csv'
masks_csv = pd.read_csv(PATH/MASKS_FN)
meta_csv = pd.read_csv(PATH/META_FN)
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: 
        fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return axTRAIN_DN = 'train-128'
MASKS_DN = 'train_masks-128'
sz = 128
bs = 64
nw = 16
TRAIN_DN = 'train'
MASKS_DN = 'train_masks_png'
sz = 128
bs = 64
nw = 16
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): 
        return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): 
        return 0
x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])
y_names = np.array([
    Path(MASKS_DN)/f'{o[:-4]}_mask.png' 
    for o in masks_csv['img']
])
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)
aug_tfms = [
    RandomRotate(4, tfm_y=TfmType.CLASS),
    RandomFlip(tfm_y=TfmType.CLASS),
    RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)
]
tfms = tfms_from_model(
    esnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    ath=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denorm
x,y = next(iter(md.trn_dl))
x.shape,y.shape
'''
(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))
'''

简单的上采样

一开始,我有一个简单的上采样版本,只是为了再次向你展示非 U-net 版本。这次,我将加入一个称为 dice 指标的东西。Dice 非常类似,如你所见,与 Jaccard 或 I over U 非常相似。只是有一点小差别。基本上是交集除以并集,稍微调整了一下。我们要使用 dice 的原因是 Kaggle 竞赛使用了这个指标,而且要获得高 dice 分数比获得高准确度要困难一些,因为它真的在看正确像素与你的像素的重叠部分。但它非常相似。

在 Kaggle 竞赛中,表现良好的人得到了大约 99.6 点,而获胜者得到了大约 99.7 点。

f = resnet34
cut,lr_cut = model_meta[f]def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)
def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

这是我们的标准上采样。

class StdUpsample(nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
    def forward(self, x): 
        return self.bn(F.relu(self.conv(x)))

这一切和以前一样。

class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn, nn.ReLU(),
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2)
        )
    def forward(self,x): 
        return self.features(x)[:,0]
class UpsampleModel():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name
    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model.features)[1:]]
m_base = get_base() 
m = to_gpu(Upsample34(m_base))
models = UpsampleModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.freeze_to(1)
learn.lr_find()
learn.sched.plot()
'''
86%|█████████████████████████████████████████████████████████████          | 55/64 [00:22<00:03,  2.46it/s, loss=3.21]
'''

lr=4e-2
wd=1e-7
lrs = np.array([lr/100,lr/10,lr])/2
learn.fit(lr,1, wds=wd, cycle_len=4,use_clr=(20,8))
'''
0%|          | 0/64 [00:00<?, ?it/s]
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.216882   0.133512   0.938017   0.855221  
    1      0.169544   0.115158   0.946518   0.878381       
    2      0.153114   0.099104   0.957748   0.903353       
    3      0.144105   0.093337   0.964404   0.915084
[0.09333742126112893, 0.9644036065964472, 0.9150839788573129]
'''
learn.save('tmp')
learn.load('tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.174897   0.061603   0.976321   0.94382   
    1      0.122911   0.053625   0.982206   0.957624       
    2      0.106837   0.046653   0.985577   0.965792       
    3      0.099075   0.042291   0.986519   0.968925
[0.042291240323157536, 0.986519161670927, 0.9689251193924556]
'''

现在我们可以检查我们的 dice 指标[1:48:00]。所以你可以看到在 dice 指标上,我们在 128x128 处得到了大约 96.8。所以这不太好。

learn.save('128')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))
show_img(py[0]>0);

show_img(y[0]);

U-net(ish)[1:48:16]

所以让我们尝试 U-Net。我称之为 U-net(ish),因为通常我正在创建自己的有点 hacky 版本——尽量保持与你习惯的东西尽可能相似,并做我认为有意义的事情。所以至少有很多机会让你至少通过查看确切的网格大小来使其更加真实地成为 U-net,看看这里(左上角的卷积)大小有点下降。所以显然他们没有添加任何填充,然后有一些裁剪——有一些差异。但其中一件事是因为我想利用迁移学习——这意味着我不能完全使用 U-Net。

所以另一个重要的机会是,如果你创建了 U-Net 的下行路径,然后在末尾添加一个分类器,然后在 ImageNet 上训练它。现在你有了一个在 ImageNet 上训练过的分类器,专门设计为 U-Net 的良好骨干。然后你应该能够回来并接近赢得这个旧竞赛(实际上并不是很旧——是一个相当新的竞赛)。因为以前不存在这种预训练网络。但是如果你想一下 YOLO v3 是如何做的,基本上就是这样。他们创建了一个 DarkNet,他们在 ImageNet 上预训练了它,然后他们将其用作边界框的基础。所以,再次强调这种不仅为分类而设计而且为其他事物而设计的预训练的想法——这是迄今为止没有人做过的事情。但正如我们所展示的,你现在可以用 25 美元在三小时内训练 ImageNet。如果社区中的人们对此感兴趣,希望我也能提供帮助,如果你愿意,我可以帮助你设置并给我一个脚本,我可能可以为你运行它。但目前我们还没有。所以我们将使用 ResNet。

class SaveFeatures():
    features=None
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): 
        self.features = output
    def remove(self): 
        self.hook.remove()

所以我们基本上要从get_base开始[1:50:37]。Base 是我们的基础网络,这在第一部分中已经定义过了。

所以get_base将调用f是什么,fresnet34。所以我们将获取我们的 ResNet34 并且cut_model是我们的卷积网络构建器做的第一件事。它基本上删除了自适应池化之后的所有内容,这样我们就得到了 ResNet34 的骨干。所以get_base将给我们返回 ResNet34 的骨干。

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    def close(self):
        for sf in self.sfs: 
            sf.remove()
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name
    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

然后我们将把那个 ResNet34 主干转换成一个,我称之为 Unet34。因此,它将保存我们传入的 ResNet,然后我们将使用一个前向钩子,就像以前一样,在第 2、4、5 和 6 个块处保存结果,这些块是每个步幅 2 卷积之前的层。然后我们将创建一堆我们称之为UnetBlock的东西。我们需要告诉UnetBlock有多少东西来自我们要上采样的上一层,有多少来自交叉路径,然后我们想要输出多少。来自上一层的数量完全由基础网络定义——无论下行路径是什么,我们都需要那么多层。这有点尴尬。实际上我们这里的一个硕士学生,Kerem,实际上创建了一个叫做 DynamicUnet 的东西,你可以在fastai.model.DynamicUnet中找到,它实际上为你计算这一切,并自动从你的基础模型创建整个 Unet。它仍然有一些小问题,我想要修复。视频发布时,它肯定会正常工作,我至少会有一个展示如何使用它的笔记本,可能还有一个额外的视频。但现在你只能自己去做。一旦你有了一个 ResNet,你可以输入它的名称,它会打印出层。你可以看到每个块中有多少激活。或者你可以让它自动为每个块打印出来。无论如何,我只是手动做了这个。

所以 UnetBlock 的工作原理是这样的:

  • up_in:从上一层传入的数量
  • x_in:从下行路径传入的数量(因此x
  • n_out:我们想要输出的数量

现在我要做的是,然后我说,好的,我们将从上行路径创建一定数量的卷积,从交叉路径创建一定数量的卷积,所以我将它们连接在一起,所以让我们将我们想要的数量除以 2。因此,我们将让我们的交叉卷积从交叉路径中取出并除以 2(n_out//2)。然后上行路径将是ConvTranspose2d,因为我们想要增加/上采样。同样在这里,我们将我们想要的数量除以 2(up_out),然后最后,我只是将它们连接在一起。

所以我有一个上升样本,我有一个交叉卷积,我可以将这两者连接在一起。这就是 UnetBlock 的全部内容。所以这实际上是一个相当容易创建的模块。

然后在我的前向路径中,我需要将上升路径和交叉路径传递给 UnetBlock 的前向方法。上升路径只是到目前为止的任何事情。但是交叉路径是在下降过程中存储的激活。因此,当我上升时,我首先需要的是最后一组保存的特征。随着我逐渐向上走得更远,最终是第一组特征。

有一些更多的技巧可以让这个变得更好一点,但这已经是一个很好的东西了。所以简单的上采样方法看起来很糟糕,dice 值为 0.968。一个 Unet,除了现在我们有了这些 UnetBlocks 之外,其他一切都相同,dice 值为…

m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.summary()
'''
OrderedDict([('Conv2d-1',
              OrderedDict([('input_shape', [-1, 3, 128, 128]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', 9408)])),
             ('BatchNorm2d-2',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-3',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 64, 64]),
                           ('nb_params', 0)])),
             ('MaxPool2d-4',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-5',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-6',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-7',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-8',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-9',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-10',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-11',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-12',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-13',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-14',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-15',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-16',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-17',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-18',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-19',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-20',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-21',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-22',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 36864)])),
             ('BatchNorm2d-23',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('trainable', False),
                           ('nb_params', 128)])),
             ('ReLU-24',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('BasicBlock-25',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 64, 32, 32]),
                           ('nb_params', 0)])),
             ('Conv2d-26',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 73728)])),
             ('BatchNorm2d-27',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-28',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-29',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-30',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('Conv2d-31',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 8192)])),
             ('BatchNorm2d-32',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-33',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-34',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-35',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-36',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-37',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-38',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-39',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-40',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-41',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-42',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-43',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-44',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-45',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-46',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-47',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-48',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-49',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-50',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-51',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-52',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 147456)])),
             ('BatchNorm2d-53',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', False),
                           ('nb_params', 256)])),
             ('ReLU-54',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('BasicBlock-55',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('nb_params', 0)])),
             ('Conv2d-56',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 294912)])),
             ('BatchNorm2d-57',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-58',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-59',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-60',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('Conv2d-61',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 32768)])),
             ('BatchNorm2d-62',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-63',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-64',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-65',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-66',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-67',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-68',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-69',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-70',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-71',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-72',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-73',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-74',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-75',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-76',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-77',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-78',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-79',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-80',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-81',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-82',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-83',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-84',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-85',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-86',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-87',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-88',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-89',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-90',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-91',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-92',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-93',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-94',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-95',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-96',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 589824)])),
             ('BatchNorm2d-97',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', False),
                           ('nb_params', 512)])),
             ('ReLU-98',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('BasicBlock-99',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('Conv2d-100',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1179648)])),
             ('BatchNorm2d-101',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-102',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-103',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-104',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('Conv2d-105',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 131072)])),
             ('BatchNorm2d-106',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-107',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-108',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-109',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-110',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-111',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-112',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-113',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-114',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-115',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-116',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-117',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-118',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('Conv2d-119',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 2359296)])),
             ('BatchNorm2d-120',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('trainable', False),
                           ('nb_params', 1024)])),
             ('ReLU-121',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('BasicBlock-122',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 512, 4, 4]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-123',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 128, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 262272)])),
             ('Conv2d-124',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 128, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 32896)])),
             ('BatchNorm2d-125',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-126',
              OrderedDict([('input_shape', [-1, 512, 4, 4]),
                           ('output_shape', [-1, 256, 8, 8]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-127',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-128',
              OrderedDict([('input_shape', [-1, 128, 16, 16]),
                           ('output_shape', [-1, 128, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 16512)])),
             ('BatchNorm2d-129',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-130',
              OrderedDict([('input_shape', [-1, 256, 8, 8]),
                           ('output_shape', [-1, 256, 16, 16]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-131',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 128, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-132',
              OrderedDict([('input_shape', [-1, 64, 32, 32]),
                           ('output_shape', [-1, 128, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 8320)])),
             ('BatchNorm2d-133',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-134',
              OrderedDict([('input_shape', [-1, 256, 16, 16]),
                           ('output_shape', [-1, 256, 32, 32]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-135',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 128, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 131200)])),
             ('Conv2d-136',
              OrderedDict([('input_shape', [-1, 64, 64, 64]),
                           ('output_shape', [-1, 128, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 8320)])),
             ('BatchNorm2d-137',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('trainable', True),
                           ('nb_params', 512)])),
             ('UnetBlock-138',
              OrderedDict([('input_shape', [-1, 256, 32, 32]),
                           ('output_shape', [-1, 256, 64, 64]),
                           ('nb_params', 0)])),
             ('ConvTranspose2d-139',
              OrderedDict([('input_shape', [-1, 256, 64, 64]),
                           ('output_shape', [-1, 1, 128, 128]),
                           ('trainable', True),
                           ('nb_params', 1025)]))])
'''
[o.features.size() for o in m.sfs]
'''
[torch.Size([3, 64, 64, 64]),
 torch.Size([3, 64, 32, 32]),
 torch.Size([3, 128, 16, 16]),
 torch.Size([3, 256, 8, 8])]
'''
learn.freeze_to(1)learn.lr_find()
learn.sched.plot()
''' 0%|                                                                                           | 0/64 [00:00<?, ?it/s]92%|█████████████████████████████████████████████████████████████████▍     | 59/64 [00:22<00:01,  2.68it/s, loss=2.45]
'''

lr=4e-2
wd=1e-7
lrs = np.array([lr/100,lr/10,lr])
learn.fit(lr,1,wds=wd,cycle_len=8,use_clr=(5,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice           
    0      0.12936    0.03934    0.988571   0.971385  
    1      0.098401   0.039252   0.990438   0.974921        
    2      0.087789   0.02539    0.990961   0.978927        
    3      0.082625   0.027984   0.988483   0.975948        
    4      0.079509   0.025003   0.99171    0.981221        
    5      0.076984   0.022514   0.992462   0.981881        
    6      0.076822   0.023203   0.992484   0.982321        
    7      0.075488   0.021956   0.992327   0.982704
[0.021955982234979434, 0.9923273126284281, 0.9827044502137199]
'''
learn.save('128urn-tmp')
learn.load('128urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))
'''
0%|          | 0/64 [00:00<?, ?it/s]
epoch      trn_loss   val_loss   <lambda>   dice            
    0      0.073786   0.023418   0.99297    0.98283   
    1      0.073561   0.020853   0.992142   0.982725        
    2      0.075227   0.023357   0.991076   0.980879        
    3      0.074245   0.02352    0.993108   0.983659        
    4      0.073434   0.021508   0.993024   0.983609        
    5      0.073092   0.020956   0.993188   0.983333        
    6      0.073617   0.019666   0.993035   0.984102        
    7      0.072786   0.019844   0.993196   0.98435         
    8      0.072256   0.018479   0.993282   0.984277        
    9      0.072052   0.019479   0.993164   0.984147        
    10     0.071361   0.019402   0.993344   0.984541        
    11     0.070969   0.018904   0.993139   0.984499        
    12     0.071588   0.018027   0.9935     0.984543        
    13     0.070709   0.018345   0.993491   0.98489         
    14     0.072238   0.019096   0.993594   0.984825        
    15     0.071407   0.018967   0.993446   0.984919        
    16     0.071047   0.01966    0.993366   0.984952        
    17     0.072024   0.018133   0.993505   0.98497         
    18     0.071517   0.018464   0.993602   0.985192        
    19     0.070109   0.018337   0.993614   0.9852
[0.018336569653853538, 0.9936137114252362, 0.9852004420189631]
'''

0.985!这就像我们将错误减半,其他一切完全相同。而且更重要的是,你可以看一下。

learn.save('128urn-0')
learn.load('128urn-0')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

与我们的非 Unet 等效物相比,这实际上看起来有点像汽车,后者只是一个斑点。因为试图通过下行和上行路径来做这个——这只是要求太多了。而当我们实际上在每个点提供下行路径像素时,它实际上可以开始创建一些类似汽车的东西。

show_img(py[0]>0);

show_img(y[0]);

最后,我们将执行 m.close 以删除占用 GPU 内存的sfs.features

m.close()

512x512 [1:56:26]

转到较小的批量大小,更高的大小

sz=512
bs=16
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS, 
    aug_tfms=aug_tfms
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)
denorm = md.trn_ds.denormm_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]
learn.freeze_to(1)
learn.load('128urn-0')
learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))
'''
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.071421   0.02362    0.996459   0.991772  
    1      0.070373   0.014013   0.996558   0.992602          
    2      0.067895   0.011482   0.996705   0.992883          
    3      0.070653   0.014256   0.996695   0.992771          
    4      0.068621   0.013195   0.996993   0.993359
[0.013194938530288046, 0.996993034604996, 0.993358936574724]
'''

你可以看到 Dice 系数真的在上升[1:56:30]。所以请注意,我正在加载网络的 128x128 版本。我们再次使用渐进式调整大小的技巧,这样我们得到了 0.993。

learn.save('512urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
learn.load('512urn-tmp')
learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice              
    0      0.06605    0.013602   0.997      0.993014  
    1      0.066885   0.011252   0.997248   0.993563          
    2      0.065796   0.009802   0.997223   0.993817          
    3      0.065089   0.009668   0.997296   0.993744          
    4      0.064552   0.011683   0.997269   0.993835          
    5      0.065089   0.010553   0.997415   0.993827          
    6      0.064303   0.009472   0.997431   0.994046          
    7      0.062506   0.009623   0.997441   0.994118
[0.009623114736602894, 0.9974409020136273, 0.9941179137381296]
'''

然后解冻以达到 0.994。

learn.save('512urn')
learn.load('512urn')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

你可以看到,现在看起来很不错。

show_img(py[0]>0);

show_img(y[0]);

m.close()

1024x1024 [1:56:53]

将批量大小降至 4,大小为 1024。

sz=1024
bs=4
tfms = tfms_from_model(
    resnet34, sz, 
    crop_type=CropType.NO, 
    tfm_y=TfmType.CLASS
)
datasets = ImageData.get_ds(
    MatchedFilesDataset, 
    (trn_x,trn_y), 
    (val_x,val_y), 
    tfms, 
    path=PATH
)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
denorm = md.trn_ds.denormm_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5),dice]

加载我们刚刚保存的 512。

learn.load('512urn')
learn.freeze_to(1)
learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.007656   0.008155   0.997247   0.99353   
    1      0.004706   0.00509    0.998039   0.995437
[0.005090427414942828, 0.9980387706605215, 0.995437301104031]
'''

这让我们达到了 0.995。

learn.save('1024urn-tmp')
learn.load('1024urn-tmp')
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/200,lr/30,lr])
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.005688   0.006135   0.997616   0.994616  
    1      0.004412   0.005223   0.997983   0.995349             
    2      0.004186   0.004975   0.99806    0.99554              
    3      0.004016   0.004899   0.99812    0.995627
[0.004898778487196458, 0.9981196409180051, 0.9956271404784823]
'''
learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))
'''
epoch      trn_loss   val_loss   <lambda>   dice                 
    0      0.004169   0.004962   0.998049   0.995517  
    1      0.004022   0.004595   0.99823    0.995818             
    2      0.003772   0.004497   0.998215   0.995916             
    3      0.003618   0.004435   0.998291   0.995991
[0.004434524739663753, 0.9982911745707194, 0.9959913929776539]
'''

解冻将我们带到…我们将称之为 0.996。

learn.sched.plot_loss()

learn.save('1024urn')
learn.load('1024urn')
x,y = next(iter(md.val_dl))
py = to_np(learn.model(V(x)))

正如你所看到的,实际上看起来很不错[1:57:17]。在准确性方面,99.82%。你可以看到这看起来像是你可以用来裁剪的东西。我认为,在这一点上,我们可以做一些微小的调整来达到 0.997,但真正的关键是,我认为,也许只需要做一些平滑处理或一点后处理。你可以去看看 Carvana 获奖者的博客,看看其中的一些技巧,但正如我所说,我们目前的 0.996 和获奖者得到的 0.997 之间的差距并不大。所以实际上,U-Net 基本上解决了这个问题。

show_img(py[0]>0);

show_img(y[0]);

回到边界框[1:58:15]

好的,就是这样。我想要提到的最后一件事是现在回到边界框,因为你可能还记得,我说我们的边界框模型在小物体上仍然表现不佳。所以希望你能猜到我接下来要做什么,那就是对于边界框模型,记得我们在不同的网格单元中输出了模型的输出。那些较早的具有较小网格大小的输出并不好。我们该如何修复呢?用 U-Net!让我们有一个带有交叉连接的向上路径。然后我们将使用 U-Net,然后从中输出。因为现在那些更精细的网格单元具有该路径的所有信息,以及该路径、该路径和该路径的信息。当然,这是深度学习,这意味着你不能写一篇论文说我们只是用 U-Net 来处理边界框。你必须发明一个新词,所以这被称为特征金字塔网络或 FPNs。这在 RetinaNet 论文中使用过,它是在早期关于 FPNs 的论文中创建的。如果我记得正确的话,他们确实简要引用了 U-Net 论文,但他们似乎让它听起来像是这个模糊地稍微相关的东西,也许有些人可能认为稍微有用。但实际上,FPNs 就是 U-Nets。

我没有实现它来展示给你,但这将是一件有趣的事情,也许对于我们中的一些人来尝试,我知道一些学生一直在尝试在论坛上使其良好运行。所以是的,尝试一下是有趣的事情。所以我认为在这堂课之后要看的一些事情,以及我提到的其他事情,可能是玩玩 FPNs,也可能尝试一下 Kerem 的 DynamicUnet。它们都是值得一看的有趣的东西。

所以你们现在已经经历了我对你们讲解的 14 堂课。对此我感到抱歉。谢谢你们忍受我。我认为你们会发现很难找到其他人对神经网络训练和实践了解得像你们这样多。你们很容易高估其他人的能力,低估自己的能力。所以我想说的是,请继续练习。因为现在没有每个星期一晚上都有我在这里让你们回来了。很容易失去动力。所以找到方法保持下去。组织一个学习小组,一个读书小组,或者和朋友们一起做项目,或者做一些不仅仅是决定我要继续做 X 的事情。除非你是那种超级有动力的人,每当你决定做某事,它就会发生。那不是我。我知道,要让事情发生,我必须说“是的,大卫。十月份,我绝对会教那门课程”,然后我就得开始写一些材料。这是我让事情发生的唯一方法。所以我们在论坛上有一个很棒的社区。如果有人有想法让它变得更好,请告诉我。如果你认为你可以帮忙,如果你想创建一些新的论坛或以某种不同的方式进行管理,或者其他什么的,只要告诉我。你可以随时私信我,GitHub 上也有很多项目正在进行中——很多东西。所以我希望能在其他地方再见到你们,非常感谢你们加入我的旅程。

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
1天前
|
存储 人工智能 Linux
|
5天前
|
机器学习/深度学习 人工智能 算法
【AI】从零构建深度学习框架实践
【5月更文挑战第16天】 本文介绍了从零构建一个轻量级的深度学习框架tinynn,旨在帮助读者理解深度学习的基本组件和框架设计。构建过程包括设计框架架构、实现基本功能、模型定义、反向传播算法、训练和推理过程以及性能优化。文章详细阐述了网络层、张量、损失函数、优化器等组件的抽象和实现,并给出了一个基于MNIST数据集的分类示例,与TensorFlow进行了简单对比。tinynn的源代码可在GitHub上找到,目前支持多种层、损失函数和优化器,适用于学习和实验新算法。
61 2
|
7天前
|
机器学习/深度学习 人工智能 算法
AI大咖说-关于深度学习的一点思考
周志华教授探讨深度学习的成效,指出其关键在于大量数据、强大算力和训练技巧。深度学习依赖于函数可导性、梯度下降及反向传播算法,尽管硬件和数据集有显著进步,但核心原理保持不变。深度意味着增加模型复杂度,相较于拓宽,加深网络更能增强泛函表达能力,促进表示学习,通过逐层加工处理和内置特征变换实现抽象语义理解。周志华教授还提到了非神经网络的深度学习方法——深度森林。5月更文挑战第12天
31 5
|
7天前
|
机器学习/深度学习 人工智能 算法
构建高效AI系统:深度学习优化技术解析
【5月更文挑战第12天】 随着人工智能技术的飞速发展,深度学习已成为推动创新的核心动力。本文将深入探讨在构建高效AI系统中,如何通过优化算法、调整网络结构及使用新型硬件资源等手段显著提升模型性能。我们将剖析先进的优化策略,如自适应学习率调整、梯度累积技巧以及正则化方法,并讨论其对模型训练稳定性和效率的影响。文中不仅提供理论分析,还结合实例说明如何在实际项目中应用这些优化技术。
|
7天前
|
机器学习/深度学习 敏捷开发 人工智能
吴恩达 x Open AI ChatGPT ——如何写出好的提示词视频核心笔记
吴恩达 x Open AI ChatGPT ——如何写出好的提示词视频核心笔记
229 0
|
7天前
|
机器学习/深度学习 人工智能 算法
【AI 初识】讨论深度学习和机器学习之间的区别
【5月更文挑战第3天】【AI 初识】讨论深度学习和机器学习之间的区别
|
7天前
|
机器学习/深度学习 自然语言处理 PyTorch
fast.ai 深度学习笔记(三)(4)
fast.ai 深度学习笔记(三)(4)
27 0
|
7天前
|
机器学习/深度学习 算法 PyTorch
fast.ai 深度学习笔记(三)(3)
fast.ai 深度学习笔记(三)(3)
34 0
|
7天前
|
机器学习/深度学习 编解码 自然语言处理
fast.ai 深度学习笔记(三)(2)
fast.ai 深度学习笔记(三)(2)
38 0
|
7天前
|
机器学习/深度学习 PyTorch 算法框架/工具
fast.ai 深度学习笔记(三)(1)
fast.ai 深度学习笔记(三)(1)
41 0

热门文章

最新文章