pytorch中torch.where()使用方法

简介: pytorch中torch.where()使用方法

torch.where()

该函数的用处就是利用判断条件提取指定元素,满足需求的会提出,不符合要求的按照我们设定的值进行填充,说的简单就是会逐位置进行判断条件是否满足,如果满足该位置的值为x对应位置的值,如果不符就是y对应位置的值。

该函数需要传入三个参数:

  • condition:我们需要判断的条件,例如 x > 0
  • x:候选张量x
  • y:候选张量y
a = torch.randn(3, 4)
print(a)
b = torch.arange(12, dtype=torch.float).reshape(3, 4)
print(b)
print(torch.where(a > 0, a, b))
tensor([[ 0.8974,  1.1078, -0.8711,  0.9044],
        [ 0.1937, -0.3344, -0.1034, -0.0874],
        [-0.4632, -1.5329,  1.0019, -0.8950]])
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
tensor([[ 0.8974,  1.1078,  2.0000,  0.9044],
        [ 0.1937,  5.0000,  6.0000,  7.0000],
        [ 8.0000,  9.0000,  1.0019, 11.0000]])

该函数我们定义了两个张量,分别是一个正态分布的a,另外一个是0-11的张量,我们的判断条件是a>0,如果a对应位置的值大于0,那么返回的tensor对应的位置还是a的值,如果小于等于0,那么该位置的值就是张量b对应的位置的值。


目录
相关文章
|
PyTorch 算法框架/工具
pytorch中torch.clamp()使用方法
pytorch中torch.clamp()使用方法
593 0
pytorch中torch.clamp()使用方法
|
并行计算 PyTorch 测试技术
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-2
由于要进行 tensor 的学习,因此,我们先导入我们需要的库。
|
机器学习/深度学习 人工智能 自然语言处理
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-1
PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebook 的人工智能小组开发,不仅能够实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架如 TensorFlow 都不支持的。
|
机器学习/深度学习 人工智能 PyTorch
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch中nn.Parameter()使用方法
pytorch中nn.Parameter()使用方法
1433 1
|
PyTorch 算法框架/工具
pytorch中ImageFolder()使用方法
pytorch中ImageFolder()使用方法
354 0
pytorch中ImageFolder()使用方法
|
PyTorch 算法框架/工具 异构计算
基于Pytorch查看本地或者远程服务器GPU及使用方法
基于Pytorch查看本地或者远程服务器GPU及使用方法
508 0
基于Pytorch查看本地或者远程服务器GPU及使用方法
|
PyTorch 算法框架/工具
pytorch中keepdim参数归并操作使用方法
pytorch中keepdim参数归并操作使用方法
155 0
|
PyTorch 算法框架/工具
pytorch中meter.AverageValueMeter()使用方法
pytorch中meter.AverageValueMeter()使用方法
289 0
|
PyTorch 算法框架/工具
pytorch中meter.ClassErrorMeter()使用方法
pytorch中meter.ClassErrorMeter()使用方法
174 0