tensorflow使用train_image_classifier来训练数据(修改整理)

看了几篇关于cnn的文章,感觉那种大模型的cnn真的不适合个人去使用,自己也没有那么强悍的显卡,也没有足够的数据和时间

还是用迁移学习比较好,这里说一下用的模型,inception_v3是谷歌的cnn框架。这个框架有22层深,用tensorboard看的时候是比较大的(相比于letnet和alxnet),这个框架运算量并不大,而且很多卷积层的权值基本上可以不用改变,可以说使用起来非常的方便。

他降低参数有两点 第一是去除了最后的全连接层,采用全局平均池化层(将图片尺寸变为1*1)来取代它。全连接层基本上占据了alxnet和vggnet 90%的参数量,为什么呢?因为卷积核并不多,而三层全连接层(Alxnet)的参数量是非常恐怖的,第一层就以万计。而且参数过多,数据量少的话会过拟合,效果并不好。

第二是Inception V1中精心设计Inception moudle级高了参数的利用率,这个结构的思路借鉴于VGGnet,VGGnet首次实现了多个小卷积核的同时使用,替换了Alxnet的第一层11*11的卷积核,而Inception的卷积核尺寸更小,参数利用率越高

下面我来说一下怎么使用,主要是参考讲座 炼数成金,但是对这个里面的bug进行了修改。

首先,下载数据集合,数据集我用flowers的,事实上后来我才发现,官方提供了直接针对flowes的代码。


这里面的是花的5个种类

这里有一个txt文件,是output_labels.txt是所有花的名称,放在flower_photo目录下



然后生成tfrecord文件

先上代码再解释吧

[python]  view plain  copy
  1. # coding: utf-8  
  2.   
  3. import tensorflow as tf  
  4. import os  
  5. import random  
  6. import math  
  7. import sys  
  8. import types  
  9. from PIL import Image  
  10.   
  11. #验证集数量  
  12. _NUM_TEST = 300  
  13. #随机种子  
  14. _RANDOM_SEED = 0  
  15. #数据块 把图片进行分割,对于数据量比较大的时候使用  
  16. _NUM_SHARDS = 5  
  17. #数据集路径  
  18. DATASET_DIR = 'D:/Tensorflow/flower_photos/flowers'  
  19. #标签和文件名字  
  20. LABELS_FILENAME = 'D:/Tensorflow/flower_photos/output_labels.txt'  
  21.   
  22. #定义tfrecord文件的路径和名字  
  23. def _get_dataset_filename(dataset_dir,split_name,shard_id):  
  24.     output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)  
  25.     return os.path.join(dataset_dir,output_filename)  
  26.   
  27. #判断tfrecord文件是否存在  
  28. def _datase_exists(dataset_dir):  
  29.     for split_name in ['train','test']:  
  30.         for shard_id in range(_NUM_SHARDS):  
  31.             #定义tfrecord文件的路径+名字  
  32.             output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
  33.         if not tf.gfile.Exists(output_filename):  
  34.             return False  
  35.     return True  
  36.   
  37.   
  38. #获取所有文件以及分类  传入图片的路径  
  39. def _get_filenames_and_classes(dataset_dir):  
  40.     #数据目录  
  41.     directories = []  
  42.     #分类名称  
  43.     class_names = []  
  44.     for filename in os.listdir(dataset_dir):  
  45.         #合并文件路径  
  46.         path = os.path.join(dataset_dir,filename)  
  47.         #判断该路径是否为目录  
  48.         if os.path.isdir(path):  
  49.             #加入数据目录  
  50.             directories.append(path)  
  51.             #加入类别名称  
  52.             class_names.append(filename)  
  53.     photo_filenames = []  
  54.     #循环每个分类的文件夹  
  55.     for directory in directories:  
  56.         for filename in os.listdir(directory):  
  57.             path = os.path.join(directory,filename)  
  58.             #把图片加入图片列表  
  59.             photo_filenames.append(path)  
  60.     return photo_filenames,class_names  
  61.   
  62. def int64_feature(values):  
  63.     if not isinstance(values,(tuple,list)):  
  64.         values = [values]  
  65.         #print(values)  
  66.     return tf.train.Feature(int64_list=tf.train.Int64List(value=values))  
  67.   
  68. def bytes_feature(values):  
  69.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))  
  70.   
  71.   
  72. def image_to_tfexample(image_data,image_format,class_id):  
  73.     return tf.train.Example(features=tf.train.Features(feature={  
  74.         'image/encoded': bytes_feature(image_data),  
  75.         'image/format' : bytes_feature(image_format),  
  76.         'image/class/label' : int64_feature(class_id)  
  77.     }))  
  78.   
  79.   
  80. def write_label_file(labels_to_class_names,dataset_dir,filename='label.txt'):  
  81.     #拼接目录  
  82.     labels_file_name = os.path.join(dataset_dir,filename)  
  83.     print(dataset_dir)  
  84.     #with open(labels_file_name,'w') as f:  
  85.     with tf.gfile.Open(labels_file_name,'w') as f:  
  86.         for label in labels_to_class_names:  
  87.             class_name = labels_to_class_names[label]  
  88.             f.write('%d;%s\n'%(label,class_name))  
  89.   
  90.   
  91. #把数据转为TFRecord格式  
  92. def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir):  
  93.     #assert 断言   assert expression 相当于 if not expression raise AssertionError  
  94.     assert split_name in ['train','test']  
  95.     #计算每个数据块有多少个数据  
  96.     num_per_shard = int(len(filenames) / _NUM_SHARDS)  
  97.     with tf.Graph().as_default():  
  98.         with tf.Session() as sess:  
  99.             for shard_id in range(_NUM_SHARDS):  
  100.                 #定义tfrecord文件的路径+名字  
  101.                 output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
  102.                 with tf.python_io.TFRecordWriter(output_filename) as tfrecore_writer:  
  103.                     #每一个数据块开始的位置  
  104.                     start_ndx = shard_id * num_per_shard  
  105.                     #每一个数据块最后的位置  
  106.                     end_ndx = min((shard_id+1) * num_per_shard,len(filenames))  
  107.   
  108.                     for i in range(start_ndx,end_ndx):  
  109.                         try:  
  110.                             sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1,len(filenames),shard_id))  
  111.                             sys.stdout.flush()  
  112.                             #读取图片  
  113.                             #image_data = tf.gfile.FastGFile(filenames[i],'rb').read()  
  114.                             img = Image.open(filenames[i])  
  115.                             #img = img.resize((224, 224))  
  116.                             img_raw = img.tobytes()  
  117.                              #获取图片的类别名称  
  118.                             class_name = os.path.basename(os.path.dirname(filenames[i]))  
  119.                             #找到类别名称对应的id  
  120.                             class_id = class_names_to_ids[class_name]  
  121.                             #生成tfrecord文件  
  122.                             example = image_to_tfexample(img_raw, b'jpg',class_id)  
  123.                            # print(filenames[i])  
  124.                             tfrecore_writer.write(example.SerializeToString())  
  125.                         except IOError as e:  
  126.                             print("Could not read: ",filenames[i])  
  127.                             print("Error: ",e)  
  128.                             print("Skip it \n")  
  129.   
  130.     sys.stdout.write('\n')  
  131.     sys.stdout.flush()  
  132.   
  133.   
  134. if __name__=='__main__':  
  135.     #判断tfrecord文件是否存在  
  136.     if _datase_exists(DATASET_DIR):  
  137.         print('tfrecord 文件已经存在')  
  138.     else :  
  139.         #获取图片以及分类  
  140.         photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR)  
  141.         #print(class_names)  
  142.         #把分类转为字典格式 ,类似于{'house':3,'flower':1,'plane':4}  
  143.         class_names_to_ids = dict(zip(class_names,range(len(class_names))))  
  144.         print(class_names_to_ids)  
  145.         #把数据切为训练集和测试集  
  146.         random.seed(_RANDOM_SEED)  
  147.         random.shuffle(photo_filenames)  
  148.         training_filenames = photo_filenames[_NUM_TEST:]  
  149.         testing_filenames = photo_filenames[:_NUM_TEST]  
  150.        # print(training_filenames[0])  
  151.         #数据转换  
  152.         _convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR)  
  153.         _convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR)  
  154.   
  155.         #输出labels文件  
  156.         labels_to_class_names = dict(zip(range(len(class_names)),class_names))  
  157.         write_label_file(labels_to_class_names,DATASET_DIR)  

思路很简单,就是读取图片然后分割,最后转换成tfrecord格式的文件,说一下需要修改的地方(我说了就不用自己找了。。。。)


这两个都是刚才说到的,一个是图片存放的位置,一个是标签文件,为了生成一个类似于字典的txt,其他的不用改,如果

你想改这里的名字的话,那么你后面读取的时候要改官方给你的py文件,还是省省吧。

默认会在你的图片的目录下生成tfrecord文件和labels标签,

为了好看,我把他们移出来,单独放一个文件夹。


然后我们要特别看一下官方给你的几个py文件,如果你只用官方给的例子像测试下的话可以跳过。


首先是这个dataset_factory 这个要改,


原来是没有这个的,你要加上这个,datasets是你所在的这个目录,myimages自然就是你要自己写的py文件了


这里新加上最后一个字典,'image'只是个名字或者叫标识,myimages是你的py文件

然后我们来看看我们自己写的myimages

由于我用的是flowes的图片,你会发现官方给了你一个flowers.py所以你可以参考这个写一下。

下面上一下我的myimages文件,

[python]  view plain  copy
  1. from __future__ import absolute_import  
  2. from __future__ import division  
  3. from __future__ import print_function  
  4.   
  5. import os  
  6. import tensorflow as tf  
  7.   
  8. from datasets import dataset_utils  
  9.   
  10. slim = tf.contrib.slim  
  11.   
  12. _FILE_PATTERN = 'image_%s_*.tfrecord'  
  13.   
  14. SPLITS_TO_SIZES = {'train'3320'validation'350}  
  15.   
  16. _NUM_CLASSES = 5  
  17.   
  18. _ITEMS_TO_DESCRIPTIONS = {  
  19.     'image''A color image of varying size.',  
  20.     'label''A single integer between 0 and 4',  
  21. }  
  22.   
  23.   
  24. def get_split(split_name, dataset_dir, file_pattern=None, reader=None):  
  25.   
  26.   if split_name not in SPLITS_TO_SIZES:  
  27.     raise ValueError('split name %s was not recognized.' % split_name)  
  28.   
  29.   if not file_pattern:  
  30.     file_pattern = _FILE_PATTERN  
  31.   file_pattern = os.path.join(dataset_dir, file_pattern % split_name)  
  32.   
  33.   if reader is None:  
  34.     reader = tf.TFRecordReader  
  35.   
  36.   keys_to_features = {  
  37.       'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),  
  38.       'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),  
  39.       'image/class/label': tf.FixedLenFeature(  
  40.           [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),  
  41.   }  
  42.   
  43.   items_to_handlers = {  
  44.       'image': slim.tfexample_decoder.Image(),  
  45.       'label': slim.tfexample_decoder.Tensor('image/class/label'),  
  46.   }  
  47.   
  48.   decoder = slim.tfexample_decoder.TFExampleDecoder(  
  49.       keys_to_features, items_to_handlers)  
  50.   
  51.   labels_to_names = None  
  52.   if dataset_utils.has_labels(dataset_dir):  
  53.     labels_to_names = dataset_utils.read_label_file(dataset_dir)  
  54.   
  55.   return slim.dataset.Dataset(  
  56.       data_sources=file_pattern,  
  57.       reader=reader,  
  58.       decoder=decoder,  
  59.       num_samples=SPLITS_TO_SIZES[split_name],  
  60.       items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,  
  61.       num_classes=_NUM_CLASSES,  
  62.       labels_to_names=labels_to_names)  
[python]  view plain  copy
  1. 你会发现这里,这个可前面生成tfrecord的名字是有对应关系的。  

这个文件大致意思就是读取下tfrecord文件,然后分割下,有的用来train,有的用来test


接下来可以进行train了

train.bat写在slim这个文件夹下

这里我附上我的train然后讲解下参数

[html]  view plain  copy
  1. python train_image_classifier.py ^  
  2. --train_dir=D:/Tensorflow/flower_photos/train ^  
  3. --dataset_name=image ^  
  4. --dataset_split_name=train ^  
  5. --dataset_dir=D:/Tensorflow/flower_photos/flowers/tfrecord ^  
  6. --batch_size=5 ^  
  7. --max_number_of_steps=10000 ^  
  8. --model_name=inception_v3 ^  
  9. --clone_on_cpu=true ^  
  10. pause  

第一个是你的train_iamge_classifier的位置,这里用的是相对位置

第二个是新建的空文件夹,训练完的数据会放到这个文件夹下

第三个特点的,你在生成tfrecord   的时候切分数据的train和test中的train

第四个是你的tfrecord文件的位置,里面必须要有labels.txt

第五个是分批训练的,主要用于显存不够,不能够一次性存放足够多的数据

第六个是训练的次数,不设置的情况下会一直执行

第七个是训练的模型  这里使用inception_v3模型

第八个很重要,我之前一直报错,问了好多人,上了各种网站都没查出来,这个应该是有些cpu版本的tensorflow才能处理的数据,在GPU上无法计算,所以要开启能够使用cpu的这个选项,如果是cpu版本的tensorflow应该没有问题。

第九个 pause 好像没什么用,改退出还是会退出,所以还是从命令窗口开始执行吧。

训练完之后在你的train文件夹下会生成数据



然后在slim目录下新建一个bbb.py

[python]  view plain  copy
  1. import os  
  2. import tensorflow as tf  
  3. import tensorflow.contrib.slim as slim  
  4.   
  5. from nets import inception  
  6. from nets import inception_v1  
  7. from nets import inception_v3  
  8. from nets import nets_factory  
  9.   
  10. from tensorflow.python.framework import graph_util  
  11. from tensorflow.python.platform import gfile  
  12. from google.protobuf import text_format  
  13.   
  14. checkpoint_path = tf.train.latest_checkpoint('D:/Tensorflow/flower_photos/train')  
  15. with tf.Graph().as_default() as graph:  
  16.     input_tensor = tf.placeholder(tf.float32, shape=(None2992993), name='input_image')  
  17.     with tf.Session() as sess:  
  18.         #  with tf.variable_scope('model') as scope:  
  19.         with slim.arg_scope(inception.inception_v3_arg_scope()):  
  20.             logits, end_points = inception.inception_v3(input_tensor, num_classes=5, is_training=False)  
  21.   
  22.     saver = tf.train.Saver()  
  23.     saver.restore(sess, checkpoint_path)  
  24.   
  25.     output_node_names = 'InceptionV3/Predictions/Reshape_1'  
  26.   
  27.     input_graph_def = graph.as_graph_def()  
  28.     output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(","))  
  29.     with open('D:/Tensorflow/flower_photos/output_graph_nodes.txt''w') as f:  
  30.         f.write(text_format.MessageToString(output_graph_def))  
  31.   
  32.     output_graph = 'D:/Tensorflow/flower_photos/train/inception_v3_final.pb'  
  33.     with gfile.FastGFile(output_graph, 'wb') as f:  
  34.         f.write(output_graph_def.SerializeToString())  

执行后会在train目录下生成pb文件,这个是tensorflow保存和读取的模型文件。

然后我们来使用他来识别。


用到的命令整理:

rm -rf /home/leo/Downloads/tmp/train_dir/*
python train_image_classifier.py \
    --train_dir=/home/leo/Downloads/tmp/train_dir \
    --dataset_name=dish \
    --dataset_split_name=train \
    --dataset_dir=/home/leo/Downloads/train_datas/smallDataSetTest5_9/output_tfrecord \
    --model_name=inception_resnet_v2 \
    --max_number_of_steps=100000 \
    --batch_size=6 \
    --learning_rate=0.0001 \
    --learning_rate_decay_type=fixed \
    --save_interval_secs=60 \
    --save_summaries_secs=60 \
    --log_every_n_steps=10 \
    --optimizer=rmsprop \
    --weight_decay=0.00004
不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。
fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。

如果只使用CPU则加上--clone_on_cpu=True。


验证checkpoint:

python eval_image_classifier.py \
    --checkpoint_path=/home/leo/Downloads/tmp/train_dir \
    --eval_dir=/home/leo/Downloads/tmp/eval_logs \
    --dataset_name=dish \
    --dataset_split_name=validation \
    --dataset_dir=/home/leo/Downloads/train_datas/smallDataSetTest5_9/output_tfrecord \
    --model_name=inception_resnet_v2


其他常用训练命令:

############################################################
使用train image classifier 训练 inception_resnet_v2  using fine-tune
############################################################
python train_image_classifier.py \
    --train_dir=/home/leo/Downloads/tmp/train_dir_220_dish_inception_resnet_v2 \
    --dataset_dir=/home/leo/Downloads/train_datas/18_5_14_220_tfrecord/output_tfrecord \
    --dataset_name=dish \
    --dataset_split_name=train \
    --model_name=inception_resnet_v2 \
    --checkpoint_path=/home/leo/Downloads/pretrained_models/inception_resnet_v2_2016_08_30/inception_resnet_v2_2016_08_30.ckpt \
    --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
    --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
############################################################
使用train image classifier 训练 mobile net v1  using fine-tune
############################################################
python train_image_classifier.py \
    --train_dir=/home/leo/Downloads/tmp/train_dir_220_dish_mobilenet_v1 \
    --dataset_dir=/home/leo/Downloads/train_datas/18_5_14_220_tfrecord/output_tfrecord \
    --dataset_name=dish \
    --dataset_split_name=train \
    --model_name=mobilenet_v1 \
    --checkpoint_path=/home/leo/Downloads/pretrained_models/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
    --checkpoint_exclude_scopes=MobilenetV1/Logits \
    --trainable_scopes=MobilenetV1/Logits

############################################################
使用train image classifier 训练 mobile net v2  from scratch
############################################################
python train_image_classifier.py \
    --train_dir=/home/leo/Downloads/tmp/train_dir_220_dish_mobilenet_v2 \
    --dataset_dir=/home/leo/Downloads/train_datas/18_5_14_220_tfrecord/output_tfrecord \
    --dataset_name=dish \
    --dataset_split_name=train \
    --model_name=mobilenet_v2 \
    --train_image_size=224 \
    --learning_rate=0.0001 \
    --learning_rate_decay_type=fixed \








你可能感兴趣的:(AI,AI实战派,tensorflow,分类器,神经网络,训练数据集)