问题描述
将如下input作为torch.MaxPool2d的输入
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
input = torch.reshape(input, (-1, 1, 5, 5))
报错:
RuntimeError: "max_pool2d" not implemented for 'Long'
解决办法
pytorch中的很多操作不支持Long类型的张量, 只需要把输入的张量改成浮点类型即可
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]], dtype = torch.float32)
input = torch.reshape(input, (-1, 1, 5, 5))