图像分割tf-record

转 https://zhuanlan.zhihu.com/p/32078191


出于想把程序放在美团云的dls上跑的目的,学习了一下如何使用TF自带的数据格式TFRecord(吐槽一下美团云一次只能上传20个文件的设定,想把一个数据集上传上去真的太麻烦了)。说是针对分割问题的使用方法是因为在分割问题里label也是图片,所以就只有图片,其实TFRecord的使用方法非常灵活,只要能组织起自己的数据就可以随便用起来。

TFRecord的格式

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的:

message Example{
    Features features = 1;
}; 

message Features{
    map<string, Feature> feature = 1;
};

message Feature{
    oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
}; 

这个格式也是需要在程序中使用的,仅做了解。

将已有的数据集转换成TFRecord格式

  1. 定义一个将已有的数据转换成Feature数据结构的函数(官方教程中的函数)
  2. 将数据集中的文件名写入一个txt文件中(为了在程序中方便读取),可以这样来组织:
    文件名1.jpg 文件名1.png
    文件名2.jpg 文件名2.png
    ......
  3. 创建一个句柄来读这个txt文件
  4. 定义一个TFRecord的writer
  5. 逐个文件来写入TFRecoder文件
import tensorflow as tf

TXT_PATH = './dataset.txt'
TFRECORD_PATH = './dataset.tfrecord'

# 1.定义一个将已有的数据转换成Feature数据结构的函数(官方教程中的函数)
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 3.创建一个句柄来读这个txt文件
f = open(TXT_PATH)

# 4.定义一个TFRecord的writer
writer = tf.python_io.TFRecordWriter(TFRECORD_PATH)

# 5.逐个文件来写入TFRecoder文件
for i in f.readlines():
    # 每一行的文件名是用空格隔开的,所以需要使用split方法把string映射成list
    item = i.split()
    
    # 这里可能有些教程里会使用skimage或者opencv来读取文件,但是我试了一下opencv的方法
    # 一个400MB左右的训练集先用cv2.imread转换成numpy数组再转成string,最后写入TFRecord
    # 得到的文件有17GB,所以这里推荐使用FastGFile方法,生成的tfrecord文件会变小,  
    # 唯一不好的一点就是这种方法需要在读取之后再进行一次解码。
    img = tf.gfile.FastGFile(item[0], 'rb').read()
    label = tf.gfile.FastGFile(item[1], 'rb').read()

    # 按照第一部分中Example Protocol Buffer的格式来定义要存储的数据格式
    example = tf.train.Example(features=tf.train.Features(feature={
        'raw_image': _bytes_feature(img),
        'label': _bytes_feature(label)
    }))
    # 最后将example写入TFRecord文件
    writer.write(example.SerializeToString())

writer.close()
f.close()

因为FastGFile读取的是图片没有解码过的的原始数据,所以在使用存在tfrecord中的这些原始数据时,需要对读取出来的图片原始数据进行解码。

读取TFRecord格式的数据

  1. 把所有的TFRecord文件名列表写入队列中(只有一个就写一个文件名在列表中,多个就写多个)
  2. 创建一个读取器
  3. 将队列中的tfrecord文件读取为example格式
  4. 根据定义数据的方式对应说明读取的方式
  5. 对图片进行解码
# 1. 把所有的TFRecord文件名列表写入队列中(只有一个就写一个文件名在列表中,多个就写多个)
queue = tf.train.string_input_producer([TFRECORD_PATH], shuffle=True)

# 2. 创建一个读取器
reader = tf.TFRecordReader()

# 3. 将队列中的tfrecord文件读取为example格式
_, serialized_example = reader.read(queue)

# 4. 根据定义数据的方式对应说明读取的方式
features = tf.parse_single_example(serialized_example, 
                                   features={
                                       'raw_image': tf.FixedLenFeature([], tf.string)
                                       'label': tf.FixedLenFeature([], tf.string)
                                   })
img = features['raw_image']
label = features['label']

# 5. 对图片进行解码
img = tf.image_decode_jpeg(img, channels=3)
label = tf.image_decode_png(label, channels=1)

# 然后就可以用tf.train.batch方法来生成一个batch的image和label啦
img_batch, label_batch = tf.train.batch([img, label], BATCH_SIZE)

你可能感兴趣的:(图像分割tf-record)