pytorch中Sequential()的三种构造方法

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: pytorch中Sequential()的三种构造方法

对于图像分类任务中,一般神经网络的框架都是卷积、激活、池化,然后不断堆叠,最终嘉善输出层,并不会有较为复杂的连接结构,都是每一层的输出为下一层的输入。

对于简单的前馈神经网络,为了在forward中避免重复的运算操作,可以使用pytorch中内置的Sequential()这个特殊的module,我们可以将重复操作放在Sequential()中,然后模型就会自动识别内部的子module,然后就会按照顺序一层一层进行传播。

方法一

首先创建一个Sequential,然后不断添加模型子层,传入的第一个参数为该层的名称,第二个是该层的实例。

model1 = nn.Sequential()
model1.add_module('conv', nn.Conv2d(1, 5, 2))
model1.add_module('fc', nn.Linear(10, 2))
model1.add_module('sigmoid', nn.Sigmoid())

方法二

对于上面方法进行简化,可以直接将所有实例化的子类按顺序传入,但是这种方法有个缺点,就是不能对每个层进行命名,默认每个层的名称是0,1,2。

model2 = nn.Sequential(nn.Conv2d(1, 5, 2),
                      nn.Linear(10, 2),
                      nn.Sigmoid())

方法三

对于方法三是首先创建一个字典,然后将这个字典传入,这个字典和方法一一致都是层的名称和该层的实例化的键值对。

from collections import OrderedDict
model3 = nn.Sequential(OrderedDict([
    ('conv', nn.Conv2d(1, 5, 2)),
    ('fc', nn.Linear(10, 2)),
    ('sigmoid', nn.Sigmoid())
]))

输出模型结构

我们可以打印模型,查看模型的每层参数

print(model1)
print(model2)
print(model3)
Sequential(
  (conv): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
  (fc): Linear(in_features=10, out_features=2, bias=True)
  (sigmoid): Sigmoid()
)
Sequential(
  (0): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
  (1): Linear(in_features=10, out_features=2, bias=True)
  (2): Sigmoid()
)
Sequential(
  (conv): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
  (fc): Linear(in_features=10, out_features=2, bias=True)
  (sigmoid): Sigmoid()
)

查看指定层

可以查看Sequential中的具体的某一个子模块,对于方法一和方法三都是需要传入层的名称,所以我们可以直接使用名称获得该层,而对于方法二没有名字,所以只可以通过索引获取。

print(model1.conv)
print(model2[0])
print(model3.conv)
Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))


相关实践学习
【AI破次元壁合照】少年白马醉春风,函数计算一键部署AI绘画平台
本次实验基于阿里云函数计算产品能力开发AI绘画平台,可让您实现“破次元壁”与角色合照,为角色换背景效果,用AI绘图技术绘出属于自己的少年江湖。
从 0 入门函数计算
在函数计算的架构中,开发者只需要编写业务代码,并监控业务运行情况就可以了。这将开发者从繁重的运维工作中解放出来,将精力投入到更有意义的开发任务上。
目录
相关文章
|
NoSQL Java 关系型数据库
基于Java swing和mysql实现的学生选课管理系统(源码+数据库+运行指导视频)
基于Java swing和mysql实现的学生选课管理系统(源码+数据库+运行指导视频)
712 0
用于演化博弈中,列出复制动态方程后,求解复制动态方程的均衡点
用于演化博弈中,列出复制动态方程后,求解复制动态方程的均衡点
|
弹性计算 网络协议 Linux
为什么我的幻兽帕鲁服务器搭建好了之后连不上,提示超时?
幻兽帕鲁服务器刚刚搭建完成,你一定迫不及待的的想要连上去玩耍了,但是连接等待半天后,不是进入到游戏而是提示超时,令人崩溃。
9811 2
成功解决:443端口被vmware-host(8992)占用。请关掉占用443端口的程序或者尝试使用系统代理模式
该博客文章提供了解决443端口被vmware-host占用问题的方法,包括关闭占用端口的程序或尝试使用系统代理模式。
成功解决:443端口被vmware-host(8992)占用。请关掉占用443端口的程序或者尝试使用系统代理模式
|
数据采集 监控 异构计算
transformers+huggingface训练模型
本教程介绍了如何使用 Hugging Face 的 `transformers` 库训练一个 BERT 模型进行情感分析。主要内容包括:导入必要库、下载 Yelp 评论数据集、数据预处理、模型加载与配置、定义训练参数、评估指标、实例化训练器并开始训练,最后保存模型和训练状态。整个过程详细展示了如何利用预训练模型进行微调,以适应特定任务。
832 3
|
编解码 人工智能 Linux
SD中的VAE,你不能不懂
要想生成一幅美丽的图片,没有VAE可不行
SD中的VAE,你不能不懂
|
机器学习/深度学习 数据采集 人工智能
预测知识 | 机器学习预测模型局限性
预测知识 | 机器学习预测模型局限性
|
前端开发 JavaScript C++
CSS 【详解】样式选择器(含ID、类、标签、通配、属性、伪类、伪元素、Content属性、子代、后代、兄弟、相邻兄弟、交集、并集等选择器)
CSS 【详解】样式选择器(含ID、类、标签、通配、属性、伪类、伪元素、Content属性、子代、后代、兄弟、相邻兄弟、交集、并集等选择器)
2849 0
|
机器学习/深度学习 数据可视化 算法
多项式Logistic逻辑回归进行多类别分类和交叉验证准确度箱线图可视化
多项式Logistic逻辑回归进行多类别分类和交叉验证准确度箱线图可视化
|
机器学习/深度学习 JSON 并行计算