记 第一次 用自己的训练集 在tensorflow上训练CNN的 坑

花了三天的时间,才成功在down下来的tensorflow代码上跑通了自己的CNN,心累

数据输入

这个问题几乎占据了一天半的时间。不得不说,tensorflow真的是超级超级烦,尤其是对于我们这些初学者,官方文档及其晦涩又没有例子,教程上还老拿MNIST和cifar_10这种做好的数据集说事,完全不知道图片该如何输入。

后来还是用TFRecords,先把图片加上标签制作成二进制文件,使用时再直接用reader读

下面是官方教程里的原话

标准TensorFlow格式

另一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 您也可以将这个例子跟fully_connected_feed的版本加以比较。

。。。还是不会用

后来发现了一篇博客写的很好,而且它添加标签的方法也非常巧妙:
http://blog.csdn.net/u012759136/article/details/52232266

然而不知道为啥我用jpg和jpeg的图片时转换都报错了,换成png才可以

数据格式问题

好吧,基本上通过tf.train.batch后,输入图片的shape都会转换成(batch_size,width,height,channels),label是(batch,int),这两个都是tensor。
如果我用教程里跑cifar_10的代码跑,这种格式基本上就可以了,因为最后输出的logits是类似[-1.11,2.32]的格式:

  • cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  • 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较

但是,如果用教程里跑MNIST的cnn代码,有很多地方要注意

  • 数据集是feed输入的,feed的数据格式是有要求的The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays
    解决:img,label = sess.run[img,label],用返回值
  • cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量
  • 原代码通过batch=mnist.train.next_batch(50)提供数据,但是tensorflow里TFRecordReader可以够记住tfrecord的位置,并且始终能返回下一个,所以不需要自己next_batch

让他训练

一堆乱七八糟的错误,本来觉得挺坑的,做完了觉得简直弱智。。总之,如果程序一直停留的一个地方,既没卡,也没报错,那基本上就是数据没输进去。。好好看下代码现在run的是什么,返回值是什么,一步一步减小run的深度,然后看是哪个地方漏了什么,TFRecords是不是又不小心修改了,反正我被它坑了好几次

对了!

如果使用了队列,改代码的时候千万不能把这一句删了,不然你什么也看不到,程序又停那了
tf.train.start_queue_runners(sess=sess)

搞了三天,真正说的时候觉得也没什么难的,大概这就是掌握了吧。。

你可能感兴趣的:(记 第一次 用自己的训练集 在tensorflow上训练CNN的 坑)