pytorch使用布尔索引获取指定维度元素

简介: pytorch使用布尔索引获取指定维度元素

对于一些任务,我们想从tensor中提取符合指定要求的数值,那么一般我们有两种方法,第一种是采用布尔索引,第二种是使用masked_select()方法来实现。

其实还有一种方法torch.where(),但是这个与上述两个方法不同,上述两个方法会把我们需要的数值挑出来形成一个一维张量,对于where我们会得到与原来形状一样的tensor,所以本文只介绍上面两种方法。

方法一:采用布尔索引

该方法我们采用布尔索引进行提取,首先获取一个布尔矩阵mask,标记每个位置是否符合我们要求,符合则为True,不符则为False,然后我们会把True的位置提取出来。

a = torch.randn(3, 4)
print(a)
mask = a > 0
print(mask)
print(a[mask])
tensor([[ 0.5748,  1.4601,  1.8610, -0.8904],
        [-1.5891, -1.2431,  0.1356, -0.6111],
        [-0.5736, -0.7268, -0.2200,  0.4816]])
tensor([[ True,  True,  True, False],
        [False, False,  True, False],
        [False, False, False,  True]])
tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])

方法二:masked_select()

使用masked_select()方法同样可以实现,只需要将条件传入。

print(a.masked_select(a > 0))
tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])


目录
相关文章
|
9月前
|
机器学习/深度学习 自然语言处理 PyTorch
【NLP】深入了解PyTorch:功能与基本元素操作
【NLP】深入了解PyTorch:功能与基本元素操作
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习基础之Tensor的索引和切片讲解及实战(附源码 简单易懂)
PyTorch深度学习基础之Tensor的索引和切片讲解及实战(附源码 简单易懂)
104 0
|
PyTorch 算法框架/工具
Pytorch疑难小实验:Torch.max() Torch.min()在不同维度上的解释
Pytorch疑难小实验:Torch.max() Torch.min()在不同维度上的解释
128 0
|
PyTorch 算法框架/工具
Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式
Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式
205 0
|
机器学习/深度学习 数据采集 算法
基于Pytorch之深度学习模型数据类型和维度转换个人总结
基于Pytorch之深度学习模型数据类型和维度转换个人总结
281 0
基于Pytorch之深度学习模型数据类型和维度转换个人总结
|
PyTorch 算法框架/工具 索引
pytorch使用 ... 进行高级索引切片
pytorch使用 ... 进行高级索引切片
85 0
|
PyTorch 算法框架/工具 索引
pytorch使用None索引进行维度扩展
pytorch使用None索引进行维度扩展
157 0
|
PyTorch 算法框架/工具 索引
pytorch交换tensor的指定维度
pytorch交换tensor的指定维度
341 0
|
机器学习/深度学习 存储 数据挖掘
PyTorch: 张量的拼接、切分、索引
PyTorch: 张量的拼接、切分、索引
188 0
PyTorch: 张量的拼接、切分、索引
|
12天前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN