深度学习:Pytorch nn模块函数解读

简介: 深度学习:Pytorch nn模块函数解读

深度学习:Pytorch nn模块函数解读

nn

nn.Parameter()

这个方法可以把不可以训练的Tensor变成可以通过反向传播更新的参数。
我们以线性回归为例:
先看一下普通的版本

import torch
from torch import nn


class Linear_Regression(nn.Module):
    def __init__(self):
        super(Linear_Regression, self).__init__()
        self.test = torch.rand(1, 2)
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        y = self.linear(x)
        return y


input = torch.rand([3, 2])
linear = Linear_Regression()
print(linear(input))
print((list(linear.named_parameters())))

打印结果分别为 线性回归的输出 与模型的参数
在这里插入图片描述

下面是 引用 nn.Parameter()的版本

import torch
from torch import nn


class Linear_Regression(nn.Module):
    def __init__(self):
        super(Linear_Regression, self).__init__()
        self.test = nn.Parameter(torch.rand(1, 2))
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        y = self.linear(x)
        return y


input = torch.rand([3, 2])
linear = Linear_Regression()
print(linear(input))
print((list(linear.named_parameters())))

结果如下:
发现可更新的参数多了个test
在这里插入图片描述

nn.Embedding()

nn.Embeddding接受两个重要参数:

num_embeddings:字典的大小,就是我这个序列有多少个词
embedding_dim:要将单词编码成多少维的向量

代码如下:
假设我有5个词,我要把这个5个词转换成3维的tensor

input = torch.arange(5)
print(input)
emb = nn.Embedding(5,3)
print(emb(input))

在这里插入图片描述

Torch

torch.squeeze

目录
相关文章
|
1天前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
111 58
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:
16 4
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
|
6天前
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
26 1
|
16天前
|
机器学习/深度学习 人工智能 PyTorch
【Deepin 20深度探索】一键解锁Linux深度学习潜能:从零开始安装Pytorch,驾驭AI未来从Deepin出发!
【8月更文挑战第2天】随着人工智能的迅猛发展,深度学习框架Pytorch已成为科研与工业界的必备工具。Deepin 20作为优秀的国产Linux发行版,凭借其流畅的用户体验和丰富的软件生态,为深度学习爱好者提供理想开发平台。本文引导您在Deepin 20上安装Pytorch,享受Linux下的深度学习之旅。
39 12
|
13天前
|
机器学习/深度学习 存储 PyTorch
【深度学习】Pytorch面试题:什么是 PyTorch?PyTorch 的基本要素是什么?Conv1d、Conv2d 和 Conv3d 有什么区别?
关于PyTorch面试题的总结,包括PyTorch的定义、基本要素、张量概念、抽象级别、张量与矩阵的区别、不同损失函数的作用以及Conv1d、Conv2d和Conv3d的区别和反向传播的解释。
35 2
|
13天前
|
机器学习/深度学习 算法 PyTorch
【深度学习】TensorFlow面试题:什么是TensorFlow?你对张量了解多少?TensorFlow有什么优势?TensorFlow比PyTorch有什么不同?该如何选择?
关于TensorFlow面试题的总结,涵盖了TensorFlow的基本概念、张量的理解、TensorFlow的优势、数据加载方式、算法通用步骤、过拟合解决方法,以及TensorFlow与PyTorch的区别和选择建议。
30 2
|
19天前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
【7月更文挑战第31天】在数据驱动时代,Python凭借其简洁性与强大的库支持,成为数据分析与机器学习的首选语言。**数据分析基础**从Pandas和NumPy开始,Pandas简化了数据处理和清洗,NumPy支持高效的数学运算。例如,加载并清洗CSV数据、计算总销售额等。
32 2
|
19天前
|
机器学习/深度学习 数据挖掘 TensorFlow
|
20天前
|
机器学习/深度学习 人工智能 数据挖掘
从0到1构建AI帝国:PyTorch深度学习框架下的数据分析与实战秘籍
【7月更文挑战第30天】PyTorch以其灵活性和易用性成为深度学习的首选框架。
39 2
|
21天前
|
机器学习/深度学习 并行计算 数据挖掘
🎓PyTorch深度学习入门课:编程小白也能玩转的高级数据分析术
【7月更文挑战第29天】踏入深度学习世界,新手也能用PyTorch解锁高级数据分析。
16 2

热门文章

最新文章