PyTorch的nn.Linear()详解

简介: 从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。

1. nn.Linear()


  • nn.Linear():用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量


  • 一般形状为[batch_size, size],不同于卷积层要求输入输出是四维张量。其用法与形参说明如下:


49c52fa5a5cd2e32ec5cac2632df8b44.png


  • in_features指的是输入的二维张量的大小,即输入的[batch_size, size]中的size。


  • out_features指的是输出的二维张量的大小,即输出的二维张量的形状为[batch_size,output_size],当然,它也代表了该全连接层的神经元个数。


  • 从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。


用法示例:


import torch as t
from torch import nn
from torch.nn import functional as F
# 假定输入的图像形状为[3,64,64]
x = t.randn(10, 3, 64, 64)      # 10张 3个channel 大小为64x64的图片
x = nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0)(x)
print(x.shape)
# 之前的特征图尺寸为多少,只要设置为(1,1),那么最终特征图大小都为(1,1) 
# x = F.adaptive_avg_pool2d(x, [1,1])    # [b, 64, h, w] => [b, 64, 1, 1]
# print(x.shape)
# 将四维张量转换为二维张量之后,才能作为全连接层的输入
x = x.view(x.size(0), -1)
print(x.shape)
# in_features由输入张量的形状决定,out_features则决定了输出张量的形状 
connected_layer = nn.Linear(in_features = 64*21*21, out_features = 10)
# 调用全连接层
output = connected_layer(x) 
print(output.shape)
torch.Size([10, 64, 21, 21])
torch.Size([10, 28224])
torch.Size([10, 10])


目录
相关文章
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
torch.nn.Linear的使用方法
torch.nn.Linear的使用方法
62 0
|
1月前
|
机器学习/深度学习 资源调度 监控
PyTorch使用Tricks:Dropout,R-Dropout和Multi-Sample Dropout等 !!
PyTorch使用Tricks:Dropout,R-Dropout和Multi-Sample Dropout等 !!
26 0
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
基于torch.nn.Dropout通过实例说明Dropout丢弃法(附代码)
基于torch.nn.Dropout通过实例说明Dropout丢弃法(附代码)
20 0
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch torch.nn库以及nn与nn.functional有什么区别?
Pytorch torch.nn库以及nn与nn.functional有什么区别?
44 0
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch中nn.ReLU()和F.relu()有什么区别?
pytorch中nn.ReLU()和F.relu()有什么区别?
369 0
|
8月前
|
机器学习/深度学习 算法 PyTorch
Linear Regression with PyTorch 用PyTorch实现线性回归
Linear Regression with PyTorch 用PyTorch实现线性回归
80 0
|
10月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch】nn.ReLU()与F.relu()的区别
【PyTorch】nn.ReLU()与F.relu()的区别
99 0
|
10月前
|
机器学习/深度学习 PyTorch 算法框架/工具
|
10月前
|
PyTorch 算法框架/工具
【PyTorch简明教程】torch.Tensor()与torch.tensor()的区别
【PyTorch简明教程】torch.Tensor()与torch.tensor()的区别
84 0
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
425 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法