cv2读入的BGR图像转换成torch.tensor格式,torch.tensor格式的图像转换成ndarray格式并保存

代码如下:

包含两个函数

1.ndarray2tensor :

        img(W,H,C):cv2读取的一张图像,ndarray格式,注意读取的为BGR图像,函数中没有转换成RGB图像,如需自行可以更换;

        输出(1,C,W,H):torch.Tensor格式

2.torch2ndarray_save:

        input(1,C,W,H):torch.Tensor格式

        filename:str格式,保存的路径

         denormal:bool格式,是否反标准化

注意:标准化和反标准化已经实例出来,可参考使用

import cv2
import torch
import numpy as np
import os
import torchvision

c_mean = [0.4914, 0.4822, 0.4465]
c_std = [0.2023, 0.1994, 0.2010]
de_mean = [-mean / std for mean, std in zip(c_mean, c_std)]
de_std = [1 / std for std in c_std]
normalize = torchvision.transforms.Normalize(c_mean, c_std)
denormalize = torchvision.transforms.Normalize(de_mean, de_std)

def ndarray2tensor(img):
    img = cv2.resize(img, (256, 256))
    img = img / 255.
    img = torch.tensor(img, dtype=torch.float32)
    img = img.unsqueeze(0)
    img = img.permute(0, 3, 1, 2)
    img = normalize(img)
    return img

def torch2ndarray_save(input:torch.Tensor, filename, denormal = True):
    assert (len(input.shape) == 4 and input.shape[0] == 1)
    input = input.clone().detach()
    if input.is_cuda == True:
        input = input.to(torch.device('cpu'))
    if denormal:
        input = denormalize(input) * 255
    input = torch.tensor(input, dtype=torch.int)
    img = input.permute(0, 2, 3, 1)
    img = torch.reshape(img, (img.shape[1], img.shape[2], img.shape[3])).numpy()
    cv2.imwrite(filename,img)

你可能感兴趣的:(python,深度学习,机器学习)