PyTorch 图像处理:Tensor、Numpy、PIL格式转换以及图像显示


Author :Horizon Max

编程技巧篇:各种操作小结

机器视觉篇:会变魔术 OpenCV

深度学习篇:简单入门 PyTorch

神经网络篇:经典网络模型

算法篇:再忙也别忘了 LeetCode


重点

cv.imshow() :显示图像是 BGR格式的
plt.imshow() :图像显示是 RGB格式的

Tensor :存储的数据分布在 [0, 1]
Numpy :存储的数据分布在 [0, 255]


CIFAR-10数据集

数据集为 RGB格式的;
在使用 opencv 显示时需要先转换成 BGR格式;
在使用 plt显示时 无需 转换格式;

示例:

dict = unpickle('./dataset/cifar-10-batches-py/test_batch')

img = dict[b'data']
image = img[0]
image = image.reshape(3, 32, 32).transpose(1, 2, 0)
cv_show('image', image)

r, g, b = cv.split(image)
pic = cv.merge([b, g, r])
cv_show('pic', pic)

左侧图为原图,右侧图为失真图 :
PyTorch 图像处理:Tensor、Numpy、PIL格式转换以及图像显示_第1张图片


格式转换

Tensor ==> Numpy

import torch
import torchvision
import pickle
import cv2 as cv

transform_tensor = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

transform_picture = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),                   # 转换成Tensor格式
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


def cv_show(name, img):
    cv.imshow(name, img)
    cv.waitKey(0)
    cv.destroyAllWindows()


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


dict = unpickle('./dataset/cifar-10-batches-py/test_batch')

img = dict[b'data']
image = img[0]
image = image.reshape(3, 32, 32).transpose(1, 2, 0)
print(image)

image_Tensor = transform_tensor(image).unsqueeze(0)
print(image_Tensor)            # 没有数据归一化操作

image_Tensor_Nor = transform_picture(image).unsqueeze(0)
print(image_Tensor_Nor)        # 有数据归一化操作
[[[158 112  49]
  [159 111  47]
  [165 116  51]
  ...
  [ 24  77 124]
  [ 34  84 129]
  [ 21  67 110]]]
  
tensor([[[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],
          [0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],
          [0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],
          ...,
          [0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],
          [0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],
          [0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]]])
          
tensor([[[[ 0.6338,  0.6531,  0.7694,  ...,  0.2267,  0.0134, -0.1804],
          [ 0.5174,  0.4981,  0.6531,  ...,  0.2073, -0.0060, -0.1223],
          [ 0.4981,  0.4981,  0.6338,  ...,  0.2654,  0.0910, -0.1029],
          ...,
          [ 1.2319,  0.6661,  0.4515,  ...,  0.6271, -0.1143,  0.2564],
          [ 1.0563,  0.6661,  0.3540,  ...,  0.3149,  0.2369, -0.1338],
          [ 0.9003,  0.6856,  0.3540,  ...,  0.1979,  0.2954, -0.0753]]]])

Tensor 转 Numpy : (用于显示图像)

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)

def tensor_numpy(image):
    clean = image.clone().detach().cpu().squeeze(0)        # 去掉batch通道 (batch, C, H, W) --> (C, H, W)
    clean[0] = clean[0] * std[0] + mean[0]                 # 数据去归一化
    clean[1] = clean[1] * std[1] + mean[1]
    clean[2] = clean[2].mul(std[2]) + mean[2]
    clean = np.around(clean.mul(255))                     # 转换到颜色255 [0, 1] --> [0, 255]
    clean = np.uint8(clean).transpose(1, 2, 0)            # 跟换三通道 (C, H, W) --> (H, W, C)
    r, g, b = cv.split(clean)                             # RGB 通道转换
    clean = cv.merge([b, g, r])
    return clean

Num = tensor_numpy(image_Tensor_Nor)
print(Num)

如果使用 cv.imshow() 需要使用上面的 RGB 通道转换;
如果使用 plt.imshow() 不需要使用上面的 RGB 通道转换;


示例:

plt.imshow('image', image)
plt.show()

RGB格式:

[[[158 112  49]
  [159 111  47]
  [165 116  51]
  ...
  [ 24  77 124]
  [ 34  84 129]
  [ 21  67 110]]]

PyTorch 图像处理:Tensor、Numpy、PIL格式转换以及图像显示_第2张图片

RGB通道转换:

r, g, b = cv.split(image)
pic = cv.merge([b, g, r])
plt.imshow('image', pic)
plt.show()

BGR格式:

[[[ 49 112 158]
  [ 47 111 159]
  [ 51 116 165]
  ...
  [124  77  24]
  [129  84  34]
  [110  67  21]]]

PyTorch 图像处理:Tensor、Numpy、PIL格式转换以及图像显示_第3张图片

PyTorch save_image()

torchvision.utils.save_image()

最后再来看一下pytorch自带的函数是如何进行格式转换保存图片的 :

def save_image(
    tensor: Union[torch.Tensor, List[torch.Tensor]],
    fp: Union[Text, pathlib.Path, BinaryIO],
    format: Optional[str] = None,
    **kwargs
) 

    grid = make_grid(tensor, **kwargs)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    im.save(fp, format=format)


你可能感兴趣的:(各种操作小结,PyTorch,opencv,matplotlib,格式转换)