TensorFlow读数据

在TensorFlow中读数据一般有三种方法:

  • 使用placeholder读内存中的数据
  • 使用queue读硬盘中的数据
  • 使用Dataset读内存个硬盘中的数据

基本概率

由于第三种方法在语法上更简洁,因此本文主要介绍第三种方法。官方给出的Dataset API类图:

TensorFlow读数据_第1张图片
image.png

其中终于重要的两个基础类: DatesetIterator
Dateset是具有相同类型的“元素”的有序表,元素可以是向量、字符串、图片等。

从内存中创建Dataset

以数字元素为例:


TensorFlow读数据_第2张图片
例1

从Dataset中实例化一个Iterator,然后对Iterator进行迭代。

iterator = dataset.make_one_shot_iterator() 

从dataset中实例化一个iterator,是“one shot iterator”,即只能从头到尾读取一次。

one_element = iterator.get_next()

从iterator中取出一个元素, one_element是一个tensor,因此需要调用sess.run(one_element)取出值。

如果元素被读取完了,再sess.run(one_element)会抛出tf.errors.OutOfRangeError异常。解决方法:使用 dataset.repeat()

更复杂的输入形式,例如,在图像识别的应用中,一个元素可以使{“image”:image_tensor, “label”:lable_tensor}

dataset = tf.data.Dataset.from_tensor_slices(
    {
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                       
        "b": np.random.uniform(size=(5, 2))
    }
)

最终dataset中的一个元素为{"a": 1.0, "b": [0.9, 0.1]}的形式。
或者

dataset = tf.data.Dataset.from_tensor_slices(
  (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2)))
)

对Dataset中的元素做变换:Transformation

一个Dataset通过Transformation变成一个新的Dataset。常用的操作有:

  • map
  • batch
  • shuffle
  • repeat

下面分别来介绍以上几个操作。
(1)map
map接收一个函数,dataset中的每个元素都可以作为这个函数的输入,并将函数的返回值作为新的dataset,例如:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0

(2)batch
将多个元素组合成batch,例如:

dataset = dataset.batch(32)

(3)shuffle
打乱dataset中的元素,参数buffersize表示打乱时buffer的大小。

dataset = dataset.shuffle(buffer_size=10000)

(4)repeat
将整个序列重复多次,只用用来处理epoch。如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出。tf.errors.OutOfRangeError异常:

dataset = dataset.repeat(5)

例子:读磁盘图片与对应的label

读入磁盘中的图片和图片相应的label,并将其打乱,组成batch_size=32的训练样本。在训练时重复10个epoch。

# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# 图片文件的列表
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# label[i]就是图片filenames[i]的label
labels = tf.constant([0, 37, ...])

# 此时dataset中的一个元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

# 此时dataset中的一个元素是(image_resized, label)
dataset = dataset.map(_parse_function)

# 此时dataset中的一个元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)
# 此时dataset中的一个元素是(image_resized_batch, label_batch)
# image_resized_batch的形状为(32, 28, 28, 3), label_batch的形状为(32, )

你可能感兴趣的:(TensorFlow读数据)