原文链接: mobilenet_v2_1.4_224 flowers 数据集分类网络
上一篇: tensorboard 自编码分类网络和vgg19 网络结构可视化
下一篇: 拟合三点共圆 圆心坐标和半径
mobilenet_v2_1.4_224 参数比vgg少很多,整个模型总共30M左右,运行一个批次(32)张图片只需要0.02秒左右,比vgg19快很多很多
下载模型文件
https://github.com/tensorflow/models/tree/master/research/slim/#Pretrained
flowers数据集
链接:https://pan.baidu.com/s/1YGPHQfX56m5bkgpDiOBI2A
提取码:u3hk
下载slim模块,将nets文件夹放入 D:\ProgramData\Anaconda3\Lib\site-packages下,方便在项目中引入
https://github.com/tensorflow/models
查看网络结构,保存网络数据到pb文件
import tensorflow as tf
from nets.mobilenet.mobilenet_v2 import mobilenet_v2_140
from tensorflow.python.framework import graph_util
IMAGE_SIZE = 224
in_x = tf.placeholder(dtype=tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3), name='in_x')
mobilenet_v2_140(in_x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
SAVE_PATH = './pb/mobile_simple.pb'
# 保存网络结构
tf.summary.FileWriter('./log/', sess.graph)
# 保存网络数据
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['MobilenetV2/embedding'])
with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
f.write(constant_graph.SerializeToString())
运行,在浏览器中查看结构
tensorboard --logdir=./log
最后卷积层大小为N*7*7*1792
logits 层用了7*7的均值池化将大小转化为N*1792
整个网络大小只有20M左右
官方pb大小,由于含有predict层所以稍微大一点,float32 四个字节
4 * 1792 * 1001 / 1024 ** 2
6.8427734375
微调网络
1000次训练基本上能达到90%左右准确率
在获得embedding特征后,先通过一个卷积层,将维数减少,之后在放入全连接层,减少参数数目
在这里使用了leaky_relu作为激活函数,因为预处理图像时含有负数,使用relu的话网络很容易死亡
dropout一般不用在最后一层,效果不好
在最后一层使用dropout与不使用效果差距很大,网络不容易收敛,并且准确率也比较差
将预测的分类结果与经过一层softmax后的logits作为输出节点
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.python.framework import graph_util
import time
from nets.mobilenet.mobilenet_v2 import mobilenet_v2_140
# mobilenet_v2 网络预处理
# 也可以放在读取record中做
# def preprocess_input(x):
# x = x.astype(np.float32)
# x /= 128.
# x -= 1.
# return x
TRAIN_STEP = 5000
SHOW_STEP = 10
# 批次太大会OOM
# batch_size = 64
BATCH_SIZE = 32
IMAGE_SIZE = 224
CAPACITY = 1024
MIN_AFTER = 512
TEST_SIZE = 32
NUM_THREADS = 4
N_CLASS = 5
LR = .001
TRAIN_RECORD = r"D:\data\flowers\train.record"
TEST_RECORD = r"D:\data\flowers\test.record"
CKPT_PATH = r"D:\data\mobilenet_v2_1.4_224\mobilenet_v2_1.4_224.ckpt"
num_classes = 5
SAVE_PATH = './pb/flowers_mobilenet.pb'
def get_data(record_path, batch_size):
queue = tf.train.string_input_producer([record_path])
reader = tf.TFRecordReader()
_, data = reader.read(queue)
features = tf.parse_single_example(
data, {
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
img = features['image']
# 第二种读取方式,先解码转化类型,然后reshape
with tf.variable_scope('img_read'):
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 3))
img = tf.cast(img, tf.float32)
img /= 128.
img -= 1.
print('img ', img.shape) # img (224, 224, 3)
image_batch, label_batch = tf.train.shuffle_batch(
[img, label],
batch_size=batch_size,
capacity=CAPACITY,
min_after_dequeue=MIN_AFTER,
num_threads=NUM_THREADS
)
print(image_batch.shape, label_batch.shape) # (32, 224, 224, 3) (32,)
return image_batch, label_batch
def train():
in_x = tf.placeholder(dtype=tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3), name='in_x')
in_y = tf.placeholder(tf.float32, (None, N_CLASS), name='in_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
mobilenet_v2_140(in_x)
embedding = tf.get_default_graph().get_tensor_by_name('MobilenetV2/embedding:0')
print(embedding.name, embedding.shape)
with tf.variable_scope('train_net'):
with slim.arg_scope(
[slim.fully_connected],
# 由于预处理时有大量负数需要使用leaky_relu 否则网络很容易死掉
activation_fn=tf.nn.leaky_relu
):
net = slim.conv2d(embedding, 32, (5, 5))
net = slim.flatten(net)
net = slim.fully_connected(net, 1024)
net = slim.dropout(net, keep_prob)
net = slim.fully_connected(net, 512)
net = slim.dropout(net, keep_prob)
net = slim.fully_connected(net, 256)
net = slim.dropout(net, keep_prob)
net = slim.fully_connected(net, 128)
net = slim.dropout(net, keep_prob)
net = slim.fully_connected(net, num_classes)
# 最后一层如果加上dropout网络收敛很慢,而且效果不太好
# 3000多次训练只能达到0.7 左右
# 不在最后一层加dropout同样的时间可以达到0.9左右
# net = slim.dropout(net, keep_prob)
with tf.variable_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=in_y, logits=net), name='loss')
train_var_list = [var for var in tf.global_variables()
if 'train_net' in var.name]
net_var_list = [var for var in tf.global_variables()
if 'MobilenetV2' in var.name]
with tf.variable_scope('train'):
train_op = tf.train.AdamOptimizer(LR).minimize(loss, var_list=train_var_list, name='train')
with tf.variable_scope('logits'):
logits = tf.nn.softmax(net, name='logits')
with tf.variable_scope('prediction'):
prediction = tf.argmax(net, 1, name='prediction')
with tf.variable_scope('real'):
real = tf.argmax(in_y, 1, name='real')
with tf.variable_scope('accuracy'):
accuracy = tf.reduce_mean(
tf.cast(tf.equal(real, prediction), tf.float32),
name='accuracy'
)
cnt = 0
for i in tf.global_variables('train_net'):
# print(x, type(x)) # 5
# 需要转化为list格式,否则无法进行求和
cnt += np.product(i.shape.as_list())
print(i.shape.as_list())
print(i.name, i.shape)
print(4 * cnt / 1024 ** 2) # 14.228656768798828
batch_x, batch_y = get_data(TRAIN_RECORD, BATCH_SIZE)
test_x, test_y = get_data(TEST_RECORD, TEST_SIZE)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(net_var_list)
saver.restore(sess, CKPT_PATH)
# 启动线程开始读取数据,否则会一直等待
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
for i in range(1, 1 + TRAIN_STEP):
batch_x_val, batch_y_val = sess.run([batch_x, batch_y])
batch_y_val = tf.keras.utils.to_categorical(batch_y_val, N_CLASS)
# batch_x_val = preprocess_input(batch_x_val)
sess.run(train_op, feed_dict={
in_x: batch_x_val,
in_y: batch_y_val,
keep_prob: .5,
})
if not i % SHOW_STEP:
batch_x_val, batch_y_val = sess.run([batch_x, batch_y])
batch_y_val = tf.keras.utils.to_categorical(batch_y_val, N_CLASS)
# batch_x_val = preprocess_input(batch_x_val)
loss_val, accuracy_val, net_val = sess.run(
[loss, accuracy, net],
feed_dict={
in_x: batch_x_val,
in_y: batch_y_val,
keep_prob: 1.,
}
)
print('train ', i, loss_val, accuracy_val)
test_x_val, test_y_val = sess.run([test_x, test_y])
test_y_val = tf.keras.utils.to_categorical(test_y_val, N_CLASS)
# test_x_val = preprocess_input(test_x_val)
loss_val, accuracy_val = sess.run(
[loss, accuracy],
feed_dict={
in_x: test_x_val,
in_y: test_y_val,
keep_prob: 1.,
}
)
print('test ', i, loss_val, accuracy_val)
# 执行图的时间
for i in range(10):
batch_x_val, batch_y_val = sess.run([batch_x, batch_y])
batch_y_val = tf.keras.utils.to_categorical(batch_y_val, N_CLASS)
# batch_x_val = preprocess_input(batch_x_val)
st = time.time()
sess.run(
prediction,
feed_dict={
in_x: batch_x_val[:1],
in_y: batch_y_val[:1],
keep_prob: 1.,
}
)
print(time.time() - st)
# 保存为pb文件
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
['prediction/prediction', 'logits/logits'])
with tf.gfile.FastGFile(SAVE_PATH, mode='wb') as f:
f.write(constant_graph.SerializeToString())
tf.summary.FileWriter('./log/', sess.graph)
coord.request_stop()
# 超时时间30s
coord.join(thread, 30)
def main():
train()
if __name__ == '__main__':
main()
训练参数
MobilenetV2/embedding:0 (?, 7, 7, 1792)
train_net/Conv/weights:0 (5, 5, 1792, 32)
train_net/Conv/biases:0 (32,)
train_net/fully_connected/weights:0 (1568, 1024)
train_net/fully_connected/biases:0 (1024,)
train_net/fully_connected_1/weights:0 (1024, 512)
train_net/fully_connected_1/biases:0 (512,)
train_net/fully_connected_2/weights:0 (512, 256)
train_net/fully_connected_2/biases:0 (256,)
train_net/fully_connected_3/weights:0 (256, 128)
train_net/fully_connected_3/biases:0 (128,)
train_net/fully_connected_4/weights:0 (128, 5)
train_net/fully_connected_4/biases:0 (5,)
参数量计算,需要转化为list之后进行计算因为维度信息中含有未知维度
训练参数大小约为14M,加上之前embedding的16M差不多30M左右
cnt = 0
for i in tf.global_variables('train_net'):
# print(x, type(x)) # 5
# 需要转化为list格式,否则无法进行求和
cnt += np.product(i.shape.as_list())
print(i.shape.as_list())
print(i.name, i.shape)
print(4 * cnt / 1024 ** 2) # 14.228656768798828
导出pb文件大小与计算参数大小很接近了
网络结构图,右边是数据流图,可以看到在数据解析之后先进行了img_read操作,也就是预处理操作
读取pb文件,然后对文件夹下的图片进行预测处理,一张照片的处理速度大概在0.02秒左右
import tensorflow as tf
import os
import numpy as np
from PIL import Image
import time
SAVE_DIR = './pb/'
SAVE_PATH = os.path.join(SAVE_DIR, 'flowers_mobilenet.pb')
IMAGE_DIR = './imgs'
sess = tf.Session()
output_graph_def = tf.GraphDef()
with open(SAVE_PATH, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(
output_graph_def,
name='', # 默认name为import,类似scope
# return_elements=['prediction:0']
)
sess.run(tf.global_variables_initializer())
in_x = sess.graph.get_tensor_by_name("in_x:0")
keep_prob = sess.graph.get_tensor_by_name('keep_prob:0')
prediction = sess.graph.get_tensor_by_name('prediction/prediction:0')
logits = sess.graph.get_tensor_by_name('logits/logits:0')
print(in_x.name, prediction.name, logits.name)
print(in_x.shape, prediction.shape, logits.shape)
for name in os.listdir(IMAGE_DIR):
path = os.path.join(IMAGE_DIR, name)
img = Image.open(path).resize((224, 224))
img = np.array(img).astype(np.float32).reshape((-1, 224, 224, 3))
img /= 128.
img -= 1
st = time.time()
prediction_val, logits_val = sess.run(
[prediction, logits], {
in_x: img,
keep_prob: 1.
}
)
print(name, time.time() - st)
print(prediction_val, logits_val)
只有一张预测错误,但是依然能够看到top3准确率很高
0- (1).jpg 6.334139585494995
[0] [[1. 0. 0. 0. 0.]]
0- (2).jpg 0.01692676544189453
[0] [[1. 0. 0. 0. 0.]]
0- (3).jpg 0.016925573348999023
[0] [[1. 0. 0. 0. 0.]]
0- (4).jpg 0.0159604549407959
[4] [[0.3088594 0.05328554 0.12708923 0.09731596 0.41344985]]
0- (5).jpg 0.015984773635864258
[0] [[1. 0. 0. 0. 0.]]
1- (1).jpg 0.016924619674682617
[1] [[0. 1. 0. 0. 0.]]
1- (2).jpg 0.016926050186157227
[1] [[0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00 1.9174303e-37]]
1- (3).jpg 0.01592564582824707
[1] [[0. 1. 0. 0. 0.]]
1- (4).jpg 0.01598358154296875
[1] [[0. 1. 0. 0. 0.]]
1- (5).jpg 0.016951322555541992
[1] [[0. 1. 0. 0. 0.]]
2- (1).jpg 0.016955137252807617
[2] [[0. 0. 1. 0. 0.]]
2- (2).jpg 0.016927719116210938
[2] [[0. 0. 1. 0. 0.]]
2- (3).jpg 0.01692485809326172
[2] [[1.3398786e-12 4.7934960e-16 9.9997687e-01 1.1129710e-14 2.3081511e-05]]
2- (4).jpg 0.015955209732055664
[2] [[0. 0. 1. 0. 0.]]
2- (5).jpg 0.016924142837524414
[2] [[0. 0. 1. 0. 0.]]
3- (1).jpg 0.01592850685119629
[3] [[0. 0. 0. 1. 0.]]
3- (2).jpg 0.015956640243530273
[3] [[0. 0. 0. 1. 0.]]
3- (3).jpg 0.016922712326049805
[3] [[0. 0. 0. 1. 0.]]
3- (4).jpg 0.01695418357849121
[3] [[0. 0. 0. 1. 0.]]
3- (5).jpg 0.016954421997070312
[3] [[0. 0. 0. 1. 0.]]
4- (1).jpg 0.017953872680664062
[4] [[0. 0. 0. 0. 1.]]
4- (2).jpg 0.016954660415649414
[4] [[3.8542287e-22 5.6099554e-26 1.1666944e-23 6.9857317e-24 1.0000000e+00]]
4- (3).jpg 0.015956640243530273
[4] [[0. 0. 0. 0. 1.]]
4- (4).jpg 0.01595759391784668
[4] [[2.9233486e-06 3.8212596e-07 1.3178097e-06 1.1761556e-06 9.9999416e-01]]
4- (5).jpg 0.015957117080688477
[4] [[3.3387124e-38 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00]]