torch.gather()详解

简介: torch.gather()函数:利用index来索引input特定位置的数值dim = 1表示横向。

一、函数参数

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

torch.gather()函数:利用index来索引input特定位置的数值

dim = 1表示横向。


对于三维张量,其output是:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

二、小栗子1

比如现在有4个句子(句子长度不一),现在的序列标注问题需要给每个单词都标上一个标签,标签如下:

input = [
    [2, 3, 4, 5],
    [1, 4, 3],
    [4, 2, 2, 5, 7],
    [1]
]

长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。在NLP中,一般需要对不同长度的句子进行padding到相同长度(用0进行padding),所以padding后的结果:

input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 12 15:49:27 2021
@author: 86493
"""
import torch
input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
length = torch.LongTensor([[4], [3], [5], [1]])
# index之所以减1,是因为序列维度从0开始计算的
out = torch.gather(input, 1, length - 1)
print(out)

ut的结果为如下,比如length的第一行是[4],即找出input的第一行的第4个元素为5(这里length-1后就是下标从1开始计算了)。

tensor([[5],
        [3],
        [7],
        [1]])

三、小栗子2

如果每行需要索引多个元素:

>>> t = torch.Tensor([[1,2],[3,4]])
1  2
3  4
>>> torch.gather(t,1,torch.LongTensor([[0,0],[1,0]])
1  1
4  3
[torch.FloatTensor of size 2x2]
相关文章
|
自然语言处理 前端开发 JavaScript
【第52期】一文读懂React国际化 (i18n)
【第52期】一文读懂React国际化 (i18n)
1328 1
|
机器学习/深度学习 自然语言处理 并行计算
Self-Attention 原理与代码实现
Self-Attention 原理与代码实现
950 0
|
弹性计算 缓存 安全
阿里云ECS服务器搭建FTP服务
阿里云ECS服务器搭建FTP服务
2801 0
阿里云ECS服务器搭建FTP服务
|
10月前
|
机器学习/深度学习 测试技术 知识图谱
DeepSeek-R1:Incentivizing Reasoning Capability in LLMs via Reinforcement Learning论文解读
DeepSeek团队推出了第一代推理模型DeepSeek-R1-Zero和DeepSeek-R1。DeepSeek-R1-Zero通过大规模强化学习训练,展示了卓越的推理能力,但存在可读性和语言混合问题。为此,团队引入多阶段训练和冷启动数据,推出性能与OpenAI-o1-1217相当的DeepSeek-R1,并开源了多个密集模型。实验表明,DeepSeek-R1在多项任务上表现出色,尤其在编码任务上超越多数模型。未来研究将聚焦提升通用能力和优化提示工程等方向。
684 16
|
12月前
|
人工智能 弹性计算 监控
分布式大模型训练的性能建模与调优
阿里云智能集团弹性计算高级技术专家林立翔分享了分布式大模型训练的性能建模与调优。内容涵盖四大方面:1) 大模型对AI基础设施的性能挑战,强调规模增大带来的显存和算力需求;2) 大模型训练的性能分析和建模,介绍TOP-DOWN和bottom-up方法论及工具;3) 基于建模分析的性能优化,通过案例展示显存预估和流水线失衡优化;4) 宣传阿里云AI基础设施,提供高效算力集群、网络及软件支持,助力大模型训练与推理。
|
关系型数据库 MySQL Java
Django学习二:配置mysql,创建model实例,自动创建数据库表,对mysql数据库表已经创建好的进行直接操作和实验。
这篇文章是关于如何使用Django框架配置MySQL数据库,创建模型实例,并自动或手动创建数据库表,以及对这些表进行操作的详细教程。
504 0
Django学习二:配置mysql,创建model实例,自动创建数据库表,对mysql数据库表已经创建好的进行直接操作和实验。
|
Python
Polars实践(2):阿里天池——淘宝用户购物行为分析
Polars实践(2):阿里天池——淘宝用户购物行为分析
307 0
|
SQL 中间件 API
Flask框架在Python面试中的应用与实战
【4月更文挑战第18天】**Flask是Python的轻量级Web框架,以其简洁API和强大扩展性受欢迎。本文深入探讨了面试中关于Flask的常见问题,包括路由、Jinja2模板、数据库操作、中间件和错误处理。同时,提到了易错点,如路由冲突、模板安全、SQL注入,以及请求上下文管理。通过实例代码展示了如何创建和管理数据库、使用表单以及处理请求。掌握这些知识将有助于在面试中展现Flask技能。**
434 1
Flask框架在Python面试中的应用与实战
|
固态存储 Ubuntu Linux
Linux(29) 多线程快速解压缩|删除|监视大型文件
Linux(29) 多线程快速解压缩|删除|监视大型文件
1543 1
|
机器学习/深度学习 存储 自然语言处理
深度探索自适应学习率调整:从传统方法到深度学习优化器
【5月更文挑战第15天】 在深度学习的复杂网络结构与海量数据中,学习率作为模型训练的关键超参数,其调整策略直接影响着模型的收敛速度与最终性能。传统的学习率调整方法,如固定学习率、学习率衰减等,虽然简单易行,但在多样化的任务面前往往显得力不从心。近年来,随着自适应学习率技术的兴起,一系列创新的优化器如Adam、RMSProp和Adagrad等应运而生,它们通过引入自适应机制动态调整学习率,显著改善了模型的训练效率与泛化能力。本文将深入剖析传统学习率调整方法的原理与局限性,并详细介绍当前主流的自适应学习率优化器,最后探讨未来可能的发展方向。

热门文章

最新文章