MobileNet训练分类网

MobileNet训练分类网

1.环境准备

  • 安装tensorflow:

    pip install tensorflow

  • 验证安装成功:

    python
    import tensorflow as tf
    tf.__version__
  • tensorflow/examples/git clone下来

2.数据集文件结构

  • 数据

    ../example/image_retraining/ :
    --dataset
    --class1
      --image1.jpg
      --image2.jpg
      ...
    --class2
      --image1.jpg
      ...
    ...  
  • 标签

    ../label.txt :
    class1
    class2
    ...

3.训练脚本

  • ../example/image_retraining/retrain.py 的主函数中有各项训练参数的设置,训练命令例如:

    python retrain.py \
      --image_dir /home/wz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/dataset \
      --intermediate_store_frequency 1000 \
      --how_many_training_steps 50000 \
      --learning_rate 0.001 \
      --train_batch_size 8 \ 
      --testing_percentage 0 \
      --eval_step_interval 20 \
      --architecture mobilenet_1.0_224 \
      --output_graph  /home/wz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp/output_graph.pb \
      --intermediate_output_graphs_dir /home/wz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp/intermediate_graph \
      --output_labels /home/wz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp/output_labels.txt \
      --model_dir  /home/wz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp/imagenet
  • 选择网络

    architecture参数可以指定选用的网络模型,inception和各个版本的mobilenet可以选择,例如mobilenet_1.0_224表示,控制网络复杂度的超参数是1.0, 这个值越小,网络参数越少,这对应于MobileNet论文中的网络控制超参数,224表示输入尺寸为224*224。

4.检测脚本

  • ../example/label_image/label_image.py 脚本主函数同样有各项参数的设置。

  • 注意input_heightinput_width要和训练用的mobilenet版本对应。

  • 注意input_mean=128是指图像像素值归一化时需要减掉的均值,input_std=128是要除以的值。

  • 稍加改动,批量检测图片:detect.py:

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import sys
    import time
    import os
    import shutil
    import numpy as np
    import tensorflow as tf
    
    
    def load_graph(model_file):
      graph = tf.Graph()
      graph_def = tf.GraphDef()
    
      with open(model_file, "rb") as f:
          graph_def.ParseFromString(f.read())
      with graph.as_default():
          tf.import_graph_def(graph_def)
    
      return graph
    
    
    def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
                                  input_mean=0, input_std=255):
      input_name = "file_reader"
      output_name = "normalized"
      file_reader = tf.read_file(file_name, input_name)
      if file_name.endswith(".png"):
          image_reader = tf.image.decode_png(file_reader, channels=3,
                                             name='png_reader')
      elif file_name.endswith(".gif"):
          image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                        name='gif_reader'))
      elif file_name.endswith(".bmp"):
          image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
      else:
          image_reader = tf.image.decode_jpeg(file_reader, channels=3,
                                              name='jpeg_reader')
      float_caster = tf.cast(image_reader, tf.float32)
      dims_expander = tf.expand_dims(float_caster, 0);
      resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
      normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
      sess = tf.Session()
      result = sess.run(normalized)
    
      return result
    
    
    def load_labels(label_file):
      label = []
      proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
      for l in proto_as_ascii_lines:
          label.append(l.rstrip())
      return label
    
    
    if __name__ == "__main__":
      #file_path = "/home/wrz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/test"
      file_path = "/home/wrz/ADS/test5/neg_resize"
      save_path = "/home/wrz/ADS/test5/bad-128-4000/neg"
      # file_name = "/home/wurui/TensorFlow/tensorflow/tensorflow/examples/image_retraining/test/3.jpg"
      model_file = "/home/wrz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp1.0_128/pb/intermediate_graphintermediate_4000.pb"
      label_file = "/home/wrz/TensorFlow/tensorflow/tensorflow/examples/image_retraining/tmp1.0_128/output_labels.txt"
      input_height = 128
      input_width = 128
      input_mean = 128
      input_std = 128
      input_layer = "input"
      output_layer = "final_result"
    
      t1 = time.time()
      graph = load_graph(model_file)
    
      imgs = os.listdir(file_path)
      t2 = time.time()
      count = 0
      for img in imgs:
          #print(count)
          count+=1
          t3 = time.time()
          file_name = file_path+'/'+str(img)
          #print('========================================================================',file_name)
          t = read_tensor_from_image_file(file_name,
                                          input_height=input_height,
                                          input_width=input_width,
                                          input_mean=input_mean,
                                          input_std=input_std)
    
          input_name = "import/" + input_layer
          output_name = "import/" + output_layer
          input_operation = graph.get_operation_by_name(input_name);
          output_operation = graph.get_operation_by_name(output_name);
    
          t4 = time.time()
          # with tf.device('/cpu:0'):
          # with tf.Session(graph=graph) as sess:
          with tf.Session(graph=graph, config=tf.ConfigProto(device_count={'GPU': 0})) as sess:
          #with tf.Session(graph=graph) as sess:
              results = sess.run(output_operation.outputs[0],
                                 {input_operation.outputs[0]: t})
          results = np.squeeze(results)
    
          top_k = results.argsort()[-5:][::-1]
          labels = load_labels(label_file)
    
          t5 = time.time()
          #print('load graph time = ', t2 - t1)
          #print('load img time =  ', t4 - t3)
          #print('run sess time =  ', t5 - t4)
          if(results[0] < results[1]): #labels[0]=neg labels[1]=pos
              print(labels[0], results[0],labels[1], results[1]," detect wrong=====================================================",file_name)
              shutil.copyfile(file_name,save_path+'/'+str(img))
              print(count)

你可能感兴趣的:(神经网络基础)