在用cnn完成oxflower17分类任务时遇到的一些问题和解决方案,记录一下。
开始按照Tensorflow创建和读取17flowers数据集[1]的方法用TFRecord读取数据集,由于返回的是tf.Tensor格式的数据,不能feed给网络,报如下错误:
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("shuffle_batch:0", shape=(64, 224, 224, 3), dtype=float32) which was passed to the feed with key Tensor("input:0", shape=(?, 224, 224, 3), dtype=float32).
数据集是feed输入的,feed的数据格式是有要求的 。解决:img,label = sess.run([img,label]),用返回值。
加了用sess.run返回的语句后,程序每次运行到这句就会挂起,既不执行也不报错。后来分析了原因是tf的数据线程没有启动,导致数据流图没办法计算,整个程序就卡在哪里。tensorflow的计算和数据读入是异步的,合理的方式是主线程进行模型的训练,然后开一个数据读入线程异步读入数据.tensorflow会在内存中维护一个队列,然后数据线程异步从磁盘中将样本推入队列当中。并且,因为tensorflow的训练和读数据是异步的,故即使当前没有数据进来,tensorflow也没办法报错,因为可能接下来会有数据进队列,所以,tensorflow就一直处于等待的状态。cited from.tensorflow 程序挂起的原因,即整个进程不报错又不执行的原因[2]
使用tf.train.range_input_producer(epoch_size, shuffle=False),会默认将QueueRunner
添加到全局图中,必须用tf.train.start_queue_runners(sess=sess),去启动该线程。然后使用coord = tf.train.Coordinator()去做一些线程的同步工作。也可以如[2]所说尝试使用sv = tf.train.Supervisor()。
train_x_batch, train_y_batch=get_batch('trn1.tfrecords',config.batch_size)
# Start input enqueue threads.
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)#启动QueueRunner, 此时文件名队列已经进队
try:
for i in range(5):
# Run training steps or whatever
x_batch,y_batch=sess.run([train_x_batch,train_y_batch])
except tf.errors.OutOfRangeError:
print ('Done -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
feed_dict=feed_data(x_batch,y_batch,config.dropout_keep_prob)
有关TensorFlow中的队列与线程,可以参照这一篇文章:Tensorflow笔记(基础篇):队列与线程[3]
但到了读取验证集时,又报如下错:
tensorflow.python.framework.errors_impl.CancelledError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed.
[[Node: shuffle_batch/random_shuffle_queue_enqueue = QueueEnqueueV2[Tcomponents=[DT_FLOAT, DT_INT64], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, sub, ParseSingleExample/ParseSingleExample:1)]]
TensorFlow官方文档在数据读取部分中说:在某些时候,样本出队的操作可能会得到一个tf.OutOfRangeError
的错误。这其实是TensorFlow的“文件结束”(EOF) ———— 这就意味着已经达到了最大训练迭代数,已经没有更多可用的样本了。cited from.Tensorflow官方文档中文版-数据读取[4]。这就说明可能是因为我的validation集数据太少所以队列关闭。
所以转换思路,从tfrecord文件创建Dataset。参考:Tensorflow学习笔记–创建Dataset[5],Tensorflow 学习笔记:Input Pipeline - Dataset[6]:
#创建:
def make_dataset_from_tfrecord_file(filename,batch_size):
filename_list=[filename]
dataset=tf.data.TFRecordDataset(filename_list).repeat()
dataset=dataset.map(__parse_function)
dataset=dataset.batch(batch_size)
return dataset
def __parse_function(example_proto):
features={"img_raw":tf.FixedLenFeature([],tf.string),
"label":tf.FixedLenFeature([],tf.int64)}
parsed_features=tf.parse_single_example(example_proto,features)
img=tf.decode_raw(parsed_features['img_raw'],tf.uint8)
img=tf.reshape(img,[224,224,3])
img=tf.cast(img,tf.float32)*(1./255)-0.5
label=tf.cast(parsed_features['label'],tf.int64)
label=tf.one_hot(label,17,1,0)
return img,label
#读取
def read(filename,batch_size):
dataset=make_dataset_from_tfrecord_file(filename,batch_size)
iterator=dataset.make_initializable_iterator()
return iterator.get_next()
#运行时
sess.run(iterator.initializer)
x_batch,y_batch=sess.run([x_batch_,y_batch_])
一开始我把get_next()放在循环内,出现了如下警告:
UserWarning: An unusually high number of `Iterator.get_next()` calls was detected. This often indicates that `Iterator.get_next()` is being called inside a training loop, which will cause gradual slowdown and eventual resource exhaustion. If this is the case, restructure your code to call `next_element = iterator.get_next()` once outside the loop, and use `next_element` as the input to some computation that is invoked inside the loop
就是说建议把get_next()放在循环外面,每次只要用iterator读取数据就可以了。
这里还涉及一个知识点就是读取Dataset的时候,有多种迭代器,这里我们用的是可初始化迭代器make_initializable_iterator ,允许Dataset中存在占位符,这样可以在数据需要输出的时候,再进行feed操作。其他迭代器如dataset.make_one_shot_iterator()
迭代器,one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了;reinitializable 迭代器和Iterator.from_string_handle 迭代器,详细参见:Dataset的用法简析
更多参考:
技术分享|TensorFlow初学者在使用过程中可能遇到的问题及解决办法
TensorFlow图像数据处理
tensorflow花卉识别(含模型保存及调用、详细注释)
舆情监控系统——step2.CNN-基于tensorFlow实现