PyTorch: 张量的拼接、切分、索引

简介: PyTorch: 张量的拼接、切分、索引
本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。

一、张量拼接与切分

1.1 torch.cat

功能:将张量按维度dim 进行拼接

  • tensors : 张量序列
  • dim: 要拼接的维度
 t = torch.ones((2, 3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)

    print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])

(2,3) -> (2,6)

这里的dim维度与axis相同,0代表列,1代表行。

1.2 torch.stack

功能:在新创建的维度 dim 上进行拼接(会拓宽原有的张量维度)

  • tensors:张量序列
  • dim:要拼接的维度

    t = torch.ones((2, 3))

    t_stack = torch.stack([t, t, t], dim=2)

    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

可见,它在新的维度上进行了拼接。

参数[t, t, t]的意思就是在第n个维度上拼接成这个样子。

t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 3, 3])
# 在第二维度上进行了拼接
Process finished with exit code 0

1.3 torch.chunk

功能:将张量按维度 dim 进行平均切分

返回值:张量列表

注意事项:若不能整除,最后一份张量小于其他张量。

  • input : 要切分的张量
  • chunks 要切分的份数
  • dim 要切分的维度

code

    # cut into 3
    a = torch.ones((2, 7))  # 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

可知,切分是7/3向上取整,每份是3,最后剩下的维度直接输出即可。

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

1.4 torch.split

torch.split(Tensor, split_size_or_sections, dim)

功能:将张量按维度 dim 进行切分

返回值:张量列表

  • tensor : 要切分的张量
  • split_size_or_sections 为 int 时,表示
    每一份的长度;为 list 时,按 list 元素切分
  • dim 要切分的维度

code:

    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

是按照指定长度list进行切分的。注意list中长度总和必须为原张量在改维度的大小,不然会报错。

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

二、张量索引

2.1 torch.index_select

torch.index_select(input, dim, index, out=None)

功能:在维度dim 上,按 index 索引数据

返回值:依index 索引数据拼接的张量

  • input : 要索引的张量
  • dim 要索引的维度
  • index 要索引数据的序号

code:

    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # if float will report an error
    t_select = torch.index_select(t, dim=0, index=idx)
    print(idx)
    print("t:\n{}\nt_select:\n{}".format(t, t_select))

可见idx是一个存储序号的张量,而torch.index_select通过该张量索引原tensor并且拼接返回。

tensor([0, 2])
t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
t_select:
tensor([[4, 5, 0],
        [2, 5, 8]])

2.2 torch.masked_select

功能:按mask 中的 True 进行索引

返回值:一维张量(无法确定true的个数,因此也就无法显示原来的形状,因此这里返回一维张量)

  • input : 要索引的张量
  • mask 与 input 同形状的布尔类型张量
    t = torch.randint(0, 9, size=(3, 3))
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

通过掩码来索引。

tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
mask:
tensor([[ True,  True,  True],
        [ True, False,  True],
        [ True,  True, False]])
t_select:
tensor([4, 5, 0, 5, 1, 2, 5]) 

Process finished with exit code 0
目录
相关文章
|
4天前
|
算法 PyTorch 算法框架/工具
Pytorch - 张量转换拼接
使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。
|
4天前
|
存储 机器学习/深度学习 PyTorch
Pytorch-张量形状操作
PyTorch中,张量形状操作至关重要,如reshape用于改变维度而不变元素,transpose/permute用于维度交换,view改形状需内存连续,squeeze移除单维度,unsqueeze添加维度。这些函数帮助数据适应神经网络层间的转换。例如,reshape能调整数据适配层的输入,transpose用于矩阵转置或多维排列,而squeeze和unsqueeze则用于处理单维度。理解并熟练运用这些工具是深度学习中必要的技能。
|
4天前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch-张量
PyTorch 是Facebook AI团队开发的深度学习框架,其核心是张量,它是同类型数据的多维数组。张量可以通过`torch.tensor()`、`torch.Tensor()`、指定类型如`IntTensor`等创建。张量操作包括线性(`torch.arange`, `torch.linspace`)、随机(`torch.randn`, `torch.manual_seed`)和全0/1张量(`torch.zeros`, `torch.ones`)。张量间可进行阿达玛积(逐元素相乘),类型转换用`type()`或`double()`。
|
4天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch-张量基础操作
张量是一个多维数组,它是标量、向量和矩阵概念的推广。在深度学习中,张量被广泛用于表示数据和模型参数。
|
4天前
|
并行计算 PyTorch 算法框架/工具
pytorch张量的创建
• 张量(Tensors)类似于NumPy的ndarrays ,但张量可以在GPU上进行计算。从本质上来说,PyTorch是一个处理张量的库。一个张量是一个数字、向量、矩阵或任何n维数组。
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch】-了解张量(Tensor)
【PyTorch】-了解张量(Tensor)
|
1月前
|
机器学习/深度学习 存储 PyTorch
PyTorch深度学习基础:张量(Tensor)详解
【4月更文挑战第17天】本文详细介绍了PyTorch中的张量,它是构建和操作深度学习数据的核心。张量是多维数组,用于存储和变换数据。PyTorch支持CPU和GPU张量,后者能加速大规模数据处理。创建张量可通过`torch.zeros()`、`torch.rand()`或直接从Python列表转换。张量操作包括数学运算、切片和拼接。在深度学习中,张量用于神经网络模型的构建和训练。理解张量对于掌握PyTorch至关重要。
|
1月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
12天前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
12天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现
【从零开始学习深度学习】30. 神经网络中批量归一化层(batch normalization)的作用及其Pytorch实现