一、torch.argmax()
(1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。
二、栗子
# -*- coding: utf-8 -*- """ Created on Fri Jan 7 15:05:09 2022 @author: 86493 """ import torch a=torch.tensor([ [ [1, 5, 5, 2], [9, -6, 2, 8], [-3, 7, -9, 1] ], [ [-1, 7, -5, 2], [9, 6, 2, 8], [3, 7, 9, 1] ]]) b=torch.argmax(a,dim=1) print(a) print(a.shape) print(b)
(1)这个例子,tensor(2, 3, 4)
,因为是dim=1
,即将第二维度去掉,变成tensor(2, 4)
,将每一个3x4数组,变成1x4数组。
[1, 5, 5, 2], [9, -6, 2, 8], [-3, 7, -9, 1]
如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:
tensor([[[ 1, 5, 5, 2], [ 9, -6, 2, 8], [-3, 7, -9, 1]], [[-1, 7, -5, 2], [ 9, 6, 2, 8], [ 3, 7, 9, 1]]]) torch.Size([2, 3, 4]) tensor([[1, 2, 0, 1], [1, 0, 2, 1]])
(1)如果改成dim=2
,即将第三维去掉,即取每一行的最大值对应的下标,结果为tensor(2, 3)
。
import torch a=torch.tensor([ [ [1, 5, 5, 2], [9, -6, 2, 8], [-3, 7, -9, 1] ], [ [-1, 7, -5, 2], [9, 6, 2, 8], [3, 7, 9, 1] ]]) b=torch.argmax(a,dim=2) print(b) print(a.shape) """ tensor([[2, 0, 1], [1, 0, 2]]) torch.Size([2, 3, 4]) """