pytorch使用记录(五) 关于tensor、PIL以及numpy转换的问题

最近在运行程序的时候一直出现如下错误:

  File "/home/daydayjump/Glow/glow/trainer.py", line 172, in train
    self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]).cuda().float(), self.global_step)
  File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/writer.py", line 427, in add_image
    image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
  File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/summary.py", line 216, in image
    image = make_image(tensor, rescale=rescale)
  File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/summary.py", line 254, in make_image
    image = Image.fromarray(tensor)
  File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/PIL/Image.py", line 2517, in fromarray
    raise TypeError("Cannot handle this data type")
TypeError: Cannot handle this data type

具体而言,是在使用tensorboardX的SummaryWriter.add_imge()时,tensor的格式出现了问题。不支持这个数据类型。

在网上搜教程,有些说是PIL中Image.fromarray要求numpy的类型为uint8,然后就按照这个方向进行修改将tensor转换为numpy类型并强制转换为uint8类型。(具体tensor和numpy的类型转换会在下文讲解)结果发现依然报错。

后来经过不断输出tensor的形状才发现了问题。原来是在pytorch中tensor默认是CHW,而PIL中是HWC。在tensorboardX中的SummaryWriter.add_imge()的函数默认dataformats='CHW',并存在convert_to_HWC操作,致使自己本来是就是HWC的tensor,变成了WCH的numpy类型。从而导致在PIL的fromarray操作无法识别数据类型。

查看tensor或者array的形状的命令如下:

a.size()   #  for torch.Tensor
b.shape()  #  for numpy.array

总结一句话就是tensor的形状不对,没有根据自己tensor的形状修改dataformats,所以一定要明确自己的tensor或者array的形状是CHW还是HWC

 

对应措施:

1、修改tensorboardX的SummaryWriter.add_image()中的dataformats参数为自己输入的形状。

2、修改自己的输入tensor或者array为'CHW'形状,这样就可以使用默认参数。

 

此外,在tensorboardX的SummaryWriter.add_image()函数中,输入可以为:

img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data

而且在后续的操作中会修改成符合PIL函数操作的类型,所以可以放心。 

----------------------------------------------------------------分割线----------------------------------------------------------------------------------------------

接下来,总结一下torch.Tensor、numpy.array本身类型转换和相互转换,以及torch.Tensor与PIL Image之间的转换。

torch.Tensor自身类型转换

pytorch使用记录(五) 关于tensor、PIL以及numpy转换的问题_第1张图片

上图是tensor各种类型的对应关系。

(1)自身类型转换,直接在tensor后加.long()就可以转换成torch.LongTensor类型。即可以通过:

data.float()   # 转换成 FloatTensor类型
data.byte()    # 转换成 ByteTensor类型
data.int()     # 转换成 IntTensor类型

此外,可以使用type()函数,当不给定参数时,返回tensor的类型;给定参数则进行强制转换,如: 

data.type(torch.ByteTensor)  # 转换成ByteTensor类型

(2)cpu和gpu的转换:

data.cpu()    # 转换成cpu类型
data.cuda()   # 转换成gpu类型

 numpy.array自身类型转换

np类型转换需要用到astype()函数,修改dtype属性具体如下:

a.astype(np.int64)    #  转换数据类型int64
a.dtype               #  查看数据类型

需要注意使用astype()函数不会修改原array的dtype属性,返回了一个新的array。 

torch.Tensor与numpy.array的转换

(1) tensor 转换成 array :

new_data = data.numpy()   # data为tensor类型,new_data为array类型,二者指向同一地址。

(2)array 转换成 tensor :

new_data = torch.from_numpy(data)  # data为array类型,new_data为tensor类型,二者指向同一地址。

torch.Tensor、numpy.array与PIL Image格式的转换

 这需要用到torchvision.transforms()函数,具体操作如下。

(1)tensor转换成PIL Image格式, 使用transforms.ToPILImage()函数,需要注意:

Converts a torch.*Tensor of shape C x H x W to a PIL Image 

形状必须是CHW,最好是FloatTensor。

img = transforms.ToPILImage()(in_tensor)  # in_tensor 是待转换的tensor

(2)array转换成PIL Image格式 ,也使用transforms.ToPILImage()函数,需要注意:  

Converts a numpy ndarray of shape H x W x C to a PIL Image

形状必须是HWC,而且要求dtype=uint8, range[0, 255] 。

img = transforms.ToPILImage()(in_array)   # in_array 是待转换的array

 (3)PIL Image转换成tensor格式,使用transforms.ToTensor()函数,也就是将HWC形状的PIL Image 转换成CHW形状的tensor。

tensor = transforms.ToTensor()(img)  # img 是待转换的PIL Image

注意,得到的tensor类型为torch.FloatTensor。

 

 

 

 

 

 

 

你可能感兴趣的:(tensor,array,PIL,类型转换,pytorch学习与使用)