【Pytorch】torch.gather用法详解

简介:

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

沿指定的维收集值。

参数:

  • input (Tensor) –输入张量
  • dim (int) – 要索引的维
  • index (LongTensor) – 要收集的元素的索引
  • sparse_grad (bool, optional) – 如果为True,关于input 的梯度将是稀疏张量。
  • out (Tensor, optional) –输出张量

对于一维张量,输出由以下公式指定:

out[i] = input[index[i]]  # dim= 0

例如:

input_tensor= torch.tensor([1, 2])
index = torch.tensor([0, 0])
input[0]=1
input[1]=2

index[0]=0
index[0]=0
out = torch.gather(input, 0, index)
out[0]=input[index[0]]=input[0]=1
out[1]=input[index[1]]=input[0]=1

对于二维张量,输出由以下公式指定:

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

举个栗子:

input_tensor= torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
input[0][0]=1
input[0][1]=2
input[1][0]=3
input[1][1]=4

index[0][0]=0
index[0][1]=0
index[1][0]=1
index[1][1]=0

dim=0:

out = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[index[0][0]][0]=input[0][0]=1
out[0][1]=input[index[0][1]][1]=input[0][1]=2
out[1][0]=input[index[1][0]][0]=input[1][0]=3
out[1][1]=input[index[1][1]][1]=input[0][1]=2

dim=1:

out = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[0][index[0][0]]=input[0][0]=1
out[0][1]=input[0][index[0][1]]=input[0][0]=1
out[1][0]=input[1][index[1][0]]=input[1][1]=4
out[1][1]=input[1][index[1][1]]=input[1][0]=3

对于三维张量,同理:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

注意inputindex必须有相同的维度。out尺寸和index相同;inputindex之间不会广播。

对于d=dim,可以有index.size(d)< input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0],[2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 2])
tensor([[2, 1],
        [6, 4]])

index.size(d)> input.size(d)

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0, 1, 0], [2, 0, 2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 4])
tensor([[2, 1, 2, 1],
        [6, 4, 6, 4]])
相关文章
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
2127 0
|
PyTorch 算法框架/工具
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
769 2
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
|
机器学习/深度学习 算法 PyTorch
PyTorch 的 10 条内部用法
PyTorch 的 10 条内部用法
|
PyTorch 算法框架/工具
关于Pytorch中torch.manual_seed()用法
关于Pytorch中torch.manual_seed()用法
|
2月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
178 1
|
6月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
939 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
1月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
138 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
8月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
664 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
3月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
241 9

热门文章

最新文章

推荐镜像

更多