【PyTorch】两种不同分类层的设计方法

简介: 【PyTorch】两种不同分类层的设计方法

问题

涉及到图像分类的网络的最后一层分类层,有两种实现方法,如下所示,你更偏向于哪种方法呢?

方法

方法1

import torch
from torch import nn
'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(32, 2)
    def forward(self, x):
        x = self.conv(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1) # 展开所有元素
        out = self.classifier(x)
        return out
if __name__ == '__main__':
    from torchsummary import summary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    summary(net, (3, 7, 7))

方法2

import torch
from torch import nn
'''
测试池化和卷积组合的分类层
'''
class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(3, 32, 3, padding=1)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        x = self.conv(x)
        out = self.classifier(x)
        return out
if __name__ == '__main__':
    from torchsummary import summary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.rand(size=(1, 3, 7, 7)).to(device)
    net = MyNet().to(device)
    out = net(x)
    summary(net, (3, 7, 7))

结语

从扩展性、可读性的角度来说,更偏向于方法2的设计。

目录
相关文章
|
2月前
|
机器学习/深度学习 存储 PyTorch
Pytorch中in-place操作相关错误解析及detach()方法说明
Pytorch中in-place操作相关错误解析及detach()方法说明
152 0
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】18. Pytorch中自定义层的几种方法:nn.Module、ParameterList和ParameterDict
【从零开始学习深度学习】18. Pytorch中自定义层的几种方法:nn.Module、ParameterList和ParameterDict
|
19天前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
19天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】17. Pytorch中模型参数的访问、初始化和共享方法
【从零开始学习深度学习】17. Pytorch中模型参数的访问、初始化和共享方法
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用FP8加速PyTorch训练的两种方法总结
在PyTorch中,FP8数据类型用于高效训练和推理,旨在减少内存占用和加快计算速度。虽然官方尚未全面支持,但在2.2版本中引入了`torch.float8_e4m3fn`和`torch.float8_e5m2`。文章通过示例展示了如何利用FP8优化Vision Transformer模型,使用Transformer Engine库提升性能,并探讨了PyTorch原生FP8支持的初步使用方法。实验表明,结合TE和FP8,训练速度可提升3倍,性能有显著增强,特别是在NVIDIA GPU上。然而,PyTorch的FP8支持仍处于试验阶段,可能带来不稳定性。
64 0
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
184 0
|
12月前
|
并行计算 PyTorch 算法框架/工具
离线下载安装PyTorch的不报错方法
离线下载安装PyTorch的不报错方法
|
存储 PyTorch 算法框架/工具
聊一聊pytorch中的张量基本方法
聊一聊pytorch中的张量基本方法
106 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 15 过拟合问题的优化技巧(二):Dropout()方法
异常数据的特点:与主流样本中的规律不同,在一个样本中出现的概率要比主流数据出现的概率低很多。在每次训练中,忽略模型中一些节点,将小概率的异常数据获得学习的机会变得更低。这样,异常数据对模型的影响就会更小。
171 0
|
机器学习/深度学习 人工智能 PyTorch