可以使用 torch.cat()
方法将两个张量沿着指定的维度进行拼接,示例如下:
import torch # 定义两个张量 a = torch.tensor([[1, 2, 3]]) b = torch.tensor([4]) # 沿着第二维进行拼接 c = torch.cat((a, b.unsqueeze(0)), dim=1) print(c) # 输出 tensor([[1, 2, 3, 4]])
在这里,我们将 b
张量通过 unsqueeze()
方法扩展为二维张量,然后再和 a
张量进行拼接。由于 b
张量只有一个元素,所以在拼接时要沿着第二维进行操作。
注意,如果 b
张量是一维张量,则需要使用 unsqueeze()
方法将其转换为二维张量,例如 b.unsqueeze(0)
或者 b.unsqueeze(1)
。