复现问题记录 | StackGAN-v2 (in python3)(二)

复现问题记录

    • 1.cannot import name 'FileWriter' from 'tensorboard'
    • 2.summary.image()中的tensor_format
    • 3.Image.fromarray():TypeError("Cannot handle this data type")
    • 4.一些warning的处理

复现问题记录 | StackGAN-v2 (in python3)(一)
复现问题记录 | StackGAN-v2 (in python3)(三)

1.cannot import name ‘FileWriter’ from ‘tensorboard’

报错信息

Traceback (most recent call last):
  File "main.py", line 139, in 
    from trainer import condGANTrainer as trainer
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 20, in 
    from tensorboard import FileWriter
ImportError: cannot import name 'FileWriter' from 'tensorboard' (/home/user/aib/python3.7/site-packages/tensorboard/__init__.py)

解决方法是使用tensorboardXtorch.utils.tensorboard,并且要注意版本和torch包对齐,但是使用这两个包后续还会有bug产生,比较复杂另开一个问题。

tensorboardX

pip install tensorboardX==1.4 #安装1.4版本是为了和环境保持一致。
#from tensorboard import summary
#from tensorboard import FileWriter
from tensorboardX import summary
from tensorboardX import FileWriter

参考:https://blog.csdn.net/weixin_38382622/article/details/107906338

torch.utils.tensorboard
注意一下torch.utils.tensorboard不用安,直接import就行,因为它是torch包里自带的。第一次装错误出现的原因是pip install会安最新的版本,而我的torch包是1.4版的,两个版本不一致报错。

#from tensorboard import summary
#from tensorboard import FileWriter
from torch.utils.tensorboard import summary
from torch.utils.tensorboard import FileWriter

第一次的错误版本记录在下面。

---------------------------------------分割线:第一次修改-------------------------------------------
发现有人也是在跑StackGAN的时候遇到了这个问题,解决方法是用别的tensorboard包,如torch.utils.tensorboard( pip install tensorboard-pytorc )和tensorboardX( pip install tensorboardX )。
参考:https://ask.csdn.net/questions/1148279

先试了一下torch.utils.tensorboard结果报了个巨复杂的错,给我看傻了; 再试一下tensorboardXtensorboard还是一样的错,好家伙给我整坏了。

Traceback (most recent call last):
  File "main.py", line 139, in 
    from trainer import condGANTrainer as trainer
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 19,in 
    from torch.utils.tensorboard import summary
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/__init__.py", line 2, in 
    from tensorboard.summary.writer.record_writer import RecordWriter  # noqa F401
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/__init__.py", line 4, in 
    from .writer import FileWriter, SummaryWriter
  File "/home/user/anaconda3/lib/python3.7/sitepackages/tensorboard/writer.py", line 28, in 
    from .summary import scalar, histogram, image, audio, text
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/summary/__init__.py", line 25, in 
    from tensorboard.summary import v1
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/summary/v1.py", line 24, in 
    from tensorboard.plugins.audio import summary as _audio_summary
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/plugins/audio/summary.py", line 36, in 
    from tensorboard.plugins.audio import metadata
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/plugins/audio/metadata.py", line 21, in 
    from tensorboard.compat.proto import summary_pb2
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/compat/proto/summary_pb2.py", line 15, in 
    from tensorboard.compat.proto import tensor_pb2 as tensorboard_dot_compat_dot_proto_dot_tensor__pb2
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/compat/proto/tensor_pb2.py", line 15, in 
    from tensorboard.compat.proto import resource_handle_pb2as tensorboard_dot_compat_dot_proto_dot_resource__handle__pb2
  File "/home/user/anaconda3/lib/python3.7/site-packages/tensorboard/compat/proto/resource_handle_pb2.py", line 22, in 
    serialized_pb=_b('\n.tensorboard/compat/proto/resource_handle.proto\x12\x0btensorboard\"r\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04\x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05\x01(\tBn\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3')
  File "/home/user/anaconda3/lib/python3.7/site-packages/google/protobuf/descriptor.py", line 884, in __new__
    return _message.default_pool.AddSerializedFile(serialized_pb)
TypeError: Couldn't build proto file into descriptor pool!
Invalid proto descriptor for file "tensorboard/compat/proto/resource_handle.proto":
  tensorboard.ResourceHandleProto.device: "tensorboard.ResourceHandleProto.device" is already defined in file "tensorboard/src/resource_handle.proto".
  tensorboard.ResourceHandleProto.container: "tensorboard.ResourceHandleProto.container" is already defined in file "tensorboard/src/resource_handle.proto".
  tensorboard.ResourceHandleProto.name: "tensorboard.ResourceHandleProto.name" is already defined in file"tensorboard/src/resource_handle.proto".
  tensorboard.ResourceHandleProto.hash_code: "tensorboard.ResourceHandleProto.hash_code" is already defined in file "tensorboard/src/resource_handle.proto".
  tensorboard.ResourceHandleProto.maybe_type_name: "tensorboard.ResourceHandleProto.maybe_type_name" is already defined in file "tensorboard/src/resource_handle.proto".
  tensorboard.ResourceHandleProto: "tensorboard.ResourceHandleProto" is already defined in file "tensorboard/src/resource_handle.proto".

这一块是因为版本问题,于是卸载刚装好的两个包,重新做一下版本的功课。
参考:https://blog.csdn.net/wasjrong/article/details/108869519

2.summary.image()中的tensor_format

安装torch.utils.tensorboard的报错

Traceback (most recent call last):
  File "main.py", line 144, in 
    algo.train()
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 778, in train
    count, self.image_dir, self.summary_writer)
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 208, in save_img_results
    sup_real_img = summary.image('real_img', real_img_set)
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/summary.py", line 327, in image
    tensor = convert_to_HWC(tensor, dataformats)
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/utils/tensorboard/_utils.py", line 101, in convert_to_HWC
    tensor shape: {}, input_format: {}".format(tensor.shape, input_format)
AssertionError: size of input tensor and input format are different.
  	tensor shape: (776, 2066, 3), input_format: NCHW

问题分析
这个问题网上没有教程,只能自己解决。根据报错&代码问题应该出在变量real_img_set和函数summary.image()

  • real_img_set是一个np里的array变量,在summary的内置里进行了array转image操作。在代码里加print(real_img_set.shape),得到数组形状为(776, 2066, 3)[H, W, C]。
  • image()在源码中的声明是image(tag, tensor, rescale=1, dataformats='NCHW'),其中tensor的格式默认是[N, H, W, C]。报错的位置在tensor = convert_to_HWC(tensor, dataformats)这一句,而convert_to_HWC()里会检查tensor.shape的长度和dataformats的长度是否一致。因为我们输入的real_img_set只有三个维度,所以引发了报错。

此处我想了一下,没弄懂这个N该从哪来,所以直接把dataformats改为’HWC’。修改如下

    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255 #numpy.ndarray object
    real_img_set = real_img_set.astype(np.uint8) #(776, 2066, 3)(H, W, C)
    #sup_real_img = summary.image('real_img', real_img_set) #modified
    sup_real_img = summary.image('real_img', real_img_set, dataformats='HWC')
    summary_writer.add_summary(sup_real_img, count)

科普:
summary.image()
用来输出Summary的图像

  • name:节点的名字,也就是在tensorboard上面会显示的名字。
  • tensor:格式必须是三维的[height, width, channels]
    • 对于channels
    • channels=1为灰度图像
    • channels=3为RGB图像
    • channels=4为RGBA图像(Red(红色) Green(绿色) Blue(蓝色)和 Alpha合成(代表透明度))

下面是torch.utils.tensorboardsummary.image()的源码

def image(tag, tensor, rescale=1, dataformats='NCHW'):
    """Outputs a `Summary` protocol buffer with images.
    The summary has up to `max_images` summary values containing images. The
    images are built from `tensor` which must be 3-D with shape `[height, width,
    channels]` and where `channels` can be:
    *  1: `tensor` is interpreted as Grayscale.
    *  3: `tensor` is interpreted as RGB.
    *  4: `tensor` is interpreted as RGBA.
    The `name` in the outputted Summary.Value protobufs is generated based on the
    name, with a suffix depending on the max_outputs setting:
    *  If `max_outputs` is 1, the summary value tag is '*name*/image'.
    *  If `max_outputs` is greater than 1, the summary value tags are
       generated sequentially as '*name*/image/0', '*name*/image/1', etc.
    Args:
      tag: A name for the generated node. Will also serve as a series name in
        TensorBoard.
      tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
        channels]` where `channels` is 1, 3, or 4.
        'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
        The image() function will scale the image values to [0, 255] by applying
        a scale factor of either 1 (uint8) or 255 (float32).
    Returns:
      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
      buffer.
    """
    tag = _clean_tag(tag)
    tensor = make_np(tensor)
    tensor = convert_to_HWC(tensor, dataformats)  #bug1
    # Do not assume that user passes in values in [0, 255], use data type to detect
    scale_factor = _calc_scale_factor(tensor)
    tensor = tensor.astype(np.float32)
    tensor = (tensor * scale_factor).astype(np.uint8)
    image = make_image(tensor, rescale=rescale)  #bug2
    return Summary(value=[Summary.Value(tag=tag, image=image)])

ndarray

import numpy as np
a1 = np.array([1,2,3,4],dtype=np.complex128)  
print("数据类型",type(a1))           #打印数组数据类型  
print("数组元素数据类型:",a1.dtype) #打印数组元素数据类型  
print("数组元素总数:",a1.size)      #打印数组尺寸,即数组元素总数  
print("数组形状:",a1.shape)         #打印数组形状  
print("数组的维度数目",a1.ndim)      #打印数组的维度数目 

NCHW
N代表数量, C代表channel,H代表高度,W代表宽度。
NCHW即(N, C, H, W)第一个元素是000,第二个元素是沿着w方向的,即001,这样下去002 003,再接着呢就是沿着H方向,即004 005 006 007…这样到09后,沿C方向,轮到了020,之后021 022 …一直到319,然后再沿N方向。
复现问题记录 | StackGAN-v2 (in python3)(二)_第1张图片
参考:https://blog.csdn.net/weixin_41847115/article/details/83794551

3.Image.fromarray():TypeError(“Cannot handle this data type”)

安装tensorboardX的报错

Traceback (most recent call last):
  File "main.py", line 144, in 
    algo.train()
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 776, in train
    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
  File "/mnt/data3/yc/StackGAN-v2/code/trainer.py", line 208, in save_img_results
    sup_real_img = summary.image('real_img', real_img_set, dataformats='CWH')
  File "/home/user/anaconda3/lib/python3.8/site-packages/torch/utils/tensorboard/summary.py", line 332, in image
    image = make_image(tensor, rescale=rescale)
  File "/home/user/anaconda3/lib/python3.8/site-packages/torch/utils/tensorboard/summary.py", line 370, in make_image
    image = Image.fromarray(tensor)
  File "/home/user/anaconda3/lib/python3.8/site-packages/PIL/Image.py", line 2766, in fromarray
    raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
TypeError: Cannot handle this data type: (1, 1, 776), |u1

搜索了一下fromarray(tensor)的TypeError(“Cannot handle this data type”),找到的教程里都只给出了一种解决,即tensor的类型有可能是uint16矩阵,需要转化为uint8。stackGAN和tensorboard源码中都有做该转化,不是因为这个原因。
于是从报错和源码下手,summary中没问题。

#formarray()部分代码
	arr = obj.__array_interface__
    shape = arr["shape"]
    ndim = len(shape)
    strides = arr.get("strides", None)
    if mode is None:
        try:
            typekey = (1, 1) + shape[2:], arr["typestr"]
        except KeyError as e:
            raise TypeError("Cannot handle this data type") from e
        try:
            mode, rawmode = _fromarray_typemap[typekey]
        except KeyError as e:
            raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e

根据 formarray() 的typekey
(1, 1) + shape[2:] = (1,1,776)
arr["typestr"] = |u1
_fromarray_typemap,发现问题出在(1,1,776)

_fromarray_typemap = {
    # (shape, typestr) => mode, rawmode
    # first two members of shape are set to one
    ((1, 1), "|b1"): ("1", "1;8"),
    ((1, 1), "|u1"): ("L", "L"),
    ((1, 1), "|i1"): ("I", "I;8"),
    ((1, 1), "): ("I", "I;16"),
    ((1, 1), ">u2"): ("I", "I;16B"),
    ((1, 1), "): ("I", "I;16S"),
    ((1, 1), ">i2"): ("I", "I;16BS"),
    ((1, 1), "): ("I", "I;32"),
    ((1, 1), ">u4"): ("I", "I;32B"),
    ((1, 1), "): ("I", "I;32S"),
    ((1, 1), ">i4"): ("I", "I;32BS"),
    ((1, 1), "): ("F", "F;32F"),
    ((1, 1), ">f4"): ("F", "F;32BF"),
    ((1, 1), "): ("F", "F;64F"),
    ((1, 1), ">f8"): ("F", "F;64BF"),
    ((1, 1, 2), "|u1"): ("LA", "LA"),
    ((1, 1, 3), "|u1"): ("RGB", "RGB"),
    ((1, 1, 4), "|u1"): ("RGBA", "RGBA"),
}

由map可以看出,这个函数输入的tensor要么是二维,要么第三维是Channels。这时思路就很清晰了,分析一下tensor变量的变形过程,输入summary.image时为(776,2066,3)[H,W,C]。在summary中经历convert_to_HWC()输入formarray变成(3,2066,776)。与上个问题一样,同样是在summary.image()中把把dataformats改为’HWC’。

    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255 #numpy.ndarray object
    real_img_set = real_img_set.astype(np.uint8) #(776, 2066, 3)(H, W, C)
    #sup_real_img = summary.image('real_img', real_img_set) #modified
    sup_real_img = summary.image('real_img', real_img_set, dataformats='HWC')
    summary_writer.add_summary(sup_real_img, count)

总结: 到这里代码就可以跑通了,后面是一些小问题的处理。
这里7、8解决的问题很像,这个问题花了我两天时间,最开始真的是焦头烂额。收获是终于把NCHW这堆乱七八糟的搞明白了。

4.一些warning的处理

warning里基本是一些版本的问题,包括python版本、torch版本等,按提示修改对应函数即可。
deprecated就是说这个版本已经被弃用不再维护了。
(1)transforms.Scale

/home/user/anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:219: UserWarninghe use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
/home/user/anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:219: UserWarninghe use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
/home/user/anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:219: UserWarninghe use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
/home/user/anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py:219: UserWarninghe use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +

(2)nn.init.orthogonal

/mnt/data3/yc/StackGAN-v2/code/trainer.py:66: UserWarning: nn.init.orthogonal is now deprecated in favor of nn.init.orthogonal_.
  nn.init.orthogonal(m.weight.data, 1.0)
/mnt/data3/yc/StackGAN-v2/code/trainer.py:61: UserWarning: nn.init.orthogonal is now deprecated in favor of nn.init.orthogonal_.

(3)nn.functional.sigmoid

/home/user/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:1351: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")

(4)nn.Upsample

/home/user/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py:2503: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  warnings.warn("Default upsampling behavior when mode={} is changed "

这个warning出现在class INCEPTION_V3(nn.Module)中,是否修改要看你的代码需不需要align_corners属性。
这篇论文是2017年写的,pytorch0.4.0版本是2018年出的,由此推断需要添加这一参数。

        # --> fixed-size input: batch x 3 x 299 x 299
        #x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
        x = nn.Upsample(size=(299, 299), mode='bilinear', align_corners = True)(x) # modified

(5)Implicit dimension choice for softmax

/mnt/data3/yc/StackGAN-v2/code/model.py:43: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  x = nn.Softmax()(x)

根据warning的提示,需要指明dim的参数,根据网上教程,通常dim=1
参考:https://sakura.blog.csdn.net/article/details/105743380

		# INCEPTION_V3
        # 299 x 299 x 3
        x = self.model(x) #[24,1000]
        x = nn.Softmax()(x) #[24,1000]
        return x

科普:
nn.softmax()的dim参数选择

(6)yaml.load()

/mnt/data3/yc/StackGAN-v2/code/miscc/config.py:106: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.

根据提示修改

def cfg_from_file(filename):
    """Load a config file and merge it into the default options."""
    import yaml
    with open(filename, 'r') as f:
        #yaml_cfg = edict(yaml.load(f))
        yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) #modified

    _merge_a_into_b(yaml_cfg, __C)

你可能感兴趣的:(深度学习,深度学习,python,pytorch)