前言:
最近想跑一些主流的网络感受感受。从github上找到了 deep-learning-models 提供的几个模型,包括:inception-v2, inception-v3, resnet50, vgg16, vgg19 等等。这些代码都是基于 keras 框架,正好我最近有在学 tensorflow 和 keras,所以很想跑跑这些代码。
心动不如行动,万事俱备,只欠把代码跑起来。此时,出现了一些常见的问题,也正好借此机会整理下来。问题如下:
1)_obtain_input_shape() got an unexpected keyword argument 'include_top'
2)Exception: URL fetch failure on https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5: None -- [Errno 110] Connection timed out
本文主要分析和整理了第一个问题的解决方案。
第二个问题就是:在下载模型参数文件的过程中,可能url对应的地址被墙了,导致下载不了。解决办法就是另想办法把 resnet50_weights_tf_dim_ordering_tf_kernels.h5 这个文件下载下来。如果有需要的话,可在本文留言。
Reference:
1)github 源码:https://github.com/fchollet/deep-learning-models
2)模型参数资源:https://github.com/fchollet/deep-learning-models/releases
3)相关博客:http://blog.csdn.net/sinat_26917383/article/details/72982230
4)本文使用的测试数据如下所示:
elephant.jpg
本文使用的 tensorflow 和 keras 的版本:
- tensorflow:
>>> import tensorflow as tf
>>> tf.__version__
'1.1.0'
- keras:
import keras
>>> print keras.__version__
2.0.9
本文实践步骤如下,以 "resnet50.py" 为例:
1)下载 github 源码,源码中使用一张名为“elephant.jpg”的图像作为测试。
2)下载测试数据集,上文有提供链接。3)在源码的目录下执行如下命令:
python resnet50.py
程序报错,如下所示:
Traceback (most recent call last):
File "resnet50.py", line 289, in
model = ResNet50(include_top=True, weights='imagenet')
File "resnet50.py", line 193, in ResNet50
include_top=include_top)
TypeError: _obtain_input_shape() got an unexpected keyword argument 'include_top'
导致程序报错的原因分析:
1)keras.__version__ == 2.0.9 中,函数 _obtain_input_shape() 的形式:
def _obtain_input_shape(input_shape,
default_size,
min_size,
data_format,
require_flatten,
weights=None):
2)deep-learning-models 案例中,调用 _obtain_input_shape() 函数的方式如下:
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=299,
min_size=71,
data_format=K.image_data_format(),
include_top=include_top)
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=197,
data_format=K.image_data_format(),
require_flatten=include_top)
再次执行命令,就可以成功运行案例代码,如下图所示:
本文的第二个问题:无法正常下载模型参数
模型参数“resnet50_weights_tf_dim_ordering_tf_kernels.h5”下载地址如下:
链接:http://pan.baidu.com/s/1dE1Lh5J 密码:puke
需修改的代码如下:
# WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
# load weights
if weights == 'imagenet':
# if include_top:
# weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels.h5',
# WEIGHTS_PATH,
# cache_subdir='models',
# md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
# else:
# weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
# WEIGHTS_PATH_NO_TOP,
# cache_subdir='models',
# md5_hash='a268eb855778b3df3c7506639542a6af')
weights_path = './resnet50_weights_tf_dim_ordering_tf_kernels.h5'
model.load_weights(weights_path)
Predicted: [[(u'n01871265', u'tusker', 0.65325415), (u'n02504458', u'African_elephant', 0.29492217), (u'n02504013', u'Indian_elephant', 0.048155606), (u'n02422106', u'hartebeest', 0.001847562), (u'n02397096', u'warthog', 0.00034257883)]]
由结果可知:p(tusker) = 0.65, p(African_elephant) = 0.29, p(Indian_elephant) = 0.048 ... ... 其中 tusker 的概率是最高的,所以识别结果为 tusker(有长牙的动物,如:象,野猪等)