Pytorch中张量的高级选择操作

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 在某些情况下,我们需要用Pytorch做一些高级的索引/选择,所以在这篇文章中,我们将介绍这类任务的三种最常见的方法:torch.index_select, torch.gather and torch.take

我们首先从一个2D示例开始,并将选择结果可视化,然后延申到3D和更复杂场景。最后以表格的形式总结了这些函数及其区别。

torch.index_select

torch.index_select

是 PyTorch 中用于按索引选择张量元素的函数。它的作用是从输入张量中按照给定的索引值,选取对应的元素形成一个新的张量。它沿着一个维度选择元素,同时保持其他维度不变。也就是说:保留所有其他维度的元素,但在索引张量之后的目标维度中选择元素。

 num_picks = 2

 values = torch.rand((len_dim_0, len_dim_1))
 indices = torch.randint(0, len_dim_1, size=(num_picks,))
 # [len_dim_0, num_picks]
 picked = torch.index_select(values, 1, indices)

上面代码将得到的张量形状为[len_dim_0, num_picks]:对于沿维度0的每个元素,我们从维度1中选择了相同的元素。

现在我们使用3D张量,一个形状为[batch_size, num_elements, num_features]的张量:这样我们就有了num_elements元素和num_feature特征,并且是一个批次进行处理的。我们为每个批处理/特性组合选择相同的元素:

 import torch

 batch_size = 16
 num_elements = 64
 num_features = 1024
 num_picks = 2

 values = torch.rand((batch_size, num_elements, num_features))
 indices = torch.randint(0, num_elements, size=(num_picks,))
 # [batch_size, num_picks, num_features]
 picked = torch.index_select(values, 1, indices)

下面是如何使用简单的for循环重新实现这个函数的方法:

 picked_manual = torch.zeros_like(picked)
 for i in range(batch_size):
     for j in range(num_picks):
         for k in range(num_features):
             picked_manual[i, j, k] = values[i, indices[j], k]

 assert torch.all(torch.eq(picked, picked_manual))

这样对比可以对index_select有一个更深入的了解

torch.gather

torch.gather

是 PyTorch 中用于按照指定索引从输入张量中收集值的函数。它允许你根据指定的索引从输入张量中取出对应位置的元素,并组成一个新的张量。它的行为类似于index_select,但是现在所需维度中的元素选择依赖于其他维度——也就是说对于每个批次索引,对于每个特征,我们可以从“元素”维度中选择不同的元素——我们将从一个张量作为另一个张量的索引。

 num_picks = 2

 values = torch.rand((len_dim_0, len_dim_1))
 indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
 # [len_dim_0, num_picks]
 picked = torch.gather(values, 1, indices)

现在的选择不再以直线为特征,而是对于沿着维度0的每个索引,在维度1中选择一个不同的元素:

我们继续扩展为3D的张量,并展示Python代码来重新实现这个选择:

 import torch

 batch_size = 16
 num_elements = 64
 num_features = 1024
 num_picks = 5
 values = torch.rand((batch_size, num_elements, num_features))
 indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
 picked = torch.gather(values, 1, indices)

 picked_manual = torch.zeros_like(picked)
 for i in range(batch_size):
     for j in range(num_picks):
         for k in range(num_features):
             picked_manual[i, j, k] = values[i, indices[i, j, k], k]

 assert torch.all(torch.eq(picked, picked_manual))
torch.gather

是一个灵活且强大的函数,可以在许多情况下用于数据收集和操作,尤其在需要按照指定索引收集数据的情况下非常有用。

torch.take

torch.take

是 PyTorch 中用于从输入张量中按照给定索引取值的函数。它类似于

torch.index_select

torch.gather

,但是更简单,只需要一个索引张量即可。它本质上是将输入张量视为扁平的,然后从这个列表中选择元素。例如:当对形状为[4,5]的输入张量应用take,并选择指标6和19时,我们将获得扁平张量的第6和第19个元素——即来自第2行的第2个元素,以及最后一个元素。

 num_picks = 2

 values = torch.rand((len_dim_0, len_dim_1))
 indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
 # [num_picks]
 picked = torch.take(values, indices)

我们现在只得到两个元素:

3D张量也是一样的这里索引张量可以是任意形状的,只要最大索引不超过张量的总数即可:

 import torch

 batch_size = 16
 num_elements = 64
 num_features = 1024
 num_picks = (2, 5, 3)

 values = torch.rand((batch_size, num_elements, num_features))
 indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
 # [2, 5, 3]
 picked = torch.take(values, indices)

 picked_manual = torch.zeros(num_picks)
 for i in range(num_picks[0]):
     for j in range(num_picks[1]):
         for k in range(num_picks[2]):
             picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]

 assert torch.all(torch.eq(picked, picked_manual))

总结

为了总结这篇文章,我们在一个表格中总结了这些函数之间的区别——包含简短的描述和示例形状。样本形状是针对前面提到的3D ML示例量身定制的,并将列出索引张量的必要形状,以及由此产生的输出形状:

当你想要从一个张量中按照索引选取子集时可以使用

torch.index_select

,它通常用于在给定维度上选择元素。适用于较为简单的索引选取操作。

torch.gather

适用于根据索引从输入张量中收集元素并形成新张量的情况。可以根据需要在不同维度上进行收集操作。

torch.take

适用于一维索引,从输入张量中取出对应索引位置的元素。当只需要按照一维索引取值时,非常方便。

https://avoid.overfit.cn/post/e4844e899c4d4600813be7d09e91b9ef

作者:Oliver S

目录
相关文章
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:
44 4
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
|
2月前
|
机器学习/深度学习 算法 PyTorch
【深度学习】TensorFlow面试题:什么是TensorFlow?你对张量了解多少?TensorFlow有什么优势?TensorFlow比PyTorch有什么不同?该如何选择?
关于TensorFlow面试题的总结,涵盖了TensorFlow的基本概念、张量的理解、TensorFlow的优势、数据加载方式、算法通用步骤、过拟合解决方法,以及TensorFlow与PyTorch的区别和选择建议。
112 2
|
2月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
18 0
|
4月前
|
算法 PyTorch 算法框架/工具
Pytorch - 张量转换拼接
使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。
|
4月前
|
存储 机器学习/深度学习 PyTorch
Pytorch-张量形状操作
PyTorch中,张量形状操作至关重要,如reshape用于改变维度而不变元素,transpose/permute用于维度交换,view改形状需内存连续,squeeze移除单维度,unsqueeze添加维度。这些函数帮助数据适应神经网络层间的转换。例如,reshape能调整数据适配层的输入,transpose用于矩阵转置或多维排列,而squeeze和unsqueeze则用于处理单维度。理解并熟练运用这些工具是深度学习中必要的技能。
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch-张量
PyTorch 是Facebook AI团队开发的深度学习框架,其核心是张量,它是同类型数据的多维数组。张量可以通过`torch.tensor()`、`torch.Tensor()`、指定类型如`IntTensor`等创建。张量操作包括线性(`torch.arange`, `torch.linspace`)、随机(`torch.randn`, `torch.manual_seed`)和全0/1张量(`torch.zeros`, `torch.ones`)。张量间可进行阿达玛积(逐元素相乘),类型转换用`type()`或`double()`。
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch-张量基础操作
张量是一个多维数组,它是标量、向量和矩阵概念的推广。在深度学习中,张量被广泛用于表示数据和模型参数。
|
4月前
|
并行计算 PyTorch 算法框架/工具
pytorch张量的创建
• 张量(Tensors)类似于NumPy的ndarrays ,但张量可以在GPU上进行计算。从本质上来说,PyTorch是一个处理张量的库。一个张量是一个数字、向量、矩阵或任何n维数组。
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch】-了解张量(Tensor)
【PyTorch】-了解张量(Tensor)
下一篇
无影云桌面