【PyTorch】实现SqueezeNet的Fire模块

简介: 【PyTorch】实现SqueezeNet的Fire模块

问题

SqueezeNet是一款非常经典的CV网络,其设计理念对后续的很多网络都有非常强的指导意义,其核心思想包括:

  • 使用1x1卷积核替代3x3,主要原因是3x3的卷积核参数量是1x1的9倍多;
  • 降低3x3卷积核的通道数量;
  • 网络结构中延迟下采样的时机以获得较大尺寸的激活特征图;

方法

下面介绍PyTorch实现的SqueezeNet网络最核心的Fire模块,如下:

import torch
from torch import nn, Tensor
from typing import Any
class BasicConv2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: # 增加in_xxx和out_xxx的好处是,调用的时候可以省略参数名
        super().__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, **kwargs) # **容易漏掉
        self.relu = nn.ReLU()
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2d(x)
        out = self.relu(x)
        return out
class Fire(nn.Module):
    def __init__(self, in_channels: int, s_1x1: int, e_1x1: int, e_3x3: int) -> None:
        super().__init__()
        self.squeeze = BasicConv2d(in_channels, s_1x1, kernel_size=1)
        self.expand_1x1 = BasicConv2d(s_1x1, e_1x1, kernel_size = 1)
        self.expand_3x3 = BasicConv2d(s_1x1, e_3x3, kernel_size = 3, padding = 1) # p=1是为了保持3x3特征图不变
    def forward(self, x: Tensor) -> Tensor:
        x = self.squeeze(x)
        return torch.cat([
            self.expand_1x1(x), 
            self.expand_3x3(x)
        ], dim=1)
if __name__ == '__main__':
    x = torch.rand(size=(1, 3, 224, 224))
    conv2d = BasicConv2d(3,  64, kernel_size = 3, padding = 1, stride = 1)
    print(conv2d(x).shape) # torch.Size([1, 64, 224, 224])   
    fire = Fire(3, 32, 32, 48)
    print(fire(x).shape) # torch.Size([1, 80, 224, 224])


目录
相关文章
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)
PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)
374 1
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
1431 0
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之激活函数模块中Sigmoid、Tanh、ReLU、LeakyReLU函数讲解(附源码)
PyTorch基础之激活函数模块中Sigmoid、Tanh、ReLU、LeakyReLU函数讲解(附源码)
1368 0
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
2309 0
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之张量模块数据类型、基本操作、与Numpy数组的操作详解(附源码 简单全面)
PyTorch基础之张量模块数据类型、基本操作、与Numpy数组的操作详解(附源码 简单全面)
292 0
|
PyTorch 算法框架/工具
Pytorch学习笔记(一):torch.cat()模块的详解
这篇博客文章详细介绍了Pytorch中的torch.cat()函数,包括其定义、使用方法和实际代码示例,用于将两个或多个张量沿着指定维度进行拼接。
809 0
Pytorch学习笔记(一):torch.cat()模块的详解
|
机器学习/深度学习 PyTorch 数据处理
PyTorch数据处理:torch.utils.data模块的7个核心函数详解
在机器学习和深度学习项目中,数据处理是至关重要的一环。PyTorch作为一个强大的深度学习框架,提供了多种灵活且高效的数据处理工具
263 1
|
机器学习/深度学习 算法 PyTorch
Pytorch的常用模块和用途说明
肆十二在B站分享PyTorch常用模块及其用途,涵盖核心库torch、神经网络库torch.nn、优化库torch.optim、数据加载工具torch.utils.data、计算机视觉库torchvision等,适合深度学习开发者参考学习。链接:[肆十二-哔哩哔哩](https://space.bilibili.com/161240964)
469 0
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:自动微分模块
探索PyTorch:自动微分模块
|
机器学习/深度学习 存储 PyTorch
Pytorch-自动微分模块
PyTorch的torch.autograd模块提供了自动微分功能,用于深度学习中的梯度计算。它包括自定义操作的函数、构建计算图、数值梯度检查、错误检测模式和梯度模式设置等组件。张量通过设置`requires_grad=True`来追踪计算,`backward()`用于反向传播计算梯度,`grad`属性存储张量的梯度。示例展示了如何计算标量和向量张量的梯度,并通过`torch.no_grad()`等方法控制梯度计算。在优化过程中,梯度用于更新模型参数。注意,使用numpy转换要求先`detach()`以避免影响计算图。
284 10

热门文章

最新文章

推荐镜像

更多