tensorflow官网Cifar-10改为自己的TFRecords数据集

已经改完了,中间有些过程记得比较模糊,能想起哪些就记下来哪些吧。官网Cifar-10都是已经下载好的数据集,所以一般是以bin或者压缩文件的形式存在,这一点可以在cifar10_train文件的最后一段main函数中cifar10.maybe_download_and_extract()体现,但是我们在训练过程中不会再去把自己的数据集压缩成文件再解压训练,都是直接生成TFRecords文件训练,所以改成这种方式还是很通用的。


改代码的关键是输入。在这里cifar10_input.py已经可以完全废掉了,没用,而且封装了很多层。捋顺代码的过程中发现各个文件之间的依赖关系是这样的:

cifar10_eval → cifar10_train → cifar10 → cifar10_input

所以既然要废掉最后一个,那就必须在cifar10中把输入处理妥当。最关键的函数就是cifar10的inputs。原代码中cifar10_train用的是distorted inputs,在这里为了跑通demo姑且不要distorted(因为这一步应该在预处理过程中完成,即先distorted再生成tfrecords),然后把inputs改好就行了。


def inputs(data_dir, tfrecords_file_names):

  """Construct input for CIFAR evaluation using the Reader ops.

  Args:
    eval_data: bool, indicating if one should use the train or eval data set.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """

  files = [os.path.join(data_dir, tfrecords_file_names)]
  
  #files = tf.train.match_filenames_once("cats_vs_dogs.tfrecords")
  filename_queue = tf.train.string_input_producer(files, shuffle=True) 
    #filename_queue = tf.train.string_input_producer(['cats_vs_dogs.tfrecords'])
  reader = tf.TFRecordReader()
  _,serialized_example = reader.read(filename_queue)
    
  # 解析读取的样例
  features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw' : tf.FixedLenFeature([], tf.string),
        })  

  images = tf.decode_raw(features['image_raw'], tf.uint8)
  images = tf.reshape(images, [128,128,3])
  images = tf.cast(images, tf.float32) * (1. / 255) - 0.5

  labels = tf.cast(features['label'],tf.int32)
  
  min_after_dequeue = 10
  capacity = min_after_dequeue + 3 * FLAGS.batch_size
  batch_size = FLAGS.batch_size

  image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
                                                    batch_size = batch_size, 
                                                    capacity=capacity,
                                                    min_after_dequeue=min_after_dequeue)
   
  if FLAGS.use_fp16:
    images = tf.cast(image_batch, tf.float16)
    labels = tf.cast(label_batch, tf.float16)
  
  return image_batch, label_batch

需要注意的是这里用files = [os.path.join(data_dir, tfrecords_file_names)] 代替了match_filenames_once,因为感觉后者不稳定,经常出莫名其妙的错误。而用前者的话路径直接指定,出错可能性较低,但是这里的两个参数就需要inputs函数参数传递而来了,我这里把这两个当做参数,替代了原先指示是否是测试集的那个eval_data。
补充:
images = tf.reshape(images, [128,128,3]) 这句里面的128*128*3非常重要,代表的是图像的输入尺寸,当时怎么压缩进tfrecords里面的,此时就应该怎么读出来。
改好这个以后,cifar10_train中的distorted_inputs改成inputs,传递相应的路径和tfrecords文件名参数,训练就行了。注意每回训练都会把之前的checkpoints强制删除,我这里把main函数改成了交互式。
def main(argv=None):  # pylint: disable=unused-argument
  #cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    save_file = input("Checkpoints and logs exist, do you want to delete it?\nyes/no\n")
    if save_file == "yes":
     tf.gfile.DeleteRecursively(FLAGS.train_dir)
     tf.gfile.MakeDirs(FLAGS.train_dir)
     train()
    elif save_file == "no":
     print("System exits, please save your checkpoint files manually and run this again.")
    else:
     print("Invalid user command, system exists.")
  else:
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()

cifar10_eval同理,其他没有做太大改动。目前在猫狗大战上训练能跑通,运行时间太长还没出结果,后续会再尝试多类分类。












你可能感兴趣的:(Deep,Learning)