【Pytorch--代码技巧】各种论文代码常见技巧

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,1000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
简介: 博主在阅读论文原代码的时候常常看见一些没有见过的代码技巧,特此将这些内容进行汇总

 博主在阅读论文原代码的时候常常看见一些没有见过的代码技巧,特此将这些内容进行汇总

1 torch.view()

作用:重置Tensor对象维度

注意点:参数中的-1表示系统自动判断,因此每个view里面只能出现一个-1

x = torch.randn(4,4)
# 重置为向量
x.view(16).size()
# 重置为多维矩阵
x.view(2,2,4).size()
# -1 数字用法
x.view(-1,8).sieze

image.gif

torch.Size([16])
torch.Size([2,2,4])
torch.Size([2,8])

image.gif

2  torch.unsqueeze()

作用:升维,最常见的就是unsqueeze(-1)表示将一维升到二维

x = torch.randn(4)
x = x.unsqueeze(-1)

image.gif

torch.Size([16,1])

image.gif

3 torch.expand()

作用:升维

# x 维度为[4]
x =torch.tensor([1,2,3,4])
# x1 维度为[3,1,4]
x1 = x.expend(3,1,4)
print(x1)
>>> 
tensor([[[1, 2, 3, 4]],
        [[1, 2, 3, 4]],
        [[1, 2, 3, 4]]])

image.gif

torch.Size([16,1])

image.gif

4 torch.transpose(0, 1)

作用:转置

注意:

1 只能拓展维度,比如 A的shape为 2x4的,不能 A.expend(1,4),只能保证原结构不变,在前面增维,比如A.shape(1,1,4)

2 可以增加多维,比如x的shape为(4),x.expend(2,2,1,4)只需保证本身是4

3 不能拓展低维,比如x的shape为(4),不能x.expend(4,2)

x = torch.randn(16,1)
x = x.transpose(0,1)

image.gif

torch.Size([1,16])

image.gif

5 去除对角线元素

一般使用z方法,因为x方法对float不可用而z可以

x = torch.randint(1, 4, (4, 4))
y = x ^ torch.diag_embed(torch.diag(x))
z = x - torch.diag_embed(torch.diag(x))

image.gif

tensor([[2, 2, 2, 1],
        [3, 1, 1, 2],
        [3, 1, 3, 1],
        [3, 1, 2, 2]])
tensor([[0, 2, 2, 1],
        [3, 0, 1, 2],
        [3, 1, 0, 1],
        [3, 1, 2, 0]])
tensor([[0, 2, 2, 1],
        [3, 0, 1, 2],
        [3, 1, 0, 1],
        [3, 1, 2, 0]])

image.gif

6 torch.gather()

作用:根据维度dim按照索引列表index从input中选取指定元素

b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))

image.gif

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 3.]])

image.gif

7 nn.Parameter()

作用:将一个不可训练的tensor转换为一个有梯度的可训练的tensor。往往用在需要自己定义的bisa中

self.bias = nn.Parameter(torch.ones([5]))
output = self.Weight(x) + self.bias

image.gif

8 nn.Sequential

作用:将多个网络模块组合成一个模块,需要注意的是相邻的两个网络模块之间的输入输出尺寸

self.block = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512 * 2, 128),
            nn.Linear(128, 16),
            nn.Linear(16, num_classes),
            nn.Softmax(dim=1)
        )

image.gif

9 nn.moduleList

作用:将多个网络模块放到一个类List中,后续可以从中进行调用

# 定义三个输入channels为1,输出channels为2,卷积核为[2,768]、[3,768]、[4,768]的卷积
        self.convs = nn.ModuleList(
            [nn.Conv2d(in_channels=1, out_channels=self.num_filters,
                       kernel_size=(k, 768), ) for k in self.filter_sizes])

image.gif

10 nn.MaxPool2d

作用:MaxPooling,提取重要信息,去掉不重要信息,从而减少计算开销

image.gif编辑

其一共有6个基本参数

    • kernel_size:池化窗口大小,输入单值如3则为3×3,输入元组如(3,2)则为3×2
    • stride:步长,单值元组均可。默认与池化窗口大小一致
    • padding:填充,单值元组均可。默认为0
    • dilation:控制窗口中元素步幅,不重要!
    • return_indices:布尔类型,返回最大值位置索引
    • ceil_mode:布尔类型,默认为False。False为向下取整,True为向上取整

    11 permute()

    作用:根据位置进行维度转化

    >>> x = torch.randn(2, 3, 5) 
    >>> x.size() 
    torch.Size([2, 3, 5]) 
    >>> x.permute(2, 0, 1).size() 
    torch.Size([5, 2, 3])

    image.gif

    12 torch.cat(inputs,dim)

    作用:根据维度进行Tensor拼接

    b = torch.cat(a, 4)

    image.gif

    13 torch.clamp(input, min, max, out=None)

    作用:将维度限制在min和max之间

    sum_mask = torch.clamp(sum_mask, min=1e-9)

    image.gif


    目录
    相关文章
    |
    8月前
    |
    机器学习/深度学习 JavaScript PyTorch
    9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
    生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
    677 7
    9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
    |
    3月前
    |
    PyTorch 算法框架/工具 异构计算
    PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
    我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
    246 9
    |
    存储 物联网 PyTorch
    基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
    **Torchtune**是由PyTorch团队开发的一个专门用于LLM微调的库。它旨在简化LLM的微调流程,提供了一系列高级API和预置的最佳实践
    629 59
    基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
    |
    4月前
    |
    机器学习/深度学习 数据可视化 PyTorch
    Flow Matching生成模型:从理论基础到Pytorch代码实现
    本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
    1817 0
    Flow Matching生成模型:从理论基础到Pytorch代码实现
    |
    5月前
    |
    机器学习/深度学习 PyTorch 算法框架/工具
    提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
    本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
    174 4
    提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
    |
    7月前
    |
    机器学习/深度学习 数据可视化 机器人
    比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
    扩散模型和流匹配是生成高分辨率数据(如图像和机器人轨迹)的先进技术。扩散模型通过逐步去噪生成数据,其代表应用Stable Diffusion已扩展至机器人学领域形成“扩散策略”。流匹配作为更通用的方法,通过学习时间依赖的速度场将噪声转化为目标分布,适用于图像生成和机器人轨迹生成,且通常以较少资源实现更快生成。 本文深入解析流匹配在图像生成中的应用,核心思想是将图像视为随机变量的实现,并通过速度场将源分布转换为目标分布。文中提供了一维模型训练实例,展示了如何用神经网络学习速度场,以及使用最大均值差异(MMD)改进训练效果。与扩散模型相比,流匹配结构简单,资源需求低,适合多模态分布生成。
    545 13
    比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
    |
    7月前
    |
    机器学习/深度学习 编解码 PyTorch
    从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
    本文介绍了一种基于扩散模型的文本到视频生成系统,详细展示了模型架构、训练流程及生成效果。通过3D U-Net结构和多头注意力机制,模型能够根据文本提示生成高质量视频。
    309 1
    从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
    |
    机器学习/深度学习 PyTorch 算法框架/工具
    CNN中的注意力机制综合指南:从理论到Pytorch代码实现
    注意力机制已成为深度学习模型的关键组件,尤其在卷积神经网络(CNN)中发挥了重要作用。通过使模型关注输入数据中最相关的部分,注意力机制显著提升了CNN在图像分类、目标检测和语义分割等任务中的表现。本文将详细介绍CNN中的注意力机制,包括其基本概念、不同类型(如通道注意力、空间注意力和混合注意力)以及实际实现方法。此外,还将探讨注意力机制在多个计算机视觉任务中的应用效果及其面临的挑战。无论是图像分类还是医学图像分析,注意力机制都能显著提升模型性能,并在不断发展的深度学习领域中扮演重要角色。
    664 10
    |
    9月前
    |
    机器学习/深度学习 存储 算法
    近端策略优化(PPO)算法的理论基础与PyTorch代码详解
    近端策略优化(PPO)是深度强化学习中高效的策略优化方法,广泛应用于大语言模型的RLHF训练。PPO通过引入策略更新约束机制,平衡了更新幅度,提升了训练稳定性。其核心思想是在优势演员-评论家方法的基础上,采用裁剪和非裁剪项组成的替代目标函数,限制策略比率在[1-ϵ, 1+ϵ]区间内,防止过大的策略更新。本文详细探讨了PPO的基本原理、损失函数设计及PyTorch实现流程,提供了完整的代码示例。
    4026 10
    近端策略优化(PPO)算法的理论基础与PyTorch代码详解
    |
    机器学习/深度学习 PyTorch 算法框架/工具
    聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
    本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
    973 0
    聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现

    推荐镜像

    更多