Pytorch中使用transforms.ToTensor和transforms.ToPILImage的numpy.ndarray与Tensor和PILImage的转换举例

作为刚接触Pytorch的新人,在刚开始时候处理三种数据tensor,numpy,PIL有点弄晕了,而且中间的转换过程需要注意类型,仅此记录下来,如果能有幸帮到大家那更好了。

进入正题,我们首先看看numpy to tensor的情况:

import numpy as np
from torchvision import transforms as T

x = np.ones((128, 228, 1))   # 这里numpy维度是HWC
print(x.shape)   # (128, 228, 1)  
x_tensor = T.ToTensor()(x) 
print(x_tensor.size())   # torch.Size([1, 128, 228]),这里tensor维度是CHW
print(torch.from_numpy(x).size())   # torch.Size([128, 228, 1])注意这里维度不变
x_numpy = x_tensor.numpy()   # 将tensor转换为numpy,但是维度也不改变(1, 128, 228)

所以这种情况比较好理解,就是将HWC变成了CHW,当然这个过程还进行了矩阵运算将其值进行了归一化。

接下来看看numpy to PIL的情况:

x = np.ones((128, 228, 1))
print(x.shape)
x_pil = T.ToPILImage()(np.uint8(x))   # 注意这里的np.uint8(),为了使其满足类型需求,在这里是必须的
print(x_pil.size)
x_numpy = np.array(x_pil)
print(x_numpy.shape)

# 输出为:
# (128, 228, 1)
# (228, 128)
# (128, 228)

可以看到ToPILImage实际上将HWC变成了WH,去掉了通道这个维度。

接下来看看tensor to PIL的情况:

x = torch.ones((1, 128, 228))
print(x.shape)
x = x.float()   # 有时候从numpy到tensor后再转化为PIL会报错,这一步就是使其满足数据类型,在这里这一句可以去掉。
x_pil = T.ToPILImage()(x)
print(x_pil.size)
x_tensor = T.ToTensor()(x_pil)
print(x_tensor.size())

# 输出为:
# torch.Size([1, 128, 228])
# (228, 128)
# torch.Size([1, 128, 228])

可以看到ToPILImage实际上将CHW变成了WH,也去掉了通道这个维度。

你可能感兴趣的:(笔记)