(学习笔记)U-net++代码解读

简介: python: 3.10U-net++结构图

python: 3.10

U-net++结构图


017273df32ed463d877c635186e43273.png

遇到的问题

1. albumentations包安装的问题

最开始的问题是找不到源,试了好多个命令都没用。最后通过github找到了解决办法。

使用的aniconda的命令。

conda install -c conda-forge albumentations

2. AttributeError: module ‘albumentations.augmentations.transforms’ has no attribute ‘RandomRotate90’

只需重新impor talbumentations即可。

3. torch没安cuda

如果已经安装了cuda可以忽略,题主没有GPU因此需要使用CPU训练。

解决办法: 将代码train和val代码里面的所有的.cuda()更改成.cpu(),这样就在CPU上跑起来了。


05249e7912944dcea8e73c36837bd290.png

代码解读(主要解决py语法问题)

首先找到train.py的入口main函数,如图所示打断点。

e297ec03376249829bd77a393a86aa8b.png

1.读取配置文件

跳进这个函数。

5f274fa1e1cd41b99d7e583b2a9034bd.png

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)') # 指定网络的名字,就是U-net++
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')# 指定迭代次数
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 16)')#指定batch_size
    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
                        choices=ARCH_NAMES,
                        help='model architecture: ' +
                        ' | '.join(ARCH_NAMES) +
                        ' (default: NestedUNet)')  # 指定网络架构
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=96, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=96, type=int,
                        help='image height')
    # loss
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                        ' | '.join(LOSS_NAMES) +
                        ' (default: BCEDiceLoss)')
    config = parser.parse_args()
    return config

这段代码的作用是为了配置模型、损失函数、数据、优化器的各种参数,给每个参数定义一个名字,给一个默认值。

比如: parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run')# 指定迭代次数这个的作用是,添加一个epochs(迭代次数)参数,默认值是100,数据类型为int,help:相当与一个说明语句。

这样定义好参数后,便于后续统一管理和了解参数情况。并且py提供的这个argparse类,会根据自己定义的参数生成说明文档(因为参数是自己定义的,传参有误我们也得知道哪里错了),报错也能知道是哪里错了。

9d9c3d02b21641b692e7918800525c4a.png

最后的cofig相当于一个字典结构。

2.os.makedirs()

209行的含义是给’name’的值命名,%s%s为占位符,用来存放%连接的名字。

a9b5fb5389434e8ca3a4867f798c8a0e.png

209行运行前后对比:

028658110c7d498b8994d85303468db2.png

210行运行结果为:

在该path下创建一个目录,exist_ok=True的含义是:若目录存在也不会报错。

3.yaml.dump()

b1cd2429c24d4fc393844bce27423933.png

就是打开该path的文件,然后将config字典的内容存到该文件中。部分结果如下:


       0db4cfaf12924822bba173313bfb7683.png                        

4.losses._dict_()

54ee88dbefa44f57bb9ad41110ac264b.png

在该项目中,有一个losses.py这个文件,里面有相应的类实现。首先losses.dict[xxx]是从losses.py这个文件里扫描xxx这个类,并返回相应的属性值。即创建了一个相应的criterion对象。

5.cudnn.benchmark

    cudnn.benchmark = True

大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。

此代码为GPU的优化选项,详情参考。

6.filter()

    params = filter(lambda p: p.requires_grad, model.parameters())

filter()有过滤的功能,其中第一个参数为True(表示需要梯度更新),第二个参数会返回该模型的所有可学习的参数,并过滤出仅仅需要梯度优化的参数。将这些参数以列表的形式赋给params。

7.glob+basename+splitext

3091e32d167b49078e398d196c5bbda9.png

261行的代码:

glob()获取指定路径下符合特定条件的文件路径列表。join()这个函数就是拼接好文件路径。最后将所有符合条件的文件路径返回给img_ids。

262行的代码:

首先是遍历261行获取好的文件路径,basename是获取p路径的基本名称,即将路径中的目录部分去除,只保留文件名部分。splitext()函数将文件名和扩展名分开,并返回。[0]的含义是取文件名赋值给img_ids。


最终img_ids只包含文件名(不包括路径与扩展名)。

8. train_test_split()

3dc3096077244996aaa9e97d0a73876f.png将img_ids(文件名)按0.2的比例分成训练集和测试集。

9. Compose()

数据增强部分代码(注释在下面)。

    train_transform = Compose([
        # transforms.RandomRotate90(),
        # transforms.Flip(),
        albu.RandomRotate90(),
        albu.Flip(),
        OneOf([
            transforms.HueSaturationValue(),  #随机改变输入图像的色调、饱和度和值
            transforms.RandomBrightness(),    #随机改变亮度
            transforms.RandomContrast(),      #随机改变输入图像的对比度
        ], p=1),#按照归一化的概率选择执行哪一个
        # transforms.Resize(config['input_h'], config['input_w']),
        albu.Resize(config['input_h'], config['input_w']),   # 采用缩放的形式将图像变到期望大小
        transforms.Normalize(),
    ])
 #验证集就不增强了
    val_transform = Compose([
        # transforms.Resize(config['input_h'], config['input_w']),
        albu.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

10. Dateset()

定义好本地路径相关的信息。

041aeaa415c148569a21a17442900101.png

11. DataLoader()

创建一个用于训练数据的数据加载器。

代码注释如下:

70b1c807e16d444c93c55f89bcd163eb.png

12.OrderedDict()

该函数会定义一个有序字典,方便后续打印。

d75071e43c804b6e9c1e70720ee1150a.png

13.train()

将字典,训练集,模型,损失函数,优化器都传进去。

fff876eb4daf41ef924434fdea67edb5.png

13.1 AverageMeter()

54908ae6d94f4a23be76839fee30ea74.png

ffc5dcdfcac04052a98b5b3f5aec2886.png

13.2 tqdm()

创建一个进度条。总共为train_loader的长度。

2d820e77a16c47c187237127d44f4b99.png

13.3 获取数据

通过train_loader获取输入数据和目标值。

ab11eb6082024fbfb55ff99bebea33d6.png

13.4 model()开始训练

e76346cf58e44c2989ccd0545b62bda2.png

13.4.1 前向传播

从此NestedUNet模型中的,forward开始运行:

class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):  # 3,32,32
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)   # 输入是3,输出是32
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out
class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)#scale_factor:放大的倍数  插值
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
        output = self.final(x0_4)
        return output
class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) #3,32,32  输入3,中间32,输出32(特征图个数)
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])  # 32,64,64
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
    def forward(self, input):
        print('input:',input.shape)
        x0_0 = self.conv0_0(input) # 8,32,96,96
        print('x0_0:',x0_0.shape)
        x1_0 = self.conv1_0(self.pool(x0_0))    # 下采样
        print('x1_0:',x1_0.shape)
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))#00和10维度不一样,因此要将x1_0升维、拼接完之后为8,96,96,96再执行卷积,8,32,96,96
        print('x0_1:',x0_1.shape)
        x2_0 = self.conv2_0(self.pool(x1_0))
        print('x2_0:',x2_0.shape)
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) # x2_0up后(8,128,48,48),拼接后(8,192,48,48),卷积后(8,64,48,48)
        print('x1_1:',x1_1.shape)
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        print('x0_2:',x0_2.shape)
        x3_0 = self.conv3_0(self.pool(x2_0))
        print('x3_0:',x3_0.shape)
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        print('x2_1:',x2_1.shape)
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        print('x1_2:',x1_2.shape)
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        print('x0_3:',x0_3.shape)
        x4_0 = self.conv4_0(self.pool(x3_0))
        print('x4_0:',x4_0.shape)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        print('x3_1:',x3_1.shape)
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        print('x2_2:',x2_2.shape)
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        print('x1_3:',x1_3.shape)
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        print('x0_4:',x0_4.shape)    #(8,32,96,96)
        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output  #(8,1,96,96)

网络流程图如下所示:


image.jpeg


13.5 criterion()

进入损失函数。

b550f207738641b6aadc1af9378290f7.png

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5    # 平滑因子,防止分母为0
        input = torch.sigmoid(input)    #  将输入的对数几率转化为概率值。
        num = target.size(0)
        input = input.view(num, -1)    # 将input转化为2维张量。
        target = target.view(num, -1)
        intersection = (input * target)   #input是概率,target是0/1,因此计算结果是预测的每个样本的正样本像素的概率。
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) # 预测结果和真实标签之间的相似度
        dice = 1 - dice.sum() / num   # 计算了平均Dice系数,并将其减去1,以得到Dice损失
        return 0.5 * bce + dice

13.6 iou_score()

计算iou,交并比。

def iou_score(output, target):
    smooth = 1e-5  # 平滑值
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5   # 预测值大于0.5为True
    target_ = target > 0.5  # 目标值大于0.5为True
    intersection = (output_ & target_).sum()   # 交集
    union = (output_ | target_).sum()          # 并集
    return (intersection + smooth) / (union + smooth)  # iou

13.7 方向传播,梯度优化

image.png

13.8 更新进度条

979b6b750f554a589ccbbf3ff753cf07.png                              

validate() 和train类似跳过

14. scheduler.step()

因为本文使用的余弦退火LR学习率更新方式,调度器通过余弦函数来调整学习率。

190e7a6bbe0949868a7243f7844218b1.png

目录
相关文章
|
7月前
|
JSON IDE 前端开发
[.NET开发者的福音]一个方便易用的在线.NET代码编辑工具.NET Fiddle
[.NET开发者的福音]一个方便易用的在线.NET代码编辑工具.NET Fiddle
|
网络协议 算法 Shell
来我们探究一下net/http 的代码流程
来我们探究一下net/http 的代码流程
|
4月前
|
API
【Azure 媒体服务】Media Service的编码示例 -- 创建缩略图子画面的.NET代码调试问题
【Azure 媒体服务】Media Service的编码示例 -- 创建缩略图子画面的.NET代码调试问题
|
15天前
|
开发框架 .NET PHP
ASP.NET Web Pages - 添加 Razor 代码
ASP.NET Web Pages 使用 Razor 标记添加服务器端代码,支持 C# 和 Visual Basic。Razor 语法简洁易学,类似于 ASP 和 PHP。例如,在网页中加入 `@DateTime.Now` 可以实时显示当前时间。
|
26天前
|
敏捷开发 缓存 中间件
.NET技术的高效开发模式,涵盖面向对象编程、良好架构设计及高效代码编写与管理三大关键要素
本文深入探讨了.NET技术的高效开发模式,涵盖面向对象编程、良好架构设计及高效代码编写与管理三大关键要素,并通过企业级应用和Web应用开发的实践案例,展示了如何在实际项目中应用这些模式,旨在为开发者提供有益的参考和指导。
23 3
|
4月前
|
C# 开发者 Windows
在VB.NET项目中使用C#编写的代码
在VB.NET项目中使用C#编写的代码
62 0
|
2月前
|
前端开发 JavaScript C#
CodeMaid:一款基于.NET开发的Visual Studio代码简化和整理实用插件
CodeMaid:一款基于.NET开发的Visual Studio代码简化和整理实用插件
|
4月前
|
Kubernetes 监控 Devops
【独家揭秘】.NET项目中的DevOps实践:从代码提交到生产部署,你不知道的那些事!
【8月更文挑战第28天】.NET 项目中的 DevOps 实践贯穿代码提交到生产部署全流程,涵盖健壮的源代码管理、GitFlow 工作流、持续集成与部署、容器化及监控日志记录。通过 Git、CI/CD 工具、Kubernetes 及日志框架的最佳实践应用,显著提升软件开发效率与质量。本文通过具体示例,助力开发者构建高效可靠的 DevOps 流程,确保项目成功交付。
90 0
|
4月前
|
XML 开发框架 .NET
.NET框架:软件开发领域的瑞士军刀,如何让初学者变身代码艺术家——从基础架构到独特优势,一篇不可错过的深度解读。
【8月更文挑战第28天】.NET框架是由微软推出的统一开发平台,支持多种编程语言,简化应用程序的开发与部署。其核心组件包括公共语言运行库(CLR)和类库(FCL)。CLR负责内存管理、线程管理和异常处理等任务,确保代码稳定运行;FCL则提供了丰富的类和接口,涵盖网络、数据访问、安全性等多个领域,提高开发效率。此外,.NET框架还支持跨语言互操作,允许开发者使用C#、VB.NET等语言编写代码并无缝集成。这一框架凭借其强大的功能和广泛的社区支持,已成为软件开发领域的重要工具,适合初学者深入学习以奠定职业生涯基础。
110 1
|
4月前
|
API
【Azure Key Vault】.NET 代码如何访问中国区的Key Vault中的机密信息(Get/Set Secret)
【Azure Key Vault】.NET 代码如何访问中国区的Key Vault中的机密信息(Get/Set Secret)