pytorch使用cat()和stack()拼接tensors

简介: pytorch使用cat()和stack()拼接tensors

有时我们在处理数据时,需要对指定的tensor按照指定维度进行拼接,对于这个需求,pytorch中提供了两个函数供我们使用,一个是torch.cat(),另外一个是torch.stack(),这两者都可以拼接tensor,但是这二者又有一些区别。

二者相同点就是都可以实现拼接tensor,不同之处就是是否是在新的维度上进行拼接(是否产生新的维度)。

一、torch.cat()

该方法可以将任意个tensor按照指定维度进行拼接,需要传入两个参数,一个参数是需要拼接的tensor,需要以列表的形式进行传入,第二个参数就是需要拼接的维度。

a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.cat([a, b], dim=0)
print(c)
print(c.shape)
tensor([[ 0.1040, -0.3168, -1.3974, -1.2703],
        [ 0.4375,  1.4254,  0.2875, -0.2420],
        [-0.9663, -1.8022, -1.2352,  0.7283],
        [-0.4226,  0.0375, -0.3861,  1.3939],
        [ 1.6275, -0.1319, -0.7143,  0.3624],
        [ 0.2245, -1.7482, -0.7933, -0.1008]])
torch.Size([6, 4])

该例子中我们定义了两个tensor,维度分别都是【3,4】,我们使用cat进行拼接,传入的维度是0,那么我们得到的结果就是会将两个tensor按照第一个维度进行拼接,可以理解为按行堆叠,把每一行想成一个样本,那么我们拼接后就会得到6个样本,维度变成【6,4】。

二、torch.stack()

第二种方法就是torch.stack()了,该方法也可以进行拼接,但是与cat有一些不同。

对于传入的参数列表和torch.cat是一样的,但是stack指定的dim是一个新的维度,最终是在这个新的维度上进行拼接。

a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.stack([a, b], dim=0)
print(c)
print(c.shape)
tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
         [ 0.4375,  1.4254,  0.2875, -0.2420],
         [-0.9663, -1.8022, -1.2352,  0.7283]],
        [[-0.4226,  0.0375, -0.3861,  1.3939],
         [ 1.6275, -0.1319, -0.7143,  0.3624],
         [ 0.2245, -1.7482, -0.7933, -0.1008]]])
torch.Size([2, 3, 4])

上面我们指定拼接的dim为0,那么我们会新产生一个维度,得到结果【2,3,4】,原来两个tensor的维度不变,新生成一个维度2,代表拼接后维度。

c = torch.stack([a, b], dim=1)
print(c)
print(c.shape)
tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
         [-0.4226,  0.0375, -0.3861,  1.3939]],
        [[ 0.4375,  1.4254,  0.2875, -0.2420],
         [ 1.6275, -0.1319, -0.7143,  0.3624]],
        [[-0.9663, -1.8022, -1.2352,  0.7283],
         [ 0.2245, -1.7482, -0.7933, -0.1008]]])
torch.Size([3, 2, 4])

如果我们设置为1,那么就会新产生1一个维度在第二位,得到结果【3,2,4】。


目录
相关文章
|
12月前
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
|
算法 PyTorch 算法框架/工具
Pytorch - 张量转换拼接
使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
160 0
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor的变换、拼接、拆分讲解及实战(附源码 超详细必看)
PyTorch深度学习基础之Tensor的变换、拼接、拆分讲解及实战(附源码 超详细必看)
234 0
|
机器学习/深度学习 存储 数据挖掘
PyTorch: 张量的拼接、切分、索引
PyTorch: 张量的拼接、切分、索引
338 0
PyTorch: 张量的拼接、切分、索引
|
15天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
69 1
|
5月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
673 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
15天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
52 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
140 9
|
7月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
464 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体

热门文章

最新文章

推荐镜像

更多