你可以使用 PyTorch 中的 torch.cat()
函数将两个二维张量沿着第一维度拼接在一起。例如,假设有两个张量 x
和 y
,它们的形状都为 (m, n)
,你可以将它们按照第一维度拼接在一起,形成一个新的张量 z
,方法如下:
import torch # 创建两个形状相同的张量 x = torch.randn(3, 4) y = torch.randn(3, 4) # 按照第一维度拼接 z = torch.cat([x, y], dim=0) # 打印拼接后的张量形状 print(z.shape)
输出:
torch.Size([6, 4])
在这个例子中,torch.cat()
函数的第一个参数是一个列表,包含要拼接的张量 x
和 y
,第二个参数是拼接的维度,即第一维度。拼接后的张量 z
的形状为 (6, 4)
,因为两个原始张量的第一维度都是 3
,拼接后就变成了 6
。