(学习笔记)U-net++代码解读

简介: python: 3.10U-net++结构图

python: 3.10

U-net++结构图


017273df32ed463d877c635186e43273.png

遇到的问题

1. albumentations包安装的问题

最开始的问题是找不到源,试了好多个命令都没用。最后通过github找到了解决办法。

使用的aniconda的命令。

conda install -c conda-forge albumentations

2. AttributeError: module ‘albumentations.augmentations.transforms’ has no attribute ‘RandomRotate90’

只需重新impor talbumentations即可。

3. torch没安cuda

如果已经安装了cuda可以忽略,题主没有GPU因此需要使用CPU训练。

解决办法: 将代码train和val代码里面的所有的.cuda()更改成.cpu(),这样就在CPU上跑起来了。


05249e7912944dcea8e73c36837bd290.png

代码解读(主要解决py语法问题)

首先找到train.py的入口main函数,如图所示打断点。

e297ec03376249829bd77a393a86aa8b.png

1.读取配置文件

跳进这个函数。

5f274fa1e1cd41b99d7e583b2a9034bd.png

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)') # 指定网络的名字,就是U-net++
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')# 指定迭代次数
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 16)')#指定batch_size
    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
                        choices=ARCH_NAMES,
                        help='model architecture: ' +
                        ' | '.join(ARCH_NAMES) +
                        ' (default: NestedUNet)')  # 指定网络架构
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=96, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=96, type=int,
                        help='image height')
    # loss
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                        ' | '.join(LOSS_NAMES) +
                        ' (default: BCEDiceLoss)')
    config = parser.parse_args()
    return config

这段代码的作用是为了配置模型、损失函数、数据、优化器的各种参数,给每个参数定义一个名字,给一个默认值。

比如: parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run')# 指定迭代次数这个的作用是,添加一个epochs(迭代次数)参数,默认值是100,数据类型为int,help:相当与一个说明语句。

这样定义好参数后,便于后续统一管理和了解参数情况。并且py提供的这个argparse类,会根据自己定义的参数生成说明文档(因为参数是自己定义的,传参有误我们也得知道哪里错了),报错也能知道是哪里错了。

9d9c3d02b21641b692e7918800525c4a.png

最后的cofig相当于一个字典结构。

2.os.makedirs()

209行的含义是给’name’的值命名,%s%s为占位符,用来存放%连接的名字。

a9b5fb5389434e8ca3a4867f798c8a0e.png

209行运行前后对比:

028658110c7d498b8994d85303468db2.png

210行运行结果为:

在该path下创建一个目录,exist_ok=True的含义是:若目录存在也不会报错。

3.yaml.dump()

b1cd2429c24d4fc393844bce27423933.png

就是打开该path的文件,然后将config字典的内容存到该文件中。部分结果如下:


       0db4cfaf12924822bba173313bfb7683.png                        

4.losses._dict_()

54ee88dbefa44f57bb9ad41110ac264b.png

在该项目中,有一个losses.py这个文件,里面有相应的类实现。首先losses.dict[xxx]是从losses.py这个文件里扫描xxx这个类,并返回相应的属性值。即创建了一个相应的criterion对象。

5.cudnn.benchmark

    cudnn.benchmark = True

大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。

此代码为GPU的优化选项,详情参考。

6.filter()

    params = filter(lambda p: p.requires_grad, model.parameters())

filter()有过滤的功能,其中第一个参数为True(表示需要梯度更新),第二个参数会返回该模型的所有可学习的参数,并过滤出仅仅需要梯度优化的参数。将这些参数以列表的形式赋给params。

7.glob+basename+splitext

3091e32d167b49078e398d196c5bbda9.png

261行的代码:

glob()获取指定路径下符合特定条件的文件路径列表。join()这个函数就是拼接好文件路径。最后将所有符合条件的文件路径返回给img_ids。

262行的代码:

首先是遍历261行获取好的文件路径,basename是获取p路径的基本名称,即将路径中的目录部分去除,只保留文件名部分。splitext()函数将文件名和扩展名分开,并返回。[0]的含义是取文件名赋值给img_ids。


最终img_ids只包含文件名(不包括路径与扩展名)。

8. train_test_split()

3dc3096077244996aaa9e97d0a73876f.png将img_ids(文件名)按0.2的比例分成训练集和测试集。

9. Compose()

数据增强部分代码(注释在下面)。

    train_transform = Compose([
        # transforms.RandomRotate90(),
        # transforms.Flip(),
        albu.RandomRotate90(),
        albu.Flip(),
        OneOf([
            transforms.HueSaturationValue(),  #随机改变输入图像的色调、饱和度和值
            transforms.RandomBrightness(),    #随机改变亮度
            transforms.RandomContrast(),      #随机改变输入图像的对比度
        ], p=1),#按照归一化的概率选择执行哪一个
        # transforms.Resize(config['input_h'], config['input_w']),
        albu.Resize(config['input_h'], config['input_w']),   # 采用缩放的形式将图像变到期望大小
        transforms.Normalize(),
    ])
 #验证集就不增强了
    val_transform = Compose([
        # transforms.Resize(config['input_h'], config['input_w']),
        albu.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

10. Dateset()

定义好本地路径相关的信息。

041aeaa415c148569a21a17442900101.png

11. DataLoader()

创建一个用于训练数据的数据加载器。

代码注释如下:

70b1c807e16d444c93c55f89bcd163eb.png

12.OrderedDict()

该函数会定义一个有序字典,方便后续打印。

d75071e43c804b6e9c1e70720ee1150a.png

13.train()

将字典,训练集,模型,损失函数,优化器都传进去。

fff876eb4daf41ef924434fdea67edb5.png

13.1 AverageMeter()

54908ae6d94f4a23be76839fee30ea74.png

ffc5dcdfcac04052a98b5b3f5aec2886.png

13.2 tqdm()

创建一个进度条。总共为train_loader的长度。

2d820e77a16c47c187237127d44f4b99.png

13.3 获取数据

通过train_loader获取输入数据和目标值。

ab11eb6082024fbfb55ff99bebea33d6.png

13.4 model()开始训练

e76346cf58e44c2989ccd0545b62bda2.png

13.4.1 前向传播

从此NestedUNet模型中的,forward开始运行:

class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):  # 3,32,32
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)   # 输入是3,输出是32
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out
class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)#scale_factor:放大的倍数  插值
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
        output = self.final(x0_4)
        return output
class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) #3,32,32  输入3,中间32,输出32(特征图个数)
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])  # 32,64,64
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
    def forward(self, input):
        print('input:',input.shape)
        x0_0 = self.conv0_0(input) # 8,32,96,96
        print('x0_0:',x0_0.shape)
        x1_0 = self.conv1_0(self.pool(x0_0))    # 下采样
        print('x1_0:',x1_0.shape)
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))#00和10维度不一样,因此要将x1_0升维、拼接完之后为8,96,96,96再执行卷积,8,32,96,96
        print('x0_1:',x0_1.shape)
        x2_0 = self.conv2_0(self.pool(x1_0))
        print('x2_0:',x2_0.shape)
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) # x2_0up后(8,128,48,48),拼接后(8,192,48,48),卷积后(8,64,48,48)
        print('x1_1:',x1_1.shape)
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        print('x0_2:',x0_2.shape)
        x3_0 = self.conv3_0(self.pool(x2_0))
        print('x3_0:',x3_0.shape)
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        print('x2_1:',x2_1.shape)
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        print('x1_2:',x1_2.shape)
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        print('x0_3:',x0_3.shape)
        x4_0 = self.conv4_0(self.pool(x3_0))
        print('x4_0:',x4_0.shape)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        print('x3_1:',x3_1.shape)
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        print('x2_2:',x2_2.shape)
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        print('x1_3:',x1_3.shape)
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        print('x0_4:',x0_4.shape)    #(8,32,96,96)
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output  #(8,1,96,96)

网络流程图如下所示:


image.jpeg


13.5 criterion()

进入损失函数。

b550f207738641b6aadc1af9378290f7.png

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5    # 平滑因子,防止分母为0
        input = torch.sigmoid(input)    #  将输入的对数几率转化为概率值。
        num = target.size(0)
        input = input.view(num, -1)    # 将input转化为2维张量。
        target = target.view(num, -1)
        intersection = (input * target)   #input是概率,target是0/1,因此计算结果是预测的每个样本的正样本像素的概率。
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) # 预测结果和真实标签之间的相似度
        dice = 1 - dice.sum() / num   # 计算了平均Dice系数,并将其减去1,以得到Dice损失
        return 0.5 * bce + dice

13.6 iou_score()

计算iou,交并比。

def iou_score(output, target):
    smooth = 1e-5  # 平滑值
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5   # 预测值大于0.5为True
    target_ = target > 0.5  # 目标值大于0.5为True
    intersection = (output_ & target_).sum()   # 交集
    union = (output_ | target_).sum()          # 并集
    return (intersection + smooth) / (union + smooth)  # iou

13.7 方向传播,梯度优化

image.png

13.8 更新进度条

979b6b750f554a589ccbbf3ff753cf07.png                              

validate() 和train类似跳过

14. scheduler.step()

因为本文使用的余弦退火LR学习率更新方式,调度器通过余弦函数来调整学习率。

190e7a6bbe0949868a7243f7844218b1.png

目录
相关文章
|
2月前
|
开发框架 网络协议 .NET
深入.net框架
深入.net框架
12 0
|
5月前
|
开发框架 安全 C#
掌握.NET基础知识(一)
掌握.NET基础知识(一)
52 0
|
存储 XML SQL
.NET、C#基础知识
.NET、C#基础知识
112 0
|
开发框架 .NET 容器
.NET基础2
引用类型有哪些方法比较相等性呢?栈集合和队列集合有啥子区别呢?泛型又有什么东西呢?
80 0
|
存储 JSON 安全
.NET 基础-3
服务端和客户端之间要传送的自定义数据类型
200 0
|
存储 开发框架 安全
.NET 基础知识
修饰符有什么作用呢?它是什么东西呢?
116 0
.NET 基础知识
.NET简谈面“.NET技术”向接口编程
  过程式的开发方式已逐渐退出大众的眼线,随之而来的是各种各样的高抽象的开发模式;我们不得不承认在没有设计模式的时候,我们很难总结出有价值的开发模型,便于以后重复使用和推广;面向对象的流行,让我们开发人员重新站在一个高的起点来看待软件模型,抽象固然是好事,但是也给初学者带来了迷惑,将软件中的东西都想成很简单的封装,我们只需要调用就行,这样越来越多的开发人员开始慢慢的往上浮,有一定编程经验和感触的人,能够明白我所说的浮,也算是给初学者提个醒吧。
905 0
|
.NET 程序员
一起谈.NET技术,.NET 4.0里异常处理的新机制
  前几天,有一个朋友问我为什么在.NET里不能捕捉(catch)到一些异常了,而且在调试器里也捕捉不到。研究了一下,是.NET 4.0里新的异常处理机制捣的鬼。   在.NET 4.0之后,CLR将会区别出一些异常(都是SEH异常),将这些异常标识为破坏性异常(Corrupted State Exception)。
755 0
一起谈.NET技术,20条.NET编码习惯
1、不要硬编string/ numeric,可以使用一些常量代替。 (提高可读性) int Count;Count = 100;private static const int ZERO  =  0;if(  Count  ==  ZERO ){// 执行一些操作} 2、对于字符串比较-使用String. Empty ,而不是""。
789 0
|
SQL C# 数据库
一起谈.NET技术,.NET远程处理框架详解
  第1章系统总体结构   1.1 总体结构   系统实现需要部署服务器端的远程对象(即一个DbServerLibrary.dll),服务器端要注册通道和该远程对象。客户端要实现一个本地查询的服务器,同时根据SQL解析的结果向各个服务器发送命令,并将结果显示在客户端界面,服务器端可以接受并显示相应的命令。
895 0