import torch def printt(x,a=""): print(x) print("{}.dim{}".format(a,x.dim())) print("{}.shape{}".format(a,x.shape)) x = torch.arange(48).reshape(2,2,3,4) printt(x,"x") y = torch.ones(48).reshape(2,2,3,4) printt(y,"y") a = torch.cat((x,y),dim = 0) printt(a,"a") b = torch.cat((x,y),dim = 1) printt(b,"b") c = torch.cat((x,y),dim = 2) printt(c,"c") d = torch.cat((x,y),dim = 3) printt(d,"d")
x:
tensor([[[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]], [[[24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35]], [[36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]]]]) x.dim4 x.shapetorch.Size([2, 2, 3, 4])
y:
tensor([[[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]], [[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]]]) y.dim4 y.shapetorch.Size([2, 2, 3, 4])
a = torch.cat((x,y),dim = 0)
tensor([[[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]]], [[[24., 25., 26., 27.], [28., 29., 30., 31.], [32., 33., 34., 35.]], [[36., 37., 38., 39.], [40., 41., 42., 43.], [44., 45., 46., 47.]]], [[[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]], [[[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]]]) a.dim4 a.shapetorch.Size([4, 2, 3, 4])
b = torch.cat((x,y),dim = 1)
tensor([[[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]], [[[24., 25., 26., 27.], [28., 29., 30., 31.], [32., 33., 34., 35.]], [[36., 37., 38., 39.], [40., 41., 42., 43.], [44., 45., 46., 47.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]]]) b.dim4 b.shapetorch.Size([2, 4, 3, 4])
c = torch.cat((x,y),dim = 2)
tensor([[[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]], [[[24., 25., 26., 27.], [28., 29., 30., 31.], [32., 33., 34., 35.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]], [[36., 37., 38., 39.], [40., 41., 42., 43.], [44., 45., 46., 47.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]]]) c.dim4 c.shapetorch.Size([2, 2, 6, 4])
d = torch.cat((x,y),dim = 3)
tensor([[[[ 0., 1., 2., 3., 1., 1., 1., 1.], [ 4., 5., 6., 7., 1., 1., 1., 1.], [ 8., 9., 10., 11., 1., 1., 1., 1.]], [[12., 13., 14., 15., 1., 1., 1., 1.], [16., 17., 18., 19., 1., 1., 1., 1.], [20., 21., 22., 23., 1., 1., 1., 1.]]], [[[24., 25., 26., 27., 1., 1., 1., 1.], [28., 29., 30., 31., 1., 1., 1., 1.], [32., 33., 34., 35., 1., 1., 1., 1.]], [[36., 37., 38., 39., 1., 1., 1., 1.], [40., 41., 42., 43., 1., 1., 1., 1.], [44., 45., 46., 47., 1., 1., 1., 1.]]]]) d.dim4 d.shapetorch.Size([2, 2, 3, 8])