pytorch中nn.Parameter()使用方法

简介: pytorch中nn.Parameter()使用方法

对于nn.Parameter()是pytorch中定义可学习参数的一种方法,因为我们在搭建网络时,网络中会存在一些矩阵,这些矩阵内部的参数是可学习的,也就是可梯度求导的。

对于一些常用的网络层,例如nn.Conv2d()卷积层nn.LInear()线性层nn.LSTM()循环网络层等,这些网络层在pytorch中的nn模块中已经定义好,所以我们搭建模型时可以直接使用,但是有些自定义网络在pytorch中是没有实现的,我们就需要自定义可学习参数,那就用到了nn.Parameter()这个函数。

该函数会为我们创建一个矩阵,该矩阵是默认可梯度求导的,之后我们就可以利用这个矩阵进行计算,该函数需要传入的参数是一个tensor,一般我们会传入一个初始化好的tensor。

下面我们将使用一个简单的线性层作为实例,来理解如何使用nn.Parameter()。

一、nn.Linear()定义参数

在类中我们定义了一个线性层,输入维度是10,输出维度是3,对于nn.Linear()层内部已经封装好了nn.Parameter(),所以不需要我们自定义,直接使用即可。

class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 3)
    def forward(self, x):
        return F.sigmoid(self.linear(x))

二、nn.Parameter()定义参数

对于一个线性层,我们会需要两个矩阵,分别是权重W和偏置b,所以我们要用nn.Parameter()定义两个可学习参数,然后传入对应维度的tensor作为参数,之后就可以在forward中定义计算过程。

class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn(10, 3))
        self.b = nn.Parameter(torch.randn(3))
    def forward(self, x):
        return F.sigmoid(self.W @ x + self.b)

三、查看可学习参数

利用下面代码就可以看定义好的模型中的参数

model1 = Net1()
model2 = Net2()
for name, parameters in model1.named_parameters():
    print(name, ':', parameters.size())
for name, parameters in model2.named_parameters():
    print(name, ':', parameters.size())
linear.weight : torch.Size([3, 10])
linear.bias : torch.Size([3])
W : torch.Size([10, 3])
b : torch.Size([3


目录
相关文章
|
PyTorch 算法框架/工具
pytorch中torch.clamp()使用方法
pytorch中torch.clamp()使用方法
506 0
pytorch中torch.clamp()使用方法
|
并行计算 PyTorch 测试技术
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-2
由于要进行 tensor 的学习,因此,我们先导入我们需要的库。
|
机器学习/深度学习 人工智能 自然语言处理
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-1
PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebook 的人工智能小组开发,不仅能够实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架如 TensorFlow 都不支持的。
|
机器学习/深度学习 人工智能 PyTorch
|
PyTorch 算法框架/工具
pytorch中ImageFolder()使用方法
pytorch中ImageFolder()使用方法
294 0
pytorch中ImageFolder()使用方法
|
PyTorch 算法框架/工具 异构计算
基于Pytorch查看本地或者远程服务器GPU及使用方法
基于Pytorch查看本地或者远程服务器GPU及使用方法
466 0
基于Pytorch查看本地或者远程服务器GPU及使用方法
|
PyTorch 算法框架/工具
pytorch中keepdim参数归并操作使用方法
pytorch中keepdim参数归并操作使用方法
134 0
|
PyTorch 算法框架/工具
pytorch中meter.AverageValueMeter()使用方法
pytorch中meter.AverageValueMeter()使用方法
257 0
|
PyTorch 算法框架/工具
pytorch中meter.ClassErrorMeter()使用方法
pytorch中meter.ClassErrorMeter()使用方法
162 0
|
PyTorch 算法框架/工具
pytorch中nn.ModuleList()使用方法
pytorch中nn.ModuleList()使用方法
319 0