pytorch 给定概率分布的张量,如何利用这个概率进行重复\不重复采样?

简介: 在 PyTorch 中,可以使用 torch.distributions.Categorical 来基于给定的概率分布进行采样。

在 PyTorch 中,可以使用 torch.distributions.Categorical 来基于给定的概率分布进行采样。

下面是一个示例:

import torch
import torch.distributions as dist
# 创建一个大小为 (1, n) 的一行张量表示概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 使用 Categorical 分布进行采样
m = dist.Categorical(probs)
samples = m.sample((10,))   # 生成 10 个样本
print(samples)

在上述示例中,我们首先创建了一个大小为 (1, n) 的一行张量 probs 表示概率分布。然后,我们使用 torch.distributions.Categorical 类来创建一个 Categorical 分布对象 m。该分布由给定的概率分布 probs 定义。接下来,我们使用 sample() 方法从分布中生成 10 个样本,并将其存储在 samples 中。最后,我们打印输出了生成的样本。

请注意,sample() 方法返回的张量的形状由传递给它的参数决定。在上述示例中,我们用 (10,) 指定了要生成 10 个样本,所以返回的张量的形状为 (10,)。如果没有指定参数,则默认生成单个样本。此外,Categorical 分布还提供了 log_prob() 方法,用于计算给定样本的对数概率。



在上述基础上,继续要求不重复抽样


如果需要基于给定的概率分布进行不重复采样,可以使用 torch.multinomial() 函数以及循环来实现。

下面是一个示例:

import torch
# 创建一个大小为 (1, n) 的一行张量表示概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 确定要生成的样本数量
num_samples = 3
# 初始化空列表用于存储已经出现过的样本索引
sampled_indices = []
# 循环采样直到得到足够数量的不同样本
while len(sampled_indices) < num_samples:
    # 使用 multinomial 函数生成一个样本索引
    index = torch.multinomial(probs, 1).item()
    # 如果该索引在 sampled_indices 中没有出现过,则将其加入 sampled_indices 列表中
    if index not in sampled_indices:
        sampled_indices.append(index)
# 将采样结果存储在样本张量中
samples = probs[sampled_indices]
print(samples)

在上述示例中,我们首先创建了一个大小为 (1, n) 的一行张量 probs 表示概率分布。接下来,我们确定要生成的样本数量,并初始化一个空列表 sampled_indices,用于存储已经出现过的样本索引。然后,我们使用 torch.multinomial() 函数生成一个样本索引,并判断该索引是否已经在 sampled_indices 中出现过。如果该索引没有出现过,则将其加入 sampled_indices 列表中,否则继续循环生成新的样本索引。最后,我们将采样结果存储在样本张量 samples 中,并打印输出。

请注意,上述代码只实现了一种简单的不重复抽样方法。如果需要进行更高效的不重复抽样,可以使用其他算法,如 Fisher-Yates shuffle 算法等。



或者


可以使用PyTorch的torch.multinomial()函数来进行不重复抽样。该函数接受一个概率分布张量和要抽取样本的数量作为输入,并返回一个整数张量,表示从概率分布中抽取的样本的索引。如果希望进行不重复抽样,可以在调用torch.multinomial()函数时将参数replacement设置为False。例如:

import torch
# 创建概率分布张量
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 进行不重复抽样
samples = torch.multinomial(probs, num_samples=3, replacement=False)
print(samples)

输出结果将是一个长度为3的整数张量,表示从概率分布中抽取的三个不重复样本的索引。

相关文章
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
|
2月前
|
存储 并行计算 PyTorch
探索PyTorch:张量数值计算
探索PyTorch:张量数值计算
|
2月前
|
机器学习/深度学习 并行计算 PyTorch
探索PyTorch:张量的创建和数值计算
探索PyTorch:张量的创建和数值计算
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
探索PyTorch:张量的类型转换,拼接操作,索引操作,形状操作
|
2月前
|
PyTorch 算法框架/工具 Python
Pytorch学习笔记(十):Torch对张量的计算、Numpy对数组的计算、它们之间的转换
这篇文章是关于PyTorch张量和Numpy数组的计算方法及其相互转换的详细学习笔记。
46 0
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:
82 4
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
|
5月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8758 3
|
4月前
|
机器学习/深度学习 算法 PyTorch
【深度学习】TensorFlow面试题:什么是TensorFlow?你对张量了解多少?TensorFlow有什么优势?TensorFlow比PyTorch有什么不同?该如何选择?
关于TensorFlow面试题的总结,涵盖了TensorFlow的基本概念、张量的理解、TensorFlow的优势、数据加载方式、算法通用步骤、过拟合解决方法,以及TensorFlow与PyTorch的区别和选择建议。
284 2
|
4月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
35 0
|
6月前
|
算法 PyTorch 算法框架/工具
Pytorch - 张量转换拼接
使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。