pytorch新手需要注意的隐晦操作Tensor,max,gather

简介: 文章移到:https://oldpan.me/archives/pytorch-rookie-care-tensor-max-gather pytorch中有很多操作比较隐晦,需要仔细研究结合一些例子才能知道如何操作,在此对这些进行总结!torch.

文章移到:https://oldpan.me/archives/pytorch-rookie-care-tensor-max-gather

pytorch中有很多操作比较隐晦,需要仔细研究结合一些例子才能知道如何操作,在此对这些进行总结!

torch.gather(input, dim, index, out=None) → Tensor

先看官方的介绍:
如果input是一个n维的tensor,size为 (x0,x1…,xi−1,xi,xi+1,…,xn−1),dim为i,然后index必须也为n维tensor,size为 (x0,x1,…,xi−1,y,xi+1,…,xn−1),其中y >= 1,最后输出的out与index的size是一样的。
意思就是按照一个指定的轴(维数)收集值
对于一个三维向量来说:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

参数:
input (Tensor) – 源tensor
dim (int) – 指定的轴数(维数)
index (LongTensor) – 需要聚集起来的数据的索引
out (Tensor, optional) – 目标tensor

看完介绍后,稍微思考一下,然后再看一个例子:

scores是一个计算出来的分数,类型为[torch.FloatTensor of size 5x1000]
而y_var是正确分数的索引,类型为[torch.LongTensor of size 5]
容易知道,这里有1000个类别,有5个输入图像,每个图像得出的分数中只有一个是正确的,正确的索引就在y_var中,这里要做的是将正确分数根据索引标号提取出来。

    scores = model(X_var)  # 分数
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze()  #进行提取

提取后的scores格式也为[torch.FloatTensor of size 5]
这里讲一下变化过程:
1、首先要知道之前的scores的size为[5,1000],而y_var的size为[5],scores为2维,y_var为1维不匹配,所以先用view将其展开为[5,1]的size,这样维数n就与scroes匹配了。
2、接下来进行gather,gather函数中第一个参数为1,意思是在第二维进行汇聚,也就是说通过y_var中的五个值来在scroes中第二维的5个1000中进行一一挑选,挑选出来后的size也为[5,1],然后再通过squeeze将那个一维去掉,最后结果为[5]

再看一个使用相同思想的例子

def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])
    print(s)
    print(y)
    print(s.gather(1, y.view(-1, 1)).squeeze())
gather_example()

结果为:

-0.9526  1.7607 -1.0142 -0.6761  0.3022
-0.8421  0.5325  0.4834  0.8441 -0.1592
 0.8786  2.6909  1.3635  0.1197  0.4031
-0.8397  1.4782  0.4514 -0.8381 -2.0638
[torch.FloatTensor of size 4x5]


 1
 2
 1
 3
[torch.LongTensor of size 4]


 1.7607
 0.4834
 2.6909
-0.8381
[torch.FloatTensor of size 4]

使用普通python函数实现的例子

假设一个numpy数组s的shape为 (N, C),y是一个shape为(N,)的numpy数组,内容为 0 <= y[i] < C 整数,然后我们使用s[np.arange(N), y] 来进行在s中挑选每一个和y索引对应的数字,其shape同样为(N,)

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

max函数需要注意的是,它是一个过载函数,函数参数不同函数的功能和返回值也不同。
当max函数中有维数参数的时候,它的返回值为两个,一个为最大值,另一个为最大值的索引

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4]
,
 2
 0
 0
 0
[torch.LongTensor of size 4]
)

Tensor隐晦操作

使用Tensor型数据进行比较的时候需要注意,如果比较的是其中的值,那么必须将其化为普通值再进行比较,即使是一维的单个数据,也要用[0]操作符来进行读取。
如果想要整个进行比较,建议使用torch.equal来进行比较

>>> apple = torch.Tensor([1,2,3])
>>> apple
Out[20]: 
 1
 2
 3
[torch.FloatTensor of size 3]
>>> apple[0]
Out[21]: 1.0
>>> banana = torch.Tensor([1])
>>> banana
Out[23]: 
 1
[torch.FloatTensor of size 1]
>>> banana[0]
Out[24]: 1.0
目录
相关文章
|
机器学习/深度学习 运维 数据可视化
chat GPT在常用的数据分析方法中的应用
ChatGPT在常用的数据分析方法中有多种应用,包括描述统计分析、探索性数据分析、假设检验、回归分析和聚类分析等。下面将详细介绍ChatGPT在这些数据分析方法中的应用。 1. 描述统计分析: 描述统计分析是对数据进行总结和描述的方法,包括计算中心趋势、离散程度和分布形状等指标。ChatGPT可以帮助你理解和计算这些描述统计指标。你可以向ChatGPT询问如何计算平均值、中位数、标准差和百分位数等指标,它可以给出相应的公式和计算方法。此外,ChatGPT还可以为你提供绘制直方图、箱线图和散点图等图表的方法,帮助你可视化数据的分布和特征。 2. 探索性数据分析: 探索性数据分析是对数据进行探
449 0
|
网络安全 Nacos 数据安全/隐私保护
nacos常见问题之使用默认用户名密码提示错误如何解决
Nacos是阿里云开源的服务发现和配置管理平台,用于构建动态微服务应用架构;本汇总针对Nacos在实际应用中用户常遇到的问题进行了归纳和解答,旨在帮助开发者和运维人员高效解决使用Nacos时的各类疑难杂症。
|
SQL 关系型数据库 MySQL
远程访问GitLab内置的PostgreSQL数据库
业务系统需要接入GitLab,业务系统以及GitLab都有一套各自的用户系统,需要实现同一套账户密码的话需要将数据同步给GitLab(主要是密码),然而由于GitLab安全策略,通过api进行同步GitLab用户数据并不满足需求,所以需要能直接访问GitLab数据库进行数据修改。
远程访问GitLab内置的PostgreSQL数据库
|
消息中间件 移动开发 NoSQL
Redis 协议 事务 发布订阅 异步连接
Redis 协议 事务 发布订阅 异步连接
|
8月前
|
人工智能 数据处理 语音技术
Pipecat实战:5步快速构建语音与AI整合项目,创建你的第一个多模态语音 AI 助手
Pipecat 是一个开源的 Python 框架,专注于构建语音和多模态对话代理,支持与多种 AI 服务集成,提供实时处理能力,适用于语音助手、企业服务等场景。
447 23
Pipecat实战:5步快速构建语音与AI整合项目,创建你的第一个多模态语音 AI 助手
|
Perl
解决Cocoapods重装或更新后版本不生效的问题
解决Cocoapods重装或更新后版本不生效的问题
441 1
|
存储 Apache 对象存储
MinIO是什么?
MinIO是什么?
628 0
|
Kubernetes Cloud Native 应用服务中间件
云原生之旅:Kubernetes集群部署与应用管理
【8月更文挑战第28天】在这篇文章中,我们将一起探索云原生技术的奇妙世界,尤其是围绕Kubernetes这一核心组件。从搭建一个基本的Kubernetes集群开始,到运行和管理容器化应用,每一步都将是一次深入浅出的学习旅程。无论你是初学者还是有经验的开发者,本指南都旨在为你提供清晰的操作步骤和必要的理论知识,让你能够自信地管理和扩展你的云原生应用。让我们携手开启这场技术冒险,共同见证云原生带来的无限可能。
|
监控 druid Java
SpringBoot 使用【druid-spring-boot-starter】集成 druid 监控数据库
SpringBoot 使用【druid-spring-boot-starter】集成 druid 监控数据库
767 0
|
开发工具 git
【Git】push代码时候报错,出现fatal: unable to access xxx Recv failure: Connection was reset
【Git】push代码时候报错,出现fatal: unable to access xxx Recv failure: Connection was reset
565 0