改代码的关键是输入。在这里cifar10_input.py已经可以完全废掉了,没用,而且封装了很多层。捋顺代码的过程中发现各个文件之间的依赖关系是这样的:
cifar10_eval → cifar10_train → cifar10 → cifar10_input
所以既然要废掉最后一个,那就必须在cifar10中把输入处理妥当。最关键的函数就是cifar10的inputs。原代码中cifar10_train用的是distorted inputs,在这里为了跑通demo姑且不要distorted(因为这一步应该在预处理过程中完成,即先distorted再生成tfrecords),然后把inputs改好就行了。
def inputs(data_dir, tfrecords_file_names):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
files = [os.path.join(data_dir, tfrecords_file_names)]
#files = tf.train.match_filenames_once("cats_vs_dogs.tfrecords")
filename_queue = tf.train.string_input_producer(files, shuffle=True)
#filename_queue = tf.train.string_input_producer(['cats_vs_dogs.tfrecords'])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
# 解析读取的样例
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw' : tf.FixedLenFeature([], tf.string),
})
images = tf.decode_raw(features['image_raw'], tf.uint8)
images = tf.reshape(images, [128,128,3])
images = tf.cast(images, tf.float32) * (1. / 255) - 0.5
labels = tf.cast(features['label'],tf.int32)
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * FLAGS.batch_size
batch_size = FLAGS.batch_size
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size = batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
if FLAGS.use_fp16:
images = tf.cast(image_batch, tf.float16)
labels = tf.cast(label_batch, tf.float16)
return image_batch, label_batch
需要注意的是这里用files = [os.path.join(data_dir, tfrecords_file_names)] 代替了match_filenames_once,因为感觉后者不稳定,经常出莫名其妙的错误。而用前者的话路径直接指定,出错可能性较低,但是这里的两个参数就需要inputs函数参数传递而来了,我这里把这两个当做参数,替代了原先指示是否是测试集的那个eval_data。
补充:
images = tf.reshape(images, [128,128,3]) 这句里面的128*128*3非常重要,代表的是图像的输入尺寸,当时怎么压缩进tfrecords里面的,此时就应该怎么读出来。
改好这个以后,cifar10_train中的distorted_inputs改成inputs,传递相应的路径和tfrecords文件名参数,训练就行了。注意每回训练都会把之前的checkpoints强制删除,我这里把main函数改成了交互式。
def main(argv=None): # pylint: disable=unused-argument
#cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
save_file = input("Checkpoints and logs exist, do you want to delete it?\nyes/no\n")
if save_file == "yes":
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
elif save_file == "no":
print("System exits, please save your checkpoint files manually and run this again.")
else:
print("Invalid user command, system exists.")
else:
tf.gfile.MakeDirs(FLAGS.train_dir)
train()