【代码错误记录】显示数据集图片-图片tensor问题

目录标题

  • matplotlib绘图imshow()函数报错“TypeError: Invalid dimensions for image data”

matplotlib绘图imshow()函数报错“TypeError: Invalid dimensions for image data”

错误代码

plt.imshow((img[6, :, :, :].moveaxis(0, 2)))

改为

plt.imshow((img[6, :, :, :]))

报错
TypeError: Invalid dimensions for image data”

修改为:

plt.imshow((img[6, :, :, :].squeeze().numpy().transpose(1,2,0)))

参考
解决这个问题的关键就是理解了imshow函数的参数。
matplotlib.pyplot.imshow()函数的输入需要是二维的numpy或者是第三维度是3或4的numpy,

  • 当第3维深度是1时,使用np.squeeze()函数压缩数据成为二维数组。
  • 因为我在pytorch环境下使用,得到结果的输出是(batch_size,channel,width,height)的tensor,因此我首先需要detach()函数切断反向传播。
  • 需要指出的是,imshow不支持显示tensor,因此,我需要使用.cpu()函数转移到cpu上来。
  • 正如前面说到的,imshow函数的输入需要是二维的numpy或者第三维度是3或4的numpy,
  • 因为我的使用情况比较特殊,还多了一个batch_size维度,不过还好,我设置batch_size仅为1,这时候可以使用.squeeze()函数把1给去掉,得到了是一个(channel,widht,height)的numpy,这显然与imshow的输入要求不符。因此,我们需要使用transpose函数把channel(=3)移动到最后,这也是为什么才有了.transpose(1,2,0)这种用法。当然,如果待显示的图像本身就是channel=1,那么完全可以使用squeeze()函数把其搞掉,直接输入给imshow函数一个二维的numpy.

你可能感兴趣的:(代码错误记录,可视化,python,numpy,matplotlib)