把一个 PyTorch 的图像张量转换成 NumPy 格式,并按照正确的维度顺序显示出来

示例代码:

plt.imshow(np.transpose(tensor_denorm.numpy(), (1, 2, 0)))

它的作用是:把一个 PyTorch 的图像张量转换成 NumPy 格式,并按照正确的维度顺序显示出来


一步步解释:

tensor_denorm

这是一个形状为 (3, H, W) 的 PyTorch Tensor,表示一个图像:

  • 3:表示三个颜色通道(RGB)
  • H:图像高度
  • W:图像宽度

PyTorch 中的图像张量格式是 (C, H, W)


.numpy()

这一步把 PyTorch Tensor 转换成 NumPy 数组(前提是 Tensor 在 CPU 上):

tensor_denorm.numpy()

得到一个 NumPy 数组,形状依然是 (3, H, W)


np.transpose(..., (1, 2, 0))

NumPy 默认显示图像的格式是 (H, W, C),也就是:

  • 高度(H)
  • 宽度(W)
  • 通道(C)

所以要把 (3, H, W) 转换成 (H, W, 3),需要换维度顺序:

np.transpose(tensor_denorm.numpy(), (1, 2, 0))

plt.imshow(...)

这是 matplotlib.pyplot 的图像显示函数。它接收一个 (H, W, 3) 的数组并显示出来:

plt.imshow(...)

举个例子:

假设我们有这个张量:

tensor = torch.rand(3, 150, 150)  # 随机图像,3通道 150x150

执行这一步:

plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))

就能把这个随机图像展示出来了。


✅ 总结一句话:

plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))

等价于:

“把 PyTorch 中格式为 (C, H, W) 的图像转成 (H, W, C) 并显示出来”

你可能感兴趣的:(python,pytorch,numpy,人工智能)