torch.where()
该函数的用处就是利用判断条件提取指定元素,满足需求的会提出,不符合要求的按照我们设定的值进行填充,说的简单就是会逐位置进行判断条件是否满足,如果满足该位置的值为x对应位置的值,如果不符就是y对应位置的值。
该函数需要传入三个参数:
- condition:我们需要判断的条件,例如 x > 0
- x:候选张量x
- y:候选张量y
a = torch.randn(3, 4) print(a) b = torch.arange(12, dtype=torch.float).reshape(3, 4) print(b) print(torch.where(a > 0, a, b))
tensor([[ 0.8974, 1.1078, -0.8711, 0.9044], [ 0.1937, -0.3344, -0.1034, -0.0874], [-0.4632, -1.5329, 1.0019, -0.8950]]) tensor([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]]) tensor([[ 0.8974, 1.1078, 2.0000, 0.9044], [ 0.1937, 5.0000, 6.0000, 7.0000], [ 8.0000, 9.0000, 1.0019, 11.0000]])
该函数我们定义了两个张量,分别是一个正态分布的a,另外一个是0-11的张量,我们的判断条件是a>0,如果a对应位置的值大于0,那么返回的tensor对应的位置还是a的值,如果小于等于0,那么该位置的值就是张量b对应的位置的值。