可以使用 PyTorch 中的 .squeeze()
函数来去掉张量中大小为 1 的维度。如果要删除最后一个维度,可以指定参数 dim=-1
,即对最后一个维度进行处理。下面是示例代码:
import torch x = torch.randn(2, 3, 1) y = x.squeeze(dim=-1) print(x.size()) # 输出 torch.Size([2, 3, 1]) print(y.size()) # 输出 torch.Size([2, 3])
在上述代码中,我们首先创建了一个形状为 (2, 3, 1)
的张量 x
,其中最后一个维度大小为 1。然后使用 .squeeze()
函数将其转换为形状为 (2, 3)
的张量 y
,即已经去掉了最后一个维度。注意,在调用 .squeeze()
函数时,需要指定要去掉的维度,否则函数会默认去掉所有大小为 1 的维度。