TensorFlow抽取cnn中某一层特征

      深度学习具有强大的特征表达能力。有时候我们训练好分类模型,并不想用来进行分类,而是用来提取特征用于其他任务,比如相似图片计算。接下来讲下如何使用TensorFlow提取特征。

1.必须在模型中命名好要提取的那一层,num_filters_total为提取特征维度,即特征个数,如下

    net = _global_avg(net, pool_size=net.get_shape()[1:-1], strides=1)
    net = tf.reshape(net, [-1, num_filters_total], name='reshape_feature')

2.通过调用sess.run()来获取reshape_feature层特征

    feature = graph.get_operation_by_name("reshape_feature").outputs[0]

    batch_predictions, batch_feature = \

    sess.run([predictions, feature], {input_x: x_test_batch, dropout_keep_prob: 1.0}

 

你可能感兴趣的:(TensorFlow,深度学习)