【系列学习】4.数据通用处理 —— tf工程化项目实战

【tensorflow数据集介绍】

tensorflow的数据集格式主要有4种:

  1. 内存对象数据集:适用于少量的数据集输入
  2. TFRecord数据集
  3. Dataset数据集
  4. tf.keras数据集

通常训练比较多的做法是把样本处理成TFRecord,或者Dataset数据集。
今天主要讲下怎么处理成Dataset数据集。其便捷性在于生成了Dataset对象后,可直接在上面做shuffle/map/iterate等操作。
在tf.data.Dataset中,有3中方法可以将内存中的数据转化为Dataset:

  • tf.data.Dataset.from_tensor: 根据内存对象生成Dataset对象,对象中只能有一个元素。
  • tf.data.Dataset.from_tensor_slices: 根据内存对象生成Dataset对象,对象是列表、元组、字典、Numpy数组等类型。
  • tf.data.Dataset.from_generator: 根据生成器生成Dataset对象。

我们通常使用tf.data.Dataset.from_tensor_slices。

【Dataset接口使用套路】

  1. 生成Dataset对象
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  1. 对Dateset对象中的样本进行变换操作,支持的变化操作有:shuffle,map,batch,repeat等
dataset = dataset.map(_parsefun)
  1. 创建Dataset迭代器
def getone(dataset):
  iterator = dataset.make_one_shot_iterator()
  one_element = iterator.get_next()
  return one_element
  1. 在session中取出数据
with tf.Session() as sess:
  value = sess.run(one_element)

【实战数据处理代码mydataset.py分析】

1.主函数 creat_dataset_fromdir
def creat_dataset_fromdir(directory, batch_size, isTrain = True):
  filenames, labels = list_images(directory)
  num_classes = len(set(labels))
  dataset = creat_bacthed_dataset(filenames, labels, batch_size, isTrain)
  return dataset, num_classes 
  • 从给定的数据路径加载样本和标签
  • 根据标签统计分类数量
  • 处理成batch data
2.根据给定路径加载样本&标签 list_images
def list_images(directory):
  labels = os.listdir(directory)
  labels.sort()
  files_and_labels = []
  for label in labels:
    for f in os.listdir(os.path.join(directory, label)):
      if f[0] == '.':
        continue
      if 'jpg' in f.lower() or 'png' in f.lower():
        files_and_labels.append((os.path.join(directory, label, f),label))
  
  filenames, labels = zip(*files_and_labels)
  filenames = list(filenames)
  labels = list(labels)
  unique_labels = list(set(labels))

  label_to_int = {}
  for i, label in enumerate(sorted(unique_labels)):
    label_to_int[label] = i+1
  
  labels = [label_to_int[l] for l in labels]
  return filenames, labels

样本/标签存放的路径如下:


image.png

directory下每一个分类有一个文件夹,按照每个标签遍历下面的图片,以(图片路径,对应标签)的形式存到一个总数组里。
然后通过zip分别得到两个list,一个是全部的图片路径,一个是对应的全部标签。
然后把语义的标签处理为数字后返回两个处理好的list

3.创建批数据 creat_bacthed_dataset

这里会用到tf.data提供的方法

def creat_bacthed_dataset(filenames, labels, batch_size, isTrain=True):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_parse_function, num_parallel_calls = num_workers)

  if isTrain == True:
    dataset = dataset.shuffle(buffer_size=len(filenames))
    dataset = dataset.map(training_preprocess, num_parallel_calls=num_workers)
  else:
    dataset = dataset.map(val_preprocess, num_parallel_calls=num_workers)
  return dataset.batch(batch_size)

可以看到,这边主要用了 tf.data.Dataset.from_tensor_slices来创建Dataset数据集,具体上面已经讲过了,这是创建套路中的第一步,第二部是对数组进行变化操作,这边对数据做了一个通用_parse_function的处理,并行数量可配置,这步处理是根据图片路径解析图片的步骤。

4.解析图片操作 _parse_function
  image_string = tf.read_file(filename)
  image = tf.image.decode_jpeg((image_string, channels=3))
  return image, label

回到3,解析完图片后,根据isTrain函数采用不同的处理(training_preprocess/val_preprocess),最后再batch一下,返回Dataset对象,如5所示。

5.training_preprocess/val_preprocess
image_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=True)
image_eval_preprocessing_fn = preprocessing_factory.get_preprocessing('nasnet_mobile', is_training=False)

def training_preprocess(image, label):
  image = image_preprocessing_fn(image, image_size, image_size)
  return image, label

def val_preprocess(image, label):
  image = image_eval_preprocessing_fn(image, image_size, image_size)
  return image, label

mydataset.py就是返回了这样一个Dataset数据集~比较简单,在其他的tensorflow训练项目中,我们可以对mydataset稍作修改重复使用。

你可能感兴趣的:(【系列学习】4.数据通用处理 —— tf工程化项目实战)