pytorch的torchvision无法保存真正的灰度图像
在图像相关的深度学习模型中,有时候需要保存训练中的图像或者测试得到的图像(尤其是低级视觉任务比如去噪、超分辨率、压缩重建等),一般使用如下方式进行图像保存(torchvision.utils中的save_image()函数):
torchvision.utils.save_image(output.data,'%d.bmp'% (idx), padding=0)
但这种方式只能保存RGB彩色图像,如果网络的输出是单通道灰度图像,则该函数依然会输出三个通道,每个通道的数值都是相同的,即“伪灰度图像”,虽然从视觉效果上看不出区别,但是图像所占内存比正常情况大了两倍。
那么如何保存真正的灰度图像?首先转入save_image()函数进行探究:
def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
"""Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(filename)
可以发现其使用了图像处理模块PIL,并且使用make_grid()函数进行处理,转入make_grid()函数可以看到一句代码:
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
这表示如果张量的通道为1,则输出就会拼接三个相同的该张量形成三通道图像。问题发现。
解决方式
在save_image()函数中,把最后一句保存图像的代码改为:
im.convert('L').save(filename)
即先将其转化为灰度图,再保存。
注
对于彩色图转灰度图的原理,一般是使用如下公式进行转换:
Gray = 0.29900 * R + 0.58700 * G + 0.11400 * B
由于make_grid()函数产生的是三通道相同的伪灰度图像,所以经过上式计算得到的灰度图是符合网络输出的张量的,没有产生任何损失。