Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式

简介: Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式
import torch 
def printt(x,a=""):
    print(x)
    print("{}.dim{}".format(a,x.dim()))
    print("{}.shape{}".format(a,x.shape))
x = torch.arange(48).reshape(2,2,3,4)
printt(x,"x")
y = torch.ones(48).reshape(2,2,3,4)
printt(y,"y")
a = torch.cat((x,y),dim = 0)
printt(a,"a")
b = torch.cat((x,y),dim = 1)
printt(b,"b")
c = torch.cat((x,y),dim = 2)
printt(c,"c")
d = torch.cat((x,y),dim = 3)
printt(d,"d")


x:


tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],
         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]],
        [[[24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35]],
         [[36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]]])
x.dim4
x.shapetorch.Size([2, 2, 3, 4])


y:


tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],
        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
y.dim4
y.shapetorch.Size([2, 2, 3, 4])


a = torch.cat((x,y),dim = 0)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],
        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
a.dim4
a.shapetorch.Size([4, 2, 3, 4])


b = torch.cat((x,y),dim = 1)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
b.dim4
b.shapetorch.Size([2, 4, 3, 4])


c = torch.cat((x,y),dim = 2)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
c.dim4
c.shapetorch.Size([2, 2, 6, 4])


d = torch.cat((x,y),dim = 3)


tensor([[[[ 0.,  1.,  2.,  3.,  1.,  1.,  1.,  1.],
          [ 4.,  5.,  6.,  7.,  1.,  1.,  1.,  1.],
          [ 8.,  9., 10., 11.,  1.,  1.,  1.,  1.]],
         [[12., 13., 14., 15.,  1.,  1.,  1.,  1.],
          [16., 17., 18., 19.,  1.,  1.,  1.,  1.],
          [20., 21., 22., 23.,  1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.,  1.,  1.,  1.,  1.],
          [28., 29., 30., 31.,  1.,  1.,  1.,  1.],
          [32., 33., 34., 35.,  1.,  1.,  1.,  1.]],
         [[36., 37., 38., 39.,  1.,  1.,  1.,  1.],
          [40., 41., 42., 43.,  1.,  1.,  1.,  1.],
          [44., 45., 46., 47.,  1.,  1.,  1.,  1.]]]])
d.dim4
d.shapetorch.Size([2, 2, 3, 8])


a0f3c18dd08b42c4b091cf299c66b8d9.png

目录
相关文章
|
机器学习/深度学习 自然语言处理 PyTorch
【深度学习】实验12 使用PyTorch训练模型
【深度学习】实验12 使用PyTorch训练模型
151 0
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(一):torch.cat()模块的详解
这篇博客文章详细介绍了Pytorch中的torch.cat()函数,包括其定义、使用方法和实际代码示例,用于将两个或多个张量沿着指定维度进行拼接。
97 0
Pytorch学习笔记(一):torch.cat()模块的详解
|
7月前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于全连接网络构建RNN并生成人名
【PyTorch实战演练】基于全连接网络构建RNN并生成人名
63 0
|
7月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
机器学习/深度学习 算法 PyTorch
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
172 0
|
机器学习/深度学习 算法 PyTorch
Pytorch全连接神经网络实现手写数字识别
Pytorch全连接神经网络实现手写数字识别
143 0
|
机器学习/深度学习 传感器 算法
pytorch实现循环神经网络实验
pytorch实现循环神经网络实验
275 0
|
算法 PyTorch 算法框架/工具
pytorch实现空洞卷积+残差网络实验(torch实现)
pytorch实现空洞卷积+残差网络实验(torch实现)
407 0
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch实现卷积神经网络实验
pytorch实现卷积神经网络实验
241 0
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch实现前馈神经网络实验(手动实现)
pytorch实现前馈神经网络实验(手动实现)
314 0

热门文章

最新文章