错误:
Traceback (most recent call last):
File "/home/nianxiongdi/algorithm/deform-conv/scripts/scaled_mnist1.py", line 94, in
inputs, outputs, model = get_deform_cnn(use_cpu=False, print_summary=True)
File "/home/nianxiongdi/algorithm/deform-conv/deform_conv/cnn.py", line 405, in get_deform_cnn
conv_block_1 = buildConv2DBlock(inputs, 64, 1, 2)
File "/home/nianxiongdi/algorithm/deform-conv/deform_conv/cnn.py", line 251, in buildConv2DBlock
conv2d = ConvOffset2D(filters, name='conv12_offset')(conv2d)
File "/home/nianxiongdi/anaconda3/envs/py36/lib/python3.6/site-packages/keras/engine/topology.py", line 583, in __call__
previous_mask = _collect_previous_mask(inputs)
File "/home/nianxiongdi/anaconda3/envs/py36/lib/python3.6/site-packages/keras/engine/topology.py", line 2737, in _collect_previous_mask
mask = node.output_masks[tensor_index]
大概代码:
A.py
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, Lambda, Layer, BatchNormalization, Activation
from tensorflow.keras import backend as K
B.py
from keras.layers import Conv2D
class ConvOffset2D(Conv2D): """ConvOffset2D""" def __init__(self, filters, init_normal_stddev=0.01, **kwargs): """Init""" self.filters = filters # 32 fea super(ConvOffset2D, self).__init__( self.filters * 2, (3, 3), padding='same', use_bias=False, # TODO gradients are near zero if init is zeros kernel_initializer='zeros', # kernel_initializer=RandomNormal(0, init_normal_stddev), **kwargs )
A.py 调用了B.py 出错的原因是因为 类型不同:
调用到: self.assert_input_compatibility(inputs) Checks compatibility between the layer and provided inputs.
解决问题:
把B.py与A.py文件中的import的类型进行统一
from tensorflow.keras.layers import Conv2D
class ConvOffset2D(Conv2D): """ConvOffset2D""" def __init__(self, filters, init_normal_stddev=0.01, **kwargs): """Init""" self.filters = filters # 32 fea super(ConvOffset2D, self).__init__( self.filters * 2, (3, 3), padding='same', use_bias=False, # TODO gradients are near zero if init is zeros kernel_initializer='zeros', # kernel_initializer=RandomNormal(0, init_normal_stddev), **kwargs )