在 PyTorch 中,可以使用 torch.cat()
函数将多个张量沿着指定维度进行合并。下面是一个例子:
importtorch# 创建三个一维张量(向量)x1=torch.tensor([1, 2, 3]) x2=torch.tensor([4, 5, 6]) x3=torch.tensor([7, 8, 9]) # 使用 torch.cat() 将三个张量合并成一个二维张量(矩阵)result=torch.cat([x1.unsqueeze(dim=0), x2.unsqueeze(dim=0), x3.unsqueeze(dim=0)], dim=0) print(result)
这里我们创建了三个一维张量 x1、x2 和 x3,并使用 unsqueeze()
方法将它们变成二维张量。然后,我们使用 torch.cat()
函数将这三个张量沿着第 0 维(即行)进行合并,得到一个形状为 (3, 3)
的二维张量。
需要注意的是,要使用 torch.cat()
函数对多个张量进行合并,这些张量在除了合并维度以外的所有维度上的形状必须相同。如果有不同形状的张量需要合并,需要先进行形状调整(如上述代码中的 unsqueeze()
方法)。