pytorch使用cat()和stack()拼接tensors

简介: pytorch使用cat()和stack()拼接tensors

有时我们在处理数据时,需要对指定的tensor按照指定维度进行拼接,对于这个需求,pytorch中提供了两个函数供我们使用,一个是torch.cat(),另外一个是torch.stack(),这两者都可以拼接tensor,但是这二者又有一些区别。

二者相同点就是都可以实现拼接tensor,不同之处就是是否是在新的维度上进行拼接(是否产生新的维度)。

一、torch.cat()

该方法可以将任意个tensor按照指定维度进行拼接,需要传入两个参数,一个参数是需要拼接的tensor,需要以列表的形式进行传入,第二个参数就是需要拼接的维度。

a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.cat([a, b], dim=0)
print(c)
print(c.shape)
tensor([[ 0.1040, -0.3168, -1.3974, -1.2703],
        [ 0.4375,  1.4254,  0.2875, -0.2420],
        [-0.9663, -1.8022, -1.2352,  0.7283],
        [-0.4226,  0.0375, -0.3861,  1.3939],
        [ 1.6275, -0.1319, -0.7143,  0.3624],
        [ 0.2245, -1.7482, -0.7933, -0.1008]])
torch.Size([6, 4])

该例子中我们定义了两个tensor,维度分别都是【3,4】,我们使用cat进行拼接,传入的维度是0,那么我们得到的结果就是会将两个tensor按照第一个维度进行拼接,可以理解为按行堆叠,把每一行想成一个样本,那么我们拼接后就会得到6个样本,维度变成【6,4】。

二、torch.stack()

第二种方法就是torch.stack()了,该方法也可以进行拼接,但是与cat有一些不同。

对于传入的参数列表和torch.cat是一样的,但是stack指定的dim是一个新的维度,最终是在这个新的维度上进行拼接。

a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.stack([a, b], dim=0)
print(c)
print(c.shape)
tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
         [ 0.4375,  1.4254,  0.2875, -0.2420],
         [-0.9663, -1.8022, -1.2352,  0.7283]],
        [[-0.4226,  0.0375, -0.3861,  1.3939],
         [ 1.6275, -0.1319, -0.7143,  0.3624],
         [ 0.2245, -1.7482, -0.7933, -0.1008]]])
torch.Size([2, 3, 4])

上面我们指定拼接的dim为0,那么我们会新产生一个维度,得到结果【2,3,4】,原来两个tensor的维度不变,新生成一个维度2,代表拼接后维度。

c = torch.stack([a, b], dim=1)
print(c)
print(c.shape)
tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
         [-0.4226,  0.0375, -0.3861,  1.3939]],
        [[ 0.4375,  1.4254,  0.2875, -0.2420],
         [ 1.6275, -0.1319, -0.7143,  0.3624]],
        [[-0.9663, -1.8022, -1.2352,  0.7283],
         [ 0.2245, -1.7482, -0.7933, -0.1008]]])
torch.Size([3, 2, 4])

如果我们设置为1,那么就会新产生1一个维度在第二位,得到结果【3,2,4】。


目录
相关文章
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
|
6月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
51 0
|
8月前
|
算法 PyTorch 算法框架/工具
Pytorch - 张量转换拼接
使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。
|
9月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor的变换、拼接、拆分讲解及实战(附源码 超详细必看)
PyTorch深度学习基础之Tensor的变换、拼接、拆分讲解及实战(附源码 超详细必看)
153 0
|
机器学习/深度学习 存储 数据挖掘
PyTorch: 张量的拼接、切分、索引
PyTorch: 张量的拼接、切分、索引
269 0
PyTorch: 张量的拼接、切分、索引
|
20天前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
185 66
|
4月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
605 2
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
85 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
4月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
124 7
利用 PyTorch Lightning 搭建一个文本分类模型
|
4月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
290 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力