问题描述
使用matplotlib显示彩色图像出现问题
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-25-edd857df93f0> in <module> 20 # import numpy as np 21 # # np_img = np.array(pil_img) ---> 22 plt.imshow(np.array(img)) 23 # plt.imshow(img.permute(1,2,0)) 24 plt.show() TypeError: Invalid shape (3, 224, 224) for image data
原因分析:
使用matplotlib显示彩色图像需要数据的维度为 【width, height, channel】,就是224 * 224 * 3
报错原因是我这里的tensor的维度为 3 * 224 * 224
x_train_tensor = torch.from_numpy(x_train) y_train_tensor = torch.from_numpy(y_train)
解决方案:
将tensor或者数组的维度交换即可
可以使用permute函数,这个函数的参数就是我们交换之后新维度的排序,下面为1,2,0就是我们需要将原来1和2维度的内容排在前面,而通道维度放在最后
img.permute(1,2,0)
或者还可以使用transpose函数直接交换维度
img.transpose(0,2)