AssertionError: size of input tensor and input format are different.

想用tensorboard添加图片的时候出现以下错误

Traceback (most recent call last):
  File "/usr/local/pycharm-2020.3.5/plugins/python/helpers/pydev/pydevd.py", line 1477, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/usr/local/pycharm-2020.3.5/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/data/chenrj/paper5/main.py", line 142, in 
    acc_test,  epoch_max = train_SemiTime(
  File "/data/chenrj/paper5/optim/pretrain.py", line 95, in train_SemiTime
    test_acc, best_epoch = model.train(tot_epochs=tot_epochs, train_loader=train_loader,
  File "/data/chenrj/paper5/model/semiSOP.py", line 99, in train
    writer.add_images('cor_matrix', cor_matrix, epoch)
  File "/data/chenrj/.local/lib/python3.9/site-packages/torch/utils/tensorboard/writer.py", line 589, in add_images
    image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
  File "/data/chenrj/.local/lib/python3.9/site-packages/torch/utils/tensorboard/summary.py", line 376, in image
    tensor = convert_to_HWC(tensor, dataformats)
  File "/data/chenrj/.local/lib/python3.9/site-packages/torch/utils/tensorboard/_utils.py", line 100, in convert_to_HWC
    assert(len(tensor.shape) == len(input_format)), "size of input tensor and input format are different. \
AssertionError: size of input tensor and input format are different.         tensor shape: (128, 128), input_format: NCHW

原因

这是因为我代码中的数据是一个二维矩阵,而add_images这个函数接受的输入是一个四维的数据,为NCHW,N是数据样本量,C为channel,H和W为图片的长和宽。

解决方法

因为我的项目中是不断生成的特征图想保存下来,因此用add_images()这个函数的话没办法实现(原因解释过了),因此有2种解决方法:

1、将特征图先用一个列表保存下来,然后最后再add_images()。

2、把函数修改为add_image,add_image()函数要求的输入是一个三维的数据,为CHW。

因为我需要查看图片的变化,所以我这里采用第二种方法,将我的二维矩阵添加一维C的维度,然后在

add_image。代码如下

writer.add_image('cor_matrix', cor_matrix.unsqueeze(0), epoch)

打开tensorboard可以通过滑动条查看过往的图像。

AssertionError: size of input tensor and input format are different._第1张图片

 

你可能感兴趣的:(Bug,大数据)