【Pytorch】torch.gather用法详解

简介:

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

沿指定的维收集值。

参数:

  • input (Tensor) –输入张量
  • dim (int) – 要索引的维
  • index (LongTensor) – 要收集的元素的索引
  • sparse_grad (bool, optional) – 如果为True,关于input 的梯度将是稀疏张量。
  • out (Tensor, optional) –输出张量

对于一维张量,输出由以下公式指定:

out[i] = input[index[i]]  # dim= 0

例如:

input_tensor= torch.tensor([1, 2])
index = torch.tensor([0, 0])
input[0]=1
input[1]=2

index[0]=0
index[0]=0
out = torch.gather(input, 0, index)
out[0]=input[index[0]]=input[0]=1
out[1]=input[index[1]]=input[0]=1

对于二维张量,输出由以下公式指定:

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

举个栗子:

input_tensor= torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
input[0][0]=1
input[0][1]=2
input[1][0]=3
input[1][1]=4

index[0][0]=0
index[0][1]=0
index[1][0]=1
index[1][1]=0

dim=0:

out = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[index[0][0]][0]=input[0][0]=1
out[0][1]=input[index[0][1]][1]=input[0][1]=2
out[1][0]=input[index[1][0]][0]=input[1][0]=3
out[1][1]=input[index[1][1]][1]=input[0][1]=2

dim=1:

out = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[0][index[0][0]]=input[0][0]=1
out[0][1]=input[0][index[0][1]]=input[0][0]=1
out[1][0]=input[1][index[1][0]]=input[1][1]=4
out[1][1]=input[1][index[1][1]]=input[1][0]=3

对于三维张量,同理:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

注意inputindex必须有相同的维度。out尺寸和index相同;inputindex之间不会广播。

对于d=dim,可以有index.size(d)< input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0],[2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 2])
tensor([[2, 1],
        [6, 4]])

index.size(d)> input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0, 1, 0], [2, 0, 2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 4])
tensor([[2, 1, 2, 1],
        [6, 4, 6, 4]])
相关文章
|
6月前
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
1028 0
|
6月前
|
机器学习/深度学习 算法 PyTorch
PyTorch 的 10 条内部用法
PyTorch 的 10 条内部用法
62 0
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
510 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
|
PyTorch 算法框架/工具
关于Pytorch中torch.manual_seed()用法
关于Pytorch中torch.manual_seed()用法
|
1月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
211 2
|
1月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
60 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
1月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
60 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
2月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
145 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
2月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
43 3
PyTorch 模型调试与故障排除指南
|
1月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法

热门文章

最新文章

下一篇
无影云桌面