Tensorflow针对不定尺寸的图片读写tfrecord文件总结

介绍

最近在读取tfrecord时,遇到了关于tensorf shape的问题。

我们需要知道,大多数情况下图片进行encode编码保存在tfrecord时 是一个一维张量,shape为(1,)。 而在输入神经网络之前,我们必须要将这个图片张量reshape成一个合乎网络结构需求的三维张量。
在针对这样的需求时,我们会发现,大部分同学会选择在生成tfrecord前就定义好网络的输入shape,例如[224,224,3], 然后将所有的图片先reshape成这个大小,接着存储在tfrecord中。
这种方式的优点在于提前完成的reshape,避免了后续很多的shape uncompatible 的问题,以及后续训练中不用再对图片进行reshape,加快了训练速度。
缺点在于,限制了网络输入尺寸的定义。每修改一次神经网络的输入shape。

当我们需要从存储着不定尺寸图片的tfrecord读取数据时, 我们是无法直接将图片reshape成指定的网络结构输入尺寸的。例如图片大小 [667,1085,3]。显然,我们无法直接将其reshape成 [224,224,3]的。那么我们该如何处理呢?

按照思路,我们应该先将图片的一维tensor 转换成三维tensor, 然后再利用 tf.image库中不同的reshape 操作,将三维图片tensor转换为需要的 tensor大小。

按照这种思路,在这里,我总结了两种读写tfrecord的方式,并对这两种方式的不同点,尤其是容易导致bug的地方进行了整理。

第一种: 利用slim.dataset.Dataset读写tfrecord文件,这种方式常见于利用slim库进行目标检测等网络的实现过程中。
第二种:tf.parse_single_example 是更为常见的一种方式

利用slim.dataset.Dataset读写tfrecord文件

利用这个这个接口读写tfrecord非常的方便。它的神奇之处在于,
它不需要图片宽高的信息,只需要其二进制string tensor。 这个接口会自动返回一个三维图片tensor。 在此基础上,我们可以很方便的对其进行reshape,然后输入神经网络。
具体步骤如下:
在生成tfrecord文件时,我们需要先定义 tf_example的写入格式,然后在将图片文件依据这个写入格式,生成tfrecord文件

  • 定义 tf_example的写入特征
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)

    # 对于可能存在RGBA 4通道的图片进行处理
    image,is_process = process_image_channels(image)

    # 如有必要,那么就在生成tfrecord时即进行resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    # update encode_jpg
    if is_process or resize_size is not None:
        bytes_io = io.BytesIO()
        image.save(bytes_io, format='JPEG')
        encoded_jpg = bytes_io.getvalue()

    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))
    return tf_example
  • 生成完整的tfrecord文件
    在定义完对应的tf_example 方式后,我们可以遍历图片文件,生成完整的tfrecord文件了。
def generate_tfrecord(annotation_dict, output_path, resize_size=None):
    num_valid_tf_example = 0
    writer = tf.python_io.TFRecordWriter(output_path)
    for image_path, label in annotation_dict.items():
        if not tf.gfile.GFile(image_path):
            print('%s does not exist.' % image_path)
            continue
        tf_example = create_tf_example(image_path, label, resize_size)
        if tf_example:
            writer.write(tf_example.SerializeToString())
            num_valid_tf_example += 1

            if num_valid_tf_example % 100 == 0:
                print('Create %d TF_Example.' % num_valid_tf_example)
    writer.close()
    print('Total create TF_Example: %d' % num_valid_tf_example)

对应着,在读取tfrecord时,slim提供了 slim.dataset.Dataset 的API接口,非常方便对读入的tfrecord数据进行操作。

def get_record_dataset(record_path,
                       reader=None, 
                       num_samples=50000, 
                       num_classes=32):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format'),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)

在返回了slim.dataset.Dataset这个slim支持的data封装后, 我们可直接对返回的图片数据进行reshape,保证这个图片张量的shape与网络结构的输入层shape一致。

   dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                 num_classes=FLAGS.num_classes)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # 输出当前tensor的静态shape 和动态shape,与另一种读取方式进行对比
    print("----------tf.shape(image): ",tf.shape(image))
    print("----------image.get_shape(): ",image.get_shape())
    image = _fixed_sides_resize(image, output_height=368, output_width=368)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    #capacity=5*FLAGS.batch_size,
                                    allow_smaller_final_batch=True)

其中,对三维图片张量进行reshape的代码如下

def _fixed_sides_resize(image, output_height, output_width):
    """Resize images by fixed sides.
    
    Args:
        image: A 3-D image `Tensor`.
        output_height: The height of the image after preprocessing.
        output_width: The width of the image after preprocessing.

    Returns:
        resized_image: A 3-D tensor containing the resized image.
    """
    output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
    output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)

    image = tf.expand_dims(image, 0)
    resized_image = tf.image.resize_nearest_neighbor(
        image, [output_height, output_width], align_corners=False)
    resized_image = tf.squeeze(resized_image)
    resized_image.set_shape([None, None, 3])
    return resized_image

完成了这几步之后,我们就可以利用image 和 label 进行神经网络训练了。

利用tf.parse_single_example 读写tfrecord文件

这种方式我们需要自己手动将一维的图片tensor,先还原成三维图片tensor。 因为每一张图片的shape不相同。那么我们需要将图片的shape也存入tfrecord文件中。当我们从tfrecord文件中读取时,我们先利用tf.reshape将一维图片张量还原成三维图片张量,再reshape规定的网络输入尺寸。

  • 照例,此处的重点在于tf_example的构建。在这一部分,我将图片的shape作为一个feature,也存入了tfrecord里面。 那么,在对张量的还原时,我们可以利用这个三维的shape tensor,
def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    # 对于RGBA 4通道的图片进行处理
    image,is_process = process_image_channels(image)

    # Resize
    width, height = image.size
    if resize_size is not None:
        if width > height:
            width = int(width * resize_size / height)
            height = resize_size
        else:
            width = resize_size
            height = int(height * resize_size / width)
        image = image.resize((width, height), Image.ANTIALIAS)
    
    img_array = np.asarray(image)
    shape = img_array.shape
    byte_image = image.tobytes()
    
    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image': bytes_feature(byte_image),
            'label': int64_feature(label),
            'img_shape': int64_list_feature(shape)}))
    return tf_example
  • 在完成这个后,我们仍旧可以使用上述提及的generate_tfrecord 函数来生成对应的tfrecord

  • 那么,对应这种方式生成的tfrecord文件,我们该如何读取呢?
    在这里,我给出对应的parse_example函数就足以了。

def parse(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.

    features = {
        'image':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'label':
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
                                                                     dtype=tf.int64)),
        'img_shape': 
            tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.parse_single_example(
        serialized=serialized, features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.decode_raw(image_raw, tf.uint8)
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)
    
    shape = parsed_example['img_shape']
    
    image = tf.reshape(image,shape=shape)
    
    if not (shape[0] == shape[1] == default_img_size):
        image = _fixed_sides_resize(image,default_img_size,default_img_size)
    
    image.set_shape([default_img_size,default_img_size,3])
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

在这里,读写tfrecord的重要流程就已经展现好了。

对比

这两种方式有一个比较重要的区别,那就是制作tfrecord时存储的图片信息不同。
使用slim api时 我们制作tfrecord 时,相关代码为

    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()

当我们使用第二种方式时,制作tfrecord时存储的图片信息的相关代码如下所示。

image = Image.open(img_dir)
byte_image = image.tobytes()

第一种方式保存的图片信息,其字节数不等于图片的height, width, channel的乘积。 所以不能用 第二种的方式去读取这种方式存储的tfrecord。 会出现 reshape时 维度不对的错误。 当然,使用slim.dataset.Dataset 则不需要考虑这个问题了。 网络上使用slim.dataset.Dataset 来加载tfrecord的方式,都是使用第一种方式存储的tfrecord数据。

第二种方式,其存储的图片字节大小等于图片的height, width, channel的乘积。所以它可以直接用tf.reshape直接将原图矩阵还原回来,然后再进行下一步的reshape操作。

总结

之所以写这篇文章,是因为网络上针对不定尺寸图片tfrecord读取的解决方案不是很完善。
例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
将height, width,channel 分别存入tfrecord,然后按照提问者描述这样是不成功的。
再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解决方案

image_rows = tf.cast(features['rows'], tf.int32)
image_cols = tf.cast(features['cols'], tf.int32)
image_data = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))

这种方式在tf.reshape阶段会报错,因为我们无法将 两个tensor和一个int数值组合起来。最完善的方式是直接将shape作为一个整体存入tfrecord中,最终读取出来就是一个张量了。

你可能感兴趣的:(Tensorflow针对不定尺寸的图片读写tfrecord文件总结)