可以使用PyTorch的索引操作符[]
和张量的gather()
方法来利用张量B的元素作为索引来检索张量a的元素。
假设a是一个张量,B是一个一维的长为n的张量,可以通过以下方式获取从a中检索出的值:
import torch # 创建张量a和张量B a = torch.tensor([[1, 2], [3, 4], [5, 6]]) B = torch.tensor([1, 0, 1]) # 使用[]索引操作符进行检索 result1 = a[B] # 使用gather()方法进行检索 result2 = a.gather(0, B.unsqueeze(1).expand(-1, a.shape[1])) print(result1) print(result2)
上述代码中,使用[]
索引操作符时,张量B被直接作为下标传递给a进行检索,返回的结果是一个与B形状相同的张量,其中每个元素都对应于a中对应下标的元素。使用gather()
方法时,需要将B扩展成一个二维张量,并在行维度上进行检索,返回的结果也是一个与B形状相同的张量。