Keras Flatten的input_shape问题

在fine tune Keras Applications中给出的分类CNN Model的时候,如果在Model的top层之上加入Flatten层就会出现错误。可能的报错信息类似下面的内容:

$ python3 ./train.py
Using TensorFlow backend.
Found 60000 images belonging to 200 classes.
Found 20000 images belonging to 200 classes.
# 略过一些信息...
Creating TensorFlow device (/device:GPU:0) ->
(device: 0, name: GeForce GTX 1080, pci bus id: 0000:02:00.0, compute capability: 6.1)

# ↓↓↓ 错误出现 ↓↓↓
Traceback (most recent call last):
  File "./train.py", line 51, in 
    x = Flatten()(x)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 636, in __call__
    output_shape = self.compute_output_shape(input_shape)
  File "/usr/local/lib/python3.5/dist-packages/keras/layers/core.py", line 490, in
    compute_output_shape
    '(got ' + str(input_shape[1:]) + '. '
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536).
Make sure to pass a complete "input_shape" or "batch_input_shape" argument
to the first layer in your model.
# ↑↑↑ 错误结束 ↑↑↑

出错的代码行是x = Flatten()(x),错误提示为ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.

Flatten()(x)希望参数拥有确定shape属性,实际得到的参数xshape属性是(None, None, 1536),很明显不符合要求。同时,错误提示信息中也给出了修正错误的方法Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model。即,在Model的第一层给出确定的input_shapebatch_input_shape。那么,如何在Keras中解决该问题呢?

以Keras Applications中的VGG16为例,我们只需要在其初始化的时候,给出具体的input_shape就可以了。例如,Keras给出的VGG16模型输入层图像尺寸是(224, 224)的,所以如果使用TensorFlow的channels_last数据格式,则初始化代码为:

vgg16 = keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
x = vgg16.output
x = Flatten()(x)
...

注意,因为要fine tune模型,对模型分类的种类和类别数进行重新定义,所以include_top=False,这样返回的模型不包括VGG16的全连接层和输出层。


所以该类似问题,只需指定input_shape参数

对于tensorflow后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(224,224,3))

对于theano后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(3,224,224))


参考:

  1. Unable to fine tune Keras vgg16 model - input shape issue
  2. Keras Applications
  3. 简书

主要参考作者:Aspirinrin
链接:https://www.jianshu.com/p/ec188fa1cca1
 

你可能感兴趣的:(Keras,卷积神经网络,深度学习)