Pytorch torchvision.transforms.Compose 易错点

先看如下代码:

import os
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
import pdb
from PIL import Image
import cv2

img_path = "./panda.jpg"

transform1 = transforms.Compose([
    transforms.Normalize(mean=[0,0,0], std=[255,255,255])
    ]
)
transform2 = transforms.Compose([
    transforms.Normalize(mean=[0.50,0.50,0.50], std=[1,1,1])
    ]
)

img = cv2.imread(img_path)# 读取图像
img0 = np.transpose(img, (2, 0, 1))
img0 = torch.from_numpy(img0)
img1 = transform1(img0.float()) 
# img2 = transform2(img1)
print("img = ",img)
print("img0 = ",img0)
print("img1 = ",img1)
# print("img2 = ",img2)

有如下几个易错点:

1. pytorch中tensor和numpy的通道位置不同,numpy的通道在H,W之后,即(H,W,C),用np.shape(object)可以查看;而tensor的通道在H和W之前,即(C,H,W),用np.shape(object)或者object.shape可以查看;

所以,读取图像后得到numpy数组,要变成tensor进行数据处理,需要transpose操作,改变通道的位置;

2. 处理针对的是tensor格式,需要通过from_numpy等方法,将输入图像时的numpy格式转换为tensor格式

3. 进行Normalize等处理,需要数据类型为float

4. 即使上述代码中分别采用了transform1 和transform2, 如果img2 = transform2(img1)一句不注释掉,则打印出的img1和img2都是img2的结果。

 

下面展示一个完整的读入.jpg文件并进行处理的过程

import os
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
import pdb
import cv2

img_path = "./panda.jpg"


input_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
    transforms.Normalize(mean=[0.50,0.50,0.50], std=[1,1,1])
    ]
)

##numpy.ndarray
img = cv2.imread(img_path)# 读取图像
img1 = input_transform(img) 

pdb.set_trace()


 

你可能感兴趣的:(Pytorch)