torch.gather()详解

简介: torch.gather()函数:利用index来索引input特定位置的数值dim = 1表示横向。

一、函数参数

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

torch.gather()函数:利用index来索引input特定位置的数值

dim = 1表示横向。


对于三维张量,其output是:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

二、小栗子1

比如现在有4个句子(句子长度不一),现在的序列标注问题需要给每个单词都标上一个标签,标签如下:

input = [
    [2, 3, 4, 5],
    [1, 4, 3],
    [4, 2, 2, 5, 7],
    [1]
]

长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。在NLP中,一般需要对不同长度的句子进行padding到相同长度(用0进行padding),所以padding后的结果:

input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 12 15:49:27 2021
@author: 86493
"""
import torch
input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
length = torch.LongTensor([[4], [3], [5], [1]])
# index之所以减1,是因为序列维度从0开始计算的
out = torch.gather(input, 1, length - 1)
print(out)

ut的结果为如下,比如length的第一行是[4],即找出input的第一行的第4个元素为5(这里length-1后就是下标从1开始计算了)。

tensor([[5],
        [3],
        [7],
        [1]])

三、小栗子2

如果每行需要索引多个元素:

>>> t = torch.Tensor([[1,2],[3,4]])
1  2
3  4
>>> torch.gather(t,1,torch.LongTensor([[0,0],[1,0]])
1  1
4  3
[torch.FloatTensor of size 2x2]
相关文章
|
4月前
tf.zeros(), tf.zeros_like(), tf.ones(),tf.ones_like()
【8月更文挑战第11天】tf.zeros(), tf.zeros_like(), tf.ones(),tf.ones_like()。
47 5
|
7月前
|
存储 PyTorch 算法框架/工具
torch.Storage()是什么?和torch.Tensor()有什么区别?
torch.Storage()是什么?和torch.Tensor()有什么区别?
61 1
|
机器学习/深度学习 PyTorch API
Torch
Torch是一个用于构建深度学习模型的开源机器学习库,它基于Lua编程语言。然而,由于PyTorch的出现,现在通常所说的"torch"指的是PyTorch。PyTorch是一个基于Torch的Python库,它提供了一个灵活而高效的深度学习框架。
295 1
|
机器学习/深度学习 PyTorch 算法框架/工具
|
PyTorch 算法框架/工具 开发者
RuntimeError: Can‘t call numpy() on Variable that requires grad. Use var.detach().numpy()
出现这个现象的原因是:待转换类型的PyTorch Tensor变量带有梯度,直接将其转换为numpy数据将破坏计算图,因此numpy拒绝进行数据转换,实际上这是对开发者的一种提醒。如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加detach()调用。
195 0
|
PyTorch 算法框架/工具 Python
查看torch中的所有函数
hspmm hstack hub hypot i0 i0_ igamma igammac iinfo imag import_ir_module import_ir_module_from_buffer index_add index_copy index_fill index_put index_put_ index_select init_num_t
361 0
torch中如何将tensor([[1, 2, 3]]) 和 tensor([4]) 合并成 tensor([[1,2,3,4]])
可以使用 torch.cat() 方法将两个张量沿着指定的维度进行拼接
495 0
怎么将[tensor([[ 1, 2]]), tensor([[5, 6]]), tensor([[9, 10]])] 合并成 tensor([[1,2],[3,4],[5,6]])
可以先使用 torch.cat() 函数将列表中的张量在第0维(行)上进行拼接,然后再使用 .view() 函数将形状调整为需要的形状。
195 0
|
PyTorch 算法框架/工具 异构计算
pytorch中的torch.manual_seed()
random.seed(rand_seed)随机数种子,当使用random.seed(rand_seed)设定好种子之后,其中rand_seed可以是任意数字,比如10,那么每次调用生成的随机数将会是同一个。
158 0