torchvision.utils.save_image()保存tensor显示图片异常问题解决

用torchvision.utils.save_image()保存图片时出现异常

有些像素点会显示为全黑(灰度图),如下图所示,第一张和第三张图

torchvision.utils.save_image()保存tensor显示图片异常问题解决_第1张图片
刚开始以为是图像数据分布范围的问题,在保存之前输出图像tensor的最大max和最小min值,出现了 -0.0x和1.0x的数值,说明图像的像素范围超出了0-1。

读源码

可是通过读utils.save_image()的源码发现,就算超出0-1也不应该出现这种问题,源码中存在如下部分代码

    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()

grid可以理解为图片张量,这段代码

  1. 首先将 grid 张量中的每个元素乘以255。这一步将原来在0到1范围内的图像数据转换到0到255的范围内
  2. 对 grid 张量中的每个元素加上0.5。这一步可能是为了进行亮度调整或将值偏移至正数范围内
  3. 将 grid 张量中的每个元素限制在0到255的范围内。小于0的值将被设置为0,大于255的值将被设置为255
  4. 后面的不重要
    源码在将所有像素乘255之后,已经将数据每个像素范围限制在了0-255之间

问题解决

经过查看其他成功的代码源码中的注释发现。大多在使用 torchvision.utils.save_image时直接将4Dtensor图片和保存路径传入给 save_image()函数就行,不会出现问题。

且utils.save_image接收四维tensor ,B C H W

如源码所示
在这里插入图片描述

而我在保存之前进行了降维处理,降成了三维(squeeze(0)是降维)

torchvision.utils.save_image()保存tensor显示图片异常问题解决_第2张图片

于是删掉后面的squeeze(0),问题解决

torchvision.utils.save_image()保存tensor显示图片异常问题解决_第3张图片

结果如图所示

torchvision.utils.save_image()保存tensor显示图片异常问题解决_第4张图片

你可能感兴趣的:(深度学习,python,pytorch)