Python图像读取,图像的PIL.Image, numpy.darray, Tensor形式相互转换

文章目录

    • Python读取图像的两种推荐方式
      • 使用PIL读取图像
      • 使用opencv-python读取图像
    • numpy.darray 与 Tensor 相互转换
    • PIL.Image 与 numpy.ndarray 相互转换
    • PIL.Image 转换成 Tensor
    • Tensor 转换成 PIL.Image
    • 参考链接

Python读取图像的两种推荐方式

  uint8是无符号八位整型,表示的是区间[0, 255]内的整数。

使用PIL读取图像

from PIL import Image
img = Image.open('./lena.png')  # PIL读入的图像自然就是uint8格式

使用opencv-python读取图像

import cv2
from PIL import Image

img_cv2_color = cv2.imread('lena.png')  # cv2读入的图像默认是uint8格式的numpy.darray
img_cv2_gray = cv2.imread('lena.png', 0)  # img_cv2_gray.shape (512,512)
cv2.imwrite('img_cv2_color.jpg', img_cv2_color)
img = Image.fromarray(img_cv2_color)

  cv2.imread()返回numpy.darray,可直接用Image.fromarray()转换成PIL.Image,读取灰度图像的shape为(H,W),读取彩色图像的shape为(H,W,3)。
  cv2写图像时,输入的灰度图像的shape可以为(H,W)或(H,W,1),输入的彩色图像的shape应该为(H,W,3);
  若要从numpy.ndarray得到PIL.Image,灰度图像的shape必须为(H,W),彩色图像的shape必须为(H,W,3);

numpy.darray 与 Tensor 相互转换

>>> import numpy as np
>>> import torch
>>> a = np.arange(5)
>>> a
array([0, 1, 2, 3, 4])
>>> b = torch.from_numpy(a)  # numpy.darray 转换成 Tensor
>>> b
tensor([0, 1, 2, 3, 4], dtype=torch.int32)
>>> print(a, '\n', b)
 [0 1 2 3 4] 
tensor([0, 1, 2, 3, 4], dtype=torch.int32)

>>> c = torch.arange(5)
>>> c
tensor([0, 1, 2, 3, 4])
>>> d = c.numpy()  # Tensor 转换成 numpy.darray
>>> d
array([0, 1, 2, 3, 4], dtype=int64)
>>> e = np.array(c)  # Tensor 转换成 numpy.darray
>>> e
array([0, 1, 2, 3, 4], dtype=int64)
>>> print(c, '\n', d, '\n', e)
tensor([0, 1, 2, 3, 4]) 
 [0 1 2 3 4] 
 [0 1 2 3 4]

PIL.Image 与 numpy.ndarray 相互转换

from PIL import Image
import numpy as np
img_pil = Image.open('./lena.png')   # PIL读入的图像自然就是uint8格式
a = np.array(img_pil)  # PIL.Image 转换成 numpy.darray
# a = np.asarray(img_pil)  # PIL.Image 转换成 numpy.darray


# 先把numpy.darray转换成np.unit8, 确保像素值取区间[0,255]内的整数
# 灰度图像需保证numpy.shape为(H,W),不能出现channels,可通过执行np.squeeze()剔除channels;
# 彩色图象需保证numpy.shape为(H,W,3)
a = a.astype(np.uint8)  # a.astype('uint8')  # a = np.uint8(a)
# 再转换成PIL Image形式
img = Image.fromarray(a)  # numpy.darray 转换成 PIL.Image

  若要从numpy.ndarray得到PIL.Image,灰度图像的shape必须为(H,W),彩色图像的shape必须为(H,W,3);

PIL.Image 转换成 Tensor

from PIL import Image
import torch
import torchvision.transforms as transforms

path_img = "./lena.png"  # your path to image
img_pil = Image.open(path_img).convert('RGB')  # 0~255
img_transforms = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(norm_mean, norm_std),
    ])
img_tensor = img_transforms(img_pil)
# img_tensor.unsqueeze_(0)  # CHW --> BCHW
# fmap_1 = convlayer1(img_tensor)

Tensor 转换成 PIL.Image

# 先把Tensor转换成numpy.darray,再把numpy.darray 转换成 PIL.Image
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt

img_transforms = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(norm_mean, norm_std),
    ])


def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: Tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])

    # img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    img_ = img_.permute(1, 2, 0)  # C*H*W --> H*W*C
    if 'ToTensor' in str(transform_train):
        img_ = np.array(img_)  # 先把Tensor转换成numpy.darray
        img_ -= np.min(img_)
        img_ /= np.max(img_)
        img_ = img_ * 255

    # 再把numpy.darray转换成PIL.Image
    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_


# img_pil = Image.open('./lena.png').convert('RGB')  # 彩色图像
img_pil = Image.open('./lena.png').convert('L')  # 灰度图像
# plt.imshow(img_pil)
# plt.show()
img_tensor = img_transforms(img_pil)
img = transform_invert(img_tensor, img_transforms)
plt.imshow(img)
plt.show()

参考链接

python、PyTorch图像读取与numpy转换_Python_便纵有千种风情 20180615

python处理图像何时要将图像转化为uint8格式?uint8是什么?用array()方法打开图像后图像是什么格式?20190214

(待阅读) cv2.imshow()和plt.imshow()显示的色差问题_Python_Max 20190308

你可能感兴趣的:(PyTorch相关)