可以使用torch.unsqueeze()
函数在指定位置插入一个新的维度。该函数可接受两个参数:要插入维度的张量和要插入的位置索引。
以下是示例代码:
import torch # 创建一个形状为(2, 3)的张量 a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 在第0个维度前插入一个新的维度 b = torch.unsqueeze(a, dim=0) print(b.shape) # 输出结果为: torch.Size([1, 2, 3])
上述代码中,首先创建了一个形状为(2, 3)的张量a
。然后使用torch.unsqueeze()
函数在第0个维度(即行维度)前插入了一个新的维度,得到了一个形状为(1, 2, 3)的新张量b
。
需要注意的是,插入新维度后,原来张量的数据沿着未被插入新维度的维度保持不变,而新维度的大小为1。例如,在上述代码中,b
张量的第0个维度的大小为1,而其他两个维度的大小与a
张量相同。