【Pytorch】torch.gather用法详解

简介:

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

沿指定的维收集值。

参数:

  • input (Tensor) –输入张量
  • dim (int) – 要索引的维
  • index (LongTensor) – 要收集的元素的索引
  • sparse_grad (bool, optional) – 如果为True,关于input 的梯度将是稀疏张量。
  • out (Tensor, optional) –输出张量

对于一维张量,输出由以下公式指定:

out[i] = input[index[i]]  # dim= 0

例如:

input_tensor= torch.tensor([1, 2])
index = torch.tensor([0, 0])
input[0]=1
input[1]=2

index[0]=0
index[0]=0
out = torch.gather(input, 0, index)
out[0]=input[index[0]]=input[0]=1
out[1]=input[index[1]]=input[0]=1

对于二维张量,输出由以下公式指定:

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

举个栗子:

input_tensor= torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
input[0][0]=1
input[0][1]=2
input[1][0]=3
input[1][1]=4

index[0][0]=0
index[0][1]=0
index[1][0]=1
index[1][1]=0

dim=0:

out = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[index[0][0]][0]=input[0][0]=1
out[0][1]=input[index[0][1]][1]=input[0][1]=2
out[1][0]=input[index[1][0]][0]=input[1][0]=3
out[1][1]=input[index[1][1]][1]=input[0][1]=2

dim=1:

out = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[0][index[0][0]]=input[0][0]=1
out[0][1]=input[0][index[0][1]]=input[0][0]=1
out[1][0]=input[1][index[1][0]]=input[1][1]=4
out[1][1]=input[1][index[1][1]]=input[1][0]=3

对于三维张量,同理:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

注意inputindex必须有相同的维度。out尺寸和index相同;inputindex之间不会广播。

对于d=dim,可以有index.size(d)< input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0],[2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 2])
tensor([[2, 1],
        [6, 4]])

index.size(d)> input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0, 1, 0], [2, 0, 2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 4])
tensor([[2, 1, 2, 1],
        [6, 4, 6, 4]])
相关文章
|
2月前
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
643 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
PyTorch 的 10 条内部用法
PyTorch 的 10 条内部用法
48 0
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
443 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
|
PyTorch 算法框架/工具
关于Pytorch中torch.manual_seed()用法
关于Pytorch中torch.manual_seed()用法
|
2月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
17天前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
17天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】16. Pytorch中神经网络模型的构造方法:Module、Sequential、ModuleList、ModuleDict的区别
【从零开始学习深度学习】16. Pytorch中神经网络模型的构造方法:Module、Sequential、ModuleList、ModuleDict的区别
|
2月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
36 1