最近在运行程序的时候一直出现如下错误:
File "/home/daydayjump/Glow/glow/trainer.py", line 172, in train
self.writer.add_image("1_prob/{}".format(bi), plot_prob([y_pred[bi], y_true[bi]], ["pred", "true"]).cuda().float(), self.global_step)
File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/writer.py", line 427, in add_image
image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/summary.py", line 216, in image
image = make_image(tensor, rescale=rescale)
File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/tensorboardX/summary.py", line 254, in make_image
image = Image.fromarray(tensor)
File "/home/daydayjump/anaconda3/lib/python3.7/site-packages/PIL/Image.py", line 2517, in fromarray
raise TypeError("Cannot handle this data type")
TypeError: Cannot handle this data type
具体而言,是在使用tensorboardX的SummaryWriter.add_imge()时,tensor的格式出现了问题。不支持这个数据类型。
在网上搜教程,有些说是PIL中Image.fromarray要求numpy的类型为uint8,然后就按照这个方向进行修改将tensor转换为numpy类型并强制转换为uint8类型。(具体tensor和numpy的类型转换会在下文讲解)结果发现依然报错。
后来经过不断输出tensor的形状才发现了问题。原来是在pytorch中tensor默认是CHW,而PIL中是HWC。在tensorboardX中的SummaryWriter.add_imge()的函数默认dataformats='CHW',并存在convert_to_HWC操作,致使自己本来是就是HWC的tensor,变成了WCH的numpy类型。从而导致在PIL的fromarray操作无法识别数据类型。
查看tensor或者array的形状的命令如下:
a.size() # for torch.Tensor
b.shape() # for numpy.array
总结一句话就是tensor的形状不对,没有根据自己tensor的形状修改dataformats,所以一定要明确自己的tensor或者array的形状是CHW还是HWC
对应措施:
1、修改tensorboardX的SummaryWriter.add_image()中的dataformats参数为自己输入的形状。
2、修改自己的输入tensor或者array为'CHW'形状,这样就可以使用默认参数。
此外,在tensorboardX的SummaryWriter.add_image()函数中,输入可以为:
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
而且在后续的操作中会修改成符合PIL函数操作的类型,所以可以放心。
----------------------------------------------------------------分割线----------------------------------------------------------------------------------------------
接下来,总结一下torch.Tensor、numpy.array本身类型转换和相互转换,以及torch.Tensor与PIL Image之间的转换。
上图是tensor各种类型的对应关系。
(1)自身类型转换,直接在tensor后加.long()就可以转换成torch.LongTensor类型。即可以通过:
data.float() # 转换成 FloatTensor类型
data.byte() # 转换成 ByteTensor类型
data.int() # 转换成 IntTensor类型
此外,可以使用type()函数,当不给定参数时,返回tensor的类型;给定参数则进行强制转换,如:
data.type(torch.ByteTensor) # 转换成ByteTensor类型
(2)cpu和gpu的转换:
data.cpu() # 转换成cpu类型
data.cuda() # 转换成gpu类型
np类型转换需要用到astype()函数,修改dtype属性具体如下:
a.astype(np.int64) # 转换数据类型int64
a.dtype # 查看数据类型
需要注意使用astype()函数不会修改原array的dtype属性,返回了一个新的array。
(1) tensor 转换成 array :
new_data = data.numpy() # data为tensor类型,new_data为array类型,二者指向同一地址。
(2)array 转换成 tensor :
new_data = torch.from_numpy(data) # data为array类型,new_data为tensor类型,二者指向同一地址。
这需要用到torchvision.transforms()函数,具体操作如下。
(1)tensor转换成PIL Image格式, 使用transforms.ToPILImage()函数,需要注意:
Converts a torch.*Tensor of shape C x H x W to a PIL Image
形状必须是CHW,最好是FloatTensor。
img = transforms.ToPILImage()(in_tensor) # in_tensor 是待转换的tensor
(2)array转换成PIL Image格式 ,也使用transforms.ToPILImage()函数,需要注意:
Converts a numpy ndarray of shape H x W x C to a PIL Image
形状必须是HWC,而且要求dtype=uint8, range[0, 255] 。
img = transforms.ToPILImage()(in_array) # in_array 是待转换的array
(3)PIL Image转换成tensor格式,使用transforms.ToTensor()函数,也就是将HWC形状的PIL Image 转换成CHW形状的tensor。
tensor = transforms.ToTensor()(img) # img 是待转换的PIL Image
注意,得到的tensor类型为torch.FloatTensor。