torch.sum(x, dim, keepdim)
我们使用一些torch模块中的函数时发现,有时会存在参数keepdim,该参数主要是在归并操作时使用的,为的就是保持原来维度不变。
示例:
>>>a = torch.arange(12).reshape(3, 4) >>>print(a) tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) >>>print(torch.sum(a, dim=0, keepdim=True)) >>>print(torch.sum(a, dim=0, keepdim=True).shape) tensor([[12, 15, 18, 21]]) torch.Size([1, 4]) >>>print(torch.sum(a, dim=0, keepdim=False)) >>>print(torch.sum(a, dim=0, keepdim=False).shape) tensor([12, 15, 18, 21]) torch.Size([4])
从上面例子可以看出,如果将其设置为True,那么将归并的维度依旧会保留,与原来的tensor数据维度一致,如果设置为False,那么改维度经过计算归并则会消失。