【TensorFlow】制作TFrecords中的问题

问题

项目中需要把每张图片的标记向量存入TFrecords文件,参考了图片的存储方法,但会报错,经过几个小时的折腾才发现问题在哪里(汗颜)

代码

	tfrecords_filename = './train2.tfrecords'
    writer = tf.python_io.TFRecordWriter(tfrecords_filename)
    #重点在于指定整数的类型
    img_raw = np.array([12,14,13,15,16,17],dtype=np.int32)
    img_raw = img_raw.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'label':tf.train.Feature(int64_list = tf.train.Int64List(value=[5])),
        'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))}))
    writer.write(example.SerializeToString())
    writer.close()
    
    filename_queue = tf.train.string_input_producer([tfrecords_filename])  # 读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                        features={
                                            'label': tf.FixedLenFeature([], tf.int64),
                                            'img_raw': tf.FixedLenFeature([], tf.string),
                                        })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.int32)
    image = tf.reshape(image, [6])
    label = tf.cast(features['label'], tf.int64)
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        example, l = sess.run([image, label])
        print(example, l)

        coord.request_stop()
        coord.join(threads)

在数据读出的时候,默认的整数类型为int32,之前都用的int64,发现输出总是少一半数据,换成int16后又多一半数据,因此,显式指定数据类型是很有必要的。
感谢这篇文章提供的总体框架,总算可以进行下去了~

你可能感兴趣的:(【TensorFlow】制作TFrecords中的问题)