Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式

简介: Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式
import torch 
def printt(x,a=""):
    print(x)
    print("{}.dim{}".format(a,x.dim()))
    print("{}.shape{}".format(a,x.shape))
x = torch.arange(48).reshape(2,2,3,4)
printt(x,"x")
y = torch.ones(48).reshape(2,2,3,4)
printt(y,"y")
a = torch.cat((x,y),dim = 0)
printt(a,"a")
b = torch.cat((x,y),dim = 1)
printt(b,"b")
c = torch.cat((x,y),dim = 2)
printt(c,"c")
d = torch.cat((x,y),dim = 3)
printt(d,"d")


x:


tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],
         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]],
        [[[24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35]],
         [[36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]]])
x.dim4
x.shapetorch.Size([2, 2, 3, 4])


y:


tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],
        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
y.dim4
y.shapetorch.Size([2, 2, 3, 4])


a = torch.cat((x,y),dim = 0)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],
        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
a.dim4
a.shapetorch.Size([4, 2, 3, 4])


b = torch.cat((x,y),dim = 1)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
b.dim4
b.shapetorch.Size([2, 4, 3, 4])


c = torch.cat((x,y),dim = 2)


tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],
         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
c.dim4
c.shapetorch.Size([2, 2, 6, 4])


d = torch.cat((x,y),dim = 3)


tensor([[[[ 0.,  1.,  2.,  3.,  1.,  1.,  1.,  1.],
          [ 4.,  5.,  6.,  7.,  1.,  1.,  1.,  1.],
          [ 8.,  9., 10., 11.,  1.,  1.,  1.,  1.]],
         [[12., 13., 14., 15.,  1.,  1.,  1.,  1.],
          [16., 17., 18., 19.,  1.,  1.,  1.,  1.],
          [20., 21., 22., 23.,  1.,  1.,  1.,  1.]]],
        [[[24., 25., 26., 27.,  1.,  1.,  1.,  1.],
          [28., 29., 30., 31.,  1.,  1.,  1.,  1.],
          [32., 33., 34., 35.,  1.,  1.,  1.,  1.]],
         [[36., 37., 38., 39.,  1.,  1.,  1.,  1.],
          [40., 41., 42., 43.,  1.,  1.,  1.,  1.],
          [44., 45., 46., 47.,  1.,  1.,  1.,  1.]]]])
d.dim4
d.shapetorch.Size([2, 2, 3, 8])


a0f3c18dd08b42c4b091cf299c66b8d9.png

目录
打赏
0
0
0
0
691
分享
相关文章
Pytorch学习笔记(一):torch.cat()模块的详解
这篇博客文章详细介绍了Pytorch中的torch.cat()函数,包括其定义、使用方法和实际代码示例,用于将两个或多个张量沿着指定维度进行拼接。
244 0
Pytorch学习笔记(一):torch.cat()模块的详解
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
210 0
pytorch实现空洞卷积+残差网络实验(torch实现)
pytorch实现空洞卷积+残差网络实验(torch实现)
445 0
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等