一、函数参数
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
torch.gather()函数:利用index来索引input特定位置的数值
dim = 1表示横向。
对于三维张量,其output是:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
二、小栗子1
比如现在有4个句子(句子长度不一),现在的序列标注问题需要给每个单词都标上一个标签,标签如下:
input = [ [2, 3, 4, 5], [1, 4, 3], [4, 2, 2, 5, 7], [1] ]
长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。在NLP中,一般需要对不同长度的句子进行padding到相同长度(用0进行padding),所以padding后的结果:
input = [ [2, 3, 4, 5, 0, 0], [1, 4, 3, 0, 0, 0], [4, 2, 2, 5, 7, 0], [1, 0, 0, 0, 0, 0] ]
# -*- coding: utf-8 -*- """ Created on Sun Dec 12 15:49:27 2021 @author: 86493 """ import torch input = [ [2, 3, 4, 5, 0, 0], [1, 4, 3, 0, 0, 0], [4, 2, 2, 5, 7, 0], [1, 0, 0, 0, 0, 0] ] input = torch.tensor(input) length = torch.LongTensor([[4], [3], [5], [1]]) # index之所以减1,是因为序列维度从0开始计算的 out = torch.gather(input, 1, length - 1) print(out)
out的结果为如下,比如length的第一行是[4]
,即找出input的第一行的第4个元素为5(这里length-1
后就是下标从1开始计算了)。
tensor([[5], [3], [7], [1]])
三、小栗子2
如果每行需要索引多个元素:
>>> t = torch.Tensor([[1,2],[3,4]]) 1 2 3 4 >>> torch.gather(t,1,torch.LongTensor([[0,0],[1,0]]) 1 1 4 3 [torch.FloatTensor of size 2x2]