【深度学习】tensorflow加载VGG16的网络结构和模型参数

文件介绍

synset.txt:标签列表
vgg16-20160129.tfmodel:pre-trained vgg16的网络结构和结点参数

定义输入placeholder

images = tf.placeholder("float", [None, 224, 224, 3])

加载模型

with open("model/vgg16-20160129.tfmodel", mode='rb') as f:
  fileContent = f.read()

创建Graph,导入pre-trained模型

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
tf.import_graph_def(graph_def, input_map={ "images": images })
graph = tf.get_default_graph()

定义feed_dict

for i in ['cat.jpg','airplane.jpg','zebra.jpg','pig.jpg']:
  img=load_image('model/pic/'+i)
  plt.imshow(img)
  plt.show()
  imgs.append(img)
img_num=len(imgs)

batch = np.array(imgs).reshape((img_num, 224, 224, 3))
assert batch.shape == (img_num, 224, 224, 3)
feed_dict = { images: batch }

进行预测

prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)

结果展示如下

【深度学习】tensorflow加载VGG16的网络结构和模型参数_第1张图片

完整代码在我的GitHub上:https://github.com/mjDelta/tensorflow-examples/blob/master/load_vgg16.py

pre-trained model百度云分享 链接:https://pan.baidu.com/s/1mhEzH4s 密码:u7ap

你可能感兴趣的:(深度学习)