对于一些任务,我们想从tensor中提取符合指定要求的数值,那么一般我们有两种方法,第一种是采用布尔索引
,第二种是使用masked_select()
方法来实现。
其实还有一种方法torch.where()
,但是这个与上述两个方法不同,上述两个方法会把我们需要的数值挑出来形成一个一维张量
,对于where我们会得到与原来形状一样
的tensor,所以本文只介绍上面两种方法。
方法一:采用布尔索引
该方法我们采用布尔索引进行提取,首先获取一个布尔矩阵mask
,标记每个位置是否符合我们要求,符合则为True,不符则为False,然后我们会把True的位置提取出来。
a = torch.randn(3, 4) print(a) mask = a > 0 print(mask) print(a[mask])
tensor([[ 0.5748, 1.4601, 1.8610, -0.8904], [-1.5891, -1.2431, 0.1356, -0.6111], [-0.5736, -0.7268, -0.2200, 0.4816]]) tensor([[ True, True, True, False], [False, False, True, False], [False, False, False, True]]) tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])
方法二:masked_select()
使用masked_select()方法同样可以实现,只需要将条件传入。
print(a.masked_select(a > 0))
tensor([0.5748, 1.4601, 1.8610, 0.1356, 0.4816])