传统的 CV 问题,一般把特征提取和分类模型的构建训练分成两个步骤,CNN 可以把这两者合在一个网络里,目前很多实验证明,利用大量数据训练过的 CNN 可以用作很好的特征提取器,类似一种特征迁移。
今天介绍一下,如何利用 Tensorflow 和 预先训练好的模型,做特征提取,我们可以用 TensorFlow GitHub 官网上的预训练模型来做特征提取:
https://github.com/tensorflow/models/tree/master/research/slim
预训练模型,是用 ImageNet 训练过的,网站上有 VGG, ResNet, 以及 Inception 等几种不同类似的训练模型:
今天我们利用一个轻量级的模型 Mobilenet_v1 来做特征提取,首先下载好训练好的模型:mobilenet_v1_1.0_224 ckpt
利用 ckpt 我们还可以查看整个网络的结构,以及每一层的 feature map
首先我们载入相应的模块:
import tensorflow as tf
import numpy as np
import glob
from nets import mobilenet_v1
slim = tf.contrib.slim
然后定义一个函数: 这个函数可以帮我们解析图片路径,读取图片,做预处理,然后转成 tensor 形式:
def mobi_parse_fun(x_in, y_label=1):
img_path = tf.read_file(x_in)
img_decode = tf.io.decode_jpeg(img_path, channels=3)
img = tf.image.resize_images(img_decode, [224, 224])
img = tf.cast(img, tf.float32) / 127.5 - 1.0
return img, y_label
接下来,我们可以利用 TensorFlow 中的 dataset 模块,处理数据:
X_in = tf.placeholder(tf.string, None)
# Y_in = tf.placeholder(tf.int32, None)
train_data = tf.data.Dataset.from_tensor_slices((X_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(1)
iter_ = tf.data.Iterator.from_structure(train_data.output_types,
train_data.output_shapes)
x_batch, y_batch = iter_.get_next()
train_init_op = iter_.make_initializer(train_data)
然后调用网络的定义,并且加载模型所在的路径:
with tf.contrib.slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
logits, endpoints = mobilenet_v1.mobilenet_v1(x_batch, num_classes=1001)
ckpt_path = 'D:\Python_Code\mobilenet_v1_1.0_224\mobilenet_v1_1.0_224.ckpt'
saver = tf.train.Saver()
我们获取图片的存储路径
img_path = 'F:\cute\*.jpg'
img_list = glob.glob(img_path)
接下来,就可以定义一个 session,并且把模型加载进来:
with tf.Session() as sess:
saver.restore(sess, ckpt_path)
## 查看网络每一层的参数:
print('print the trainable parameters: ')
for eval_ in tf.trainable_variables():
print(eval_.name)
w_val = sess.run(eval_.name)
print(w_val.shape)
sess.run(train_init_op, feed_dict={X_in: img_list})
#---------------------------------------------
#---------------------------------------------
# 查看每一层的 feature map,
key_name = endpoints.keys()
print('print the feature maps: ')
for name_ in key_name:
print(name_)
feat_map = sess.run(endpoints[name_])
print(feat_map.shape)
fc_map = endpoints['AvgPool_1a']
fc_feat = tf.squeeze(fc_map, [1, 2])
for img_name in img_list:
print(img_name)
x_bat, y_bat = sess.run([x_batch, y_batch])
print(x_bat.shape, y_bat.shape)
fc_feature = sess.run([fc_feat])
print(fc_feature[0].shape)
break
我们可以查看 Mobinet_V1 的网络结构如下:
MobilenetV1/Conv2d_0/weights:0 (3, 3, 3, 32)
MobilenetV1/Conv2d_0/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_0/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_depthwise/depthwise_weights:0 (3, 3, 32, 1)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_pointwise/weights:0 (1, 1, 32, 64)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_depthwise/depthwise_weights:0 (3, 3, 64, 1)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_pointwise/weights:0 (1, 1, 64, 128)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_pointwise/weights:0 (1, 1, 128, 128)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_pointwise/weights:0 (1, 1, 128, 256)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_pointwise/weights:0 (1, 1, 256, 256)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_pointwise/weights:0 (1, 1, 256, 512)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_pointwise/weights:0 (1, 1, 512, 1024)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/depthwise_weights:0 (3, 3, 1024, 1)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/weights:0 (1, 1, 1024, 1024)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Logits/Conv2d_1c_1x1/weights:0 (1, 1, 1024, 1001)
MobilenetV1/Logits/Conv2d_1c_1x1/biases:0 (1001,)
而网络中的 feature map如下:
print the feature maps:
Conv2d_0 (1, 112, 112, 32)
Conv2d_1_depthwise (1, 112, 112, 32)
Conv2d_1_pointwise (1, 112, 112, 64)
Conv2d_2_depthwise (1, 56, 56, 64)
Conv2d_2_pointwise (1, 56, 56, 128)
Conv2d_3_depthwise (1, 56, 56, 128)
Conv2d_3_pointwise (1, 56, 56, 128)
Conv2d_4_depthwise (1, 28, 28, 128)
Conv2d_4_pointwise (1, 28, 28, 256)
Conv2d_5_depthwise (1, 28, 28, 256)
Conv2d_5_pointwise (1, 28, 28, 256)
Conv2d_6_depthwise (1, 14, 14, 256)
Conv2d_6_pointwise (1, 14, 14, 512)
Conv2d_7_depthwise (1, 14, 14, 512)
Conv2d_7_pointwise (1, 14, 14, 512)
Conv2d_8_depthwise (1, 14, 14, 512)
Conv2d_8_pointwise (1, 14, 14, 512)
Conv2d_9_depthwise (1, 14, 14, 512)
Conv2d_9_pointwise (1, 14, 14, 512)
Conv2d_10_depthwise (1, 14, 14, 512)
Conv2d_10_pointwise (1, 14, 14, 512)
Conv2d_11_depthwise (1, 14, 14, 512)
Conv2d_11_pointwise (1, 14, 14, 512)
Conv2d_12_depthwise (1, 7, 7, 512)
Conv2d_12_pointwise (1, 7, 7, 1024)
Conv2d_13_depthwise (1, 7, 7, 1024)
Conv2d_13_pointwise (1, 7, 7, 1024)
AvgPool_1a (1, 1, 1, 1024)
Logits (1, 1001)
Predictions (1, 1001)
我们看到,最靠近 FC 的是 AvgPool_1a 这层的 feature map,所以我们将这层的 feature map抽取出来,就可以当成我们输入图像的特征来用了。