torch.argmax(dim=1)用法

简介: 1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。

一、torch.argmax()

(1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;

(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。


二、栗子

# -*- coding: utf-8 -*-
"""
Created on Fri Jan  7 15:05:09 2022
@author: 86493
"""
import torch
a=torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b=torch.argmax(a,dim=1)
print(a)
print(a.shape)
print(b)

(1)这个例子,tensor(2, 3, 4),因为是dim=1,即将第二维度去掉,变成tensor(2, 4),将每一个3x4数组,变成1x4数组。

[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]

如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:

tensor([[[ 1,  5,  5,  2],
         [ 9, -6,  2,  8],
         [-3,  7, -9,  1]],
        [[-1,  7, -5,  2],
         [ 9,  6,  2,  8],
         [ 3,  7,  9,  1]]])
torch.Size([2, 3, 4])
tensor([[1, 2, 0, 1],
        [1, 0, 2, 1]])

(1)如果改成dim=2,即将第三维去掉,即取每一行的最大值对应的下标,结果为tensor(2, 3)

import torch
a=torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b=torch.argmax(a,dim=2)
print(b)
print(a.shape)
"""
tensor([[2, 0, 1],
        [1, 0, 2]])
torch.Size([2, 3, 4])
"""
相关文章
|
机器学习/深度学习 运维 安全
多分类机器学习中数据不平衡的处理(NSL-KDD 数据集+LightGBM)
多分类机器学习中数据不平衡的处理(NSL-KDD 数据集+LightGBM)
多分类机器学习中数据不平衡的处理(NSL-KDD 数据集+LightGBM)
|
机器学习/深度学习 存储 人工智能
基于内容的图像检索系统设计与实现(1)
基于内容的图像检索系统设计与实现(1)
基于内容的图像检索系统设计与实现(1)
|
IDE 测试技术 程序员
|
Java 项目管理 Maven
Java一分钟之-Maven profiles与dependencyManagement
【6月更文挑战第5天】本文探讨了Maven的profiles和dependencyManagement特性在Java项目管理中的应用,包括基本概念和常见问题。Profiles用于根据不同环境激活配置,易错点在于忘记激活,应通过命令行或设置默认profile来避免。dependencyManagement集中管理依赖版本,过度依赖会导致子模块灵活性降低,应合理使用。结合两者,可在不同环境中控制依赖版本,提高项目配置效率。
391 8
|
机器学习/深度学习 算法 PyTorch
卷积神经网络(CNN)——基础知识整理
卷积神经网络(CNN)——基础知识整理
332 2
|
机器学习/深度学习 自然语言处理
自注意力机制(Self-Attention Mechanism)
自注意力机制(Self-Attention Mechanism)
1115 6
|
搜索推荐 测试技术 UED
AIGC赋能游戏开发全流程
【1月更文挑战第14天】AIGC赋能游戏开发全流程
568 2
AIGC赋能游戏开发全流程
|
机器学习/深度学习 编解码 算法
【论文解析】CFPNet:用于目标检测的集中特征金字塔
【论文解析】CFPNet:用于目标检测的集中特征金字塔
749 0
【论文解析】CFPNet:用于目标检测的集中特征金字塔
|
机器学习/深度学习 数据可视化 计算机视觉
YOLOv5改进 | 2023注意力篇 | MSDA多尺度空洞注意力(附多位置添加教程)
YOLOv5改进 | 2023注意力篇 | MSDA多尺度空洞注意力(附多位置添加教程)
361 0
|
JavaScript Go Python
我愿称之为最容易上手的编程语言——Yaklang(I)
我愿称之为最容易上手的编程语言——Yaklang(I)
332 0