torch.argmax(dim=1)用法

简介: 1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。

一、torch.argmax()

(1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;

(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。


二、栗子

# -*- coding: utf-8 -*-
"""
Created on Fri Jan  7 15:05:09 2022
@author: 86493
"""
import torch
a=torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b=torch.argmax(a,dim=1)
print(a)
print(a.shape)
print(b)

(1)这个例子,tensor(2, 3, 4),因为是dim=1,即将第二维度去掉,变成tensor(2, 4),将每一个3x4数组,变成1x4数组。

[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]

如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:

tensor([[[ 1,  5,  5,  2],
         [ 9, -6,  2,  8],
         [-3,  7, -9,  1]],
        [[-1,  7, -5,  2],
         [ 9,  6,  2,  8],
         [ 3,  7,  9,  1]]])
torch.Size([2, 3, 4])
tensor([[1, 2, 0, 1],
        [1, 0, 2, 1]])

(1)如果改成dim=2,即将第三维去掉,即取每一行的最大值对应的下标,结果为tensor(2, 3)

import torch
a=torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b=torch.argmax(a,dim=2)
print(b)
print(a.shape)
"""
tensor([[2, 0, 1],
        [1, 0, 2]])
torch.Size([2, 3, 4])
"""
相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
torch.nn.Linear的使用方法
torch.nn.Linear的使用方法
148 0
|
19天前
|
PyTorch 算法框架/工具
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
PyTorch中的`nn.AdaptiveAvgPool2d()`函数用于实现自适应平均池化,能够将输入特征图调整到指定的输出尺寸,而不需要手动计算池化核大小和步长。
66 1
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
|
20天前
|
TensorFlow 算法框架/工具
Tensorflow error(二):x and y must have the same dtype, got tf.float32 != tf.int32
本文讨论了TensorFlow中的一个常见错误,即在计算过程中,变量的数据类型(dtype)不一致导致的错误,并通过使用`tf.cast`函数来解决这个问题。
15 0
|
3月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
48 0
|
4月前
|
PyTorch 算法框架/工具 机器学习/深度学习
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
630 0
|
PyTorch 算法框架/工具 异构计算
Pytorch出现RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)
这个问题的主要原因是输入的数据类型与网络参数的类型不符。
558 0
|
存储 测试技术
测试模型时,为什么要with torch.no_grad(),为什么要model.eval(),如何使用with torch.no_grad(),model.eval(),同时使用还是只用其中之一
在测试模型时,我们通常使用with torch.no_grad()和model.eval()这两个方法来确保模型在评估过程中的正确性和效率。
939 0
|
PyTorch 算法框架/工具
torch.split 的用法
这将返回一个元组,包含 3 个大小分别为 (6, 2)、(6, 2) 和 (6, 4) 的张量。 需要注意的是,当给定的拆分大小不等于张量在指定维度上的大小时,torch.split() 会引发一个异常。
459 0