tensorfow中的Dataset

转:https://www.jianshu.com/p/aeb54ed224b3

1 Dataset API的导入

TensorFlow 1.3中,Dataset API是放在contrib包中的:tf.contrib.data.Dataset

TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:tf.data.Dataset

下面的示例代码将以TensorFlow 1.4版本为例,如果使用TensorFlow 1.3的话,需要进行简单的修改(即加上contrib)。

2 基本概念:Dataset与Iterator

 


在初学时,我们只需要关注两个最重要的基础类:DatasetIterator

 

Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。

2.1 用tf.data.Dataset.from_tensor_slices创建了一个最简单的Dataset

import tensorflow as tf
import numpy as np

# 创建了一个dataset,这个dataset中含有5个元素,分别是1.0, 2.0, 3.0, 4.0, 5.0
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

如何将这个dataset中的元素取出呢?方法是从Dataset中示例化一个Iterator,然后对Iterator进行迭代。

在非Eager模式下,读取上述dataset中元素的方法为

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))

对应的输出结果应该就是从1.0到5.0。语句iterator = dataset.make_one_shot_iterator()dataset中实例化了一个Iterator,这个Iterator是一个one shot iterator,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。由于这是非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。

如果一个dataset中元素被读取完了,再尝试sess.run(one_element)的话,就会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据的行为是一致的。在实际程序中,可以在外界捕捉这个异常以判断数据是否读取完,请参考下面的代码:

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print("end!")

Eager模式中,创建Iterator的方式有所不同

搭模型更方便了:之前搭模型通常要认真记下每一步Tensor的shape和意义,然后再操作。
现在可以轻松点,边搭边写,忘记形状或者含义的时候可以直接打出来看。
另外流程控制可以使用Python的内建语法,更加直观。

调试时no more sess.run() ! 之前在调试时必须要加上sess.run(),
很麻烦,现在可以直接把变量print出来,亦可使用IDE的监控工具单步调试。

最后,如果之前我们想在自己的程序中用tf开头的函数,需要手动开启Session将结果的Tensor转换成Numpy数组,或者使用官方提供的函数修饰器。
现在只需要用开启这个Eager模式,就可以直接把tf开头的函数当普通函数用了。

通过tfe.Iterator(dataset)的形式直接创建Iterator并迭代。迭代时可以直接取出值,不需要使用sess.run()

import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
for one_element in tfe.Iterator(dataset):
    print(one_element)

2.2 从内存中创建更复杂的Dataset

matrix

tf.data.Dataset.from_tensor_slices的功能不止如此,它的真正作用是切分传入Tensor的第一个维度,生成相应的dataset

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))

传入的数值是一个矩阵,它的形状为(5, 2)tf.data.Dataset.from_tensor_slices就会切分它形状上的第一个维度,最后生成的dataset中一个含有5个元素,每个元素的形状是(2, ),即每个元素是矩阵的一行。

dict

在实际使用中,我们可能还希望Dataset中的每个元素具有更复杂的形式,如每个元素是一个Python中的tuple,或是Python中的dict。例如,在图像识别问题中,一个元素可以是{"image": image_tensor, "label": label_tensor}的形式,这样处理起来更方便。

tf.data.Dataset.from_tensor_slices同样支持创建这种dataset,例如我们可以让每一个元素是一个dict

dataset = tf.data.Dataset.from_tensor_slices(
    {
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                       
        "b": np.random.uniform(size=(5, 2))
    }
)

这时函数会分别切分"a"中的数值以及"b"中的数值,最终dataset中的一个元素就是类似于{"a": 1.0, "b": [0.9, 0.1]}的形式。

tuple

利用tf.data.Dataset.from_tensor_slices创建每个元素是一个tupledataset也是可以的:

dataset = tf.data.Dataset.from_tensor_slices(
  (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2)))
)

3 对Dataset中的元素做变换:Transformation

Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset(类似以Spark中的RDD操作)。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。

常用的Transformation有:
map
batch
shuffle
repeat
下面就分别进行介绍。

(1)map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0

(2)batch

batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:

dataset = dataset.batch(32)

(3)shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)

(4)repeat

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

dataset = dataset.repeat(5)

如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:

dataset = dataset.repeat()

4 综合运用——用Dataset读取图片文件

前面的例子里很多的都是从Ternsor对象中创建Dataset, 所以用Iterator读取到的可能是一些常量数据,比如文件名,数组之类的。但是在真实的世界中,训练数据都是存放在文件中的,比如CSV,JPG,所以我们关心的其实并不是这些文件名本身,还是其中的内容。那么如果我的Tensor中存放的是一些文件名字,怎么用Dataset来读取其中的数据呢?

例子:读入磁盘中的图片和图片相应的label,并将其打乱,组成batch_size=32的训练样本。在训练时重复10个epoch。

# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# 图片文件的列表
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# label[i]就是图片filenames[i]的label
labels = tf.constant([0, 37, ...])

# 此时dataset中的一个元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

# 此时dataset中的一个元素是(image_resized, label)
dataset = dataset.map(_parse_function)

# 此时dataset中的一个元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)

在这个过程中,dataset经历三次转变:

  • 运行dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))后,dataset的一个元素是(filename, label)filename是图片的文件名,label是图片对应的标签。

  • 之后通过map,将filename对应的图片读入,并缩放为28x28的大小。此时dataset中的一个元素是(image_resized, label)

  • 最后,dataset.shuffle(buffersize=1000).batch(32).repeat(10)的功能是:在每个epoch内将图片打乱组成大小为32batch,并重复10次。最终,dataset中的一个元素是(image_resized_batch, label_batch)image_resized_batch的形状为(32, 28, 28, 3),而label_batch的形状为(32, ),接下来我们就可以用这两个Tensor来建立模型了。

5 综合运用——用Dataset读取文本文件(重要!!!)

Dataset提供了一个数据预处理的API map()。 预处理的意思是可以对每一个element进行transformationIteratorget_next()拿到的可能是一个字符串代表某个文件名或者CSV文件里的一行,然后transformation的时候将这个文件的内容读取出来并保存在内存的Tensor对象。

"""
1.了解数据集:特征值futures和标记值label
2.加载数据集:训练数据training和测试数据test
3.设定估算器estimator:深层神经网络分类器DNNClassifier
4.训练模型:喂食函数train_input_fn和训练方法train
5.评估训练出来的模型:喂食函数eval_input_fn和评估方法evaluate
6.应用模型进行预测:classifier.predict
"""

import os
import pandas as pd
import tensorflow as tf

FUTURES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

# 格式化数据文件的目录地址,把数据文件的路径处理成可以使用的train_path和test_path
dir_path = os.path.dirname(os.path.realpath(__file__))
train_path = os.path.join(dir_path, 'D:/git/tensorflow/base_code/data/iris_train.csv')
test_path = os.path.join(dir_path, 'D:/git/tensorflow/base_code/data/iris_test.csv')

# 载入训练数据
train = pd.read_csv(train_path, names=FUTURES, header=0)
train_x, train_y = train, train.pop('Species')

# 载入测试数据
test = pd.read_csv(test_path, names=FUTURES, header=0)
test_x, test_y = test, test.pop('Species')

# 设定特征值的名称
feature_columns = []
for key in train_x.keys():
    feature_columns.append(tf.feature_column.numeric_column(key=key))

# print(train_x)
# print(test_y)
# print(feature_columns)


# 选定估算器:深层神经网络分类器
"""
直接使用tensorflow提供的估算器来从花朵数据推测分类规则。
tf.estimator里面包含了很多估算器,
这里我们使用深层神经网络分类器DNNClassifier(Deep Neural Network Classifier)。
"""
classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3)
# hidden_units是用来设定深层神经网络分类器的复杂程度的,n_classes对应我们三种花朵类型。

# 针对训练的喂食函数
def train_input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)  # 每次随机调整数据顺序
    return dataset.make_one_shot_iterator().get_next()
"""
首先我们定义了“喂食函数”train_input_fn,用来将4个特征数据features和植物学家分类结果labels组合在一起,
然后又用把这些花的顺序搅乱shuffle一下,这样以后每次我们进行训练都可能得到不同的结果。
然后,我们设定日志logging输入训练进度信息,稍后如果你觉得它太烦人可以把INFO改为WARN,只输出警告信息。
接着我们就启动了分类器的train训练。这里面的100和1000稍后你可以任意调节尝试不同结果。
"""


# 设定仅输出警告提示,可改为INFO
tf.logging.set_verbosity(tf.logging.WARN)

# 开始训练模型!
batch_size = 100
classifier.train(input_fn=lambda: train_input_fn(train_x, train_y, batch_size), steps=1000)



"""
上面我们把训练数据“喂食”给分类器进行训练,当然计算机也就从数据中猜出了三种花分类的依据,也就是我们训练得到的模型。
但我们并不知道这个“训练”出来的模型是否正确,为了考验一下它,我们把test测试数据交给它,让计算机用这个模型来预测这些test花的分类,
然后我们把预测得到的分类和植物学家为test花标注的分类进行对比,如果模型预测的都正确,我就说这个模型的精确度accuracy是100%。
分类器的evaluate方法是用来评估模型的精确度的,它同样需要一个“喂食函数”eval_input_fn把需要预测的特征值和应该得到的结果label值处理一下。
"""
# 针对测试的喂食函数
def eval_input_fn(features, labels, batch_size):
    features = dict(features)
    inputs = (features, labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()

# 评估我们训练出来的模型质量
eval_result = classifier.evaluate(
    input_fn=lambda: eval_input_fn(test_x, test_y, batch_size))

print(eval_result)

# 应用模型
# 支持100次循环对新数据进行分类预测
for i in range(0, 100):
    print('\nPlease enter features: SepalLength,SepalWidth,PetalLength,PetalWidth')
    a, b, c, d = map(float, input().split(','))  # 捕获用户输入的数字
    predict_x = {
        'SepalLength': [a],
        'SepalWidth': [b],
        'PetalLength': [c],
        'PetalWidth': [d],
    }

    # 进行预测
    predictions = classifier.predict(input_fn=lambda: eval_input_fn(predict_x,labels=[0],batch_size=batch_size))

    # 预测结果是数组,尽管实际我们只有一个
    for pred_dict in predictions:
        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]
        print(SPECIES[class_id], 100 * probability)

"""
这段程序可以重复对100朵新花做分类。我们每次需要按照SepalLength,SepalWidth,PetalLength,PetalWidth的顺序把花朵的测量结果输进去,
然后只要回车,人工智能训练出来的模型就会告诉我们这朵花是什么类型(并标注了有多大可能是这个类型),
如下图所示测量数据是6.1,2.3,5.1,0.2的花朵有99.94%的可能是变色鸢尾花Versicolor。
"""

# 输出结果
"""
{'loss': 1.6923486, 'average_loss': 0.05641162, 'global_step': 1000, 'accuracy': 0.96666664}
Please enter features: SepalLength,SepalWidth,PetalLength,PetalWidth
6.1,2.3,5.1,0.2
Versicolor 99.56179857254028
"""

6 Dataset的其它创建方法

除了tf.data.Dataset.from_tensor_slices外,目前Dataset API还提供了另外三种创建Dataset的方式:

  • tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。

  • tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。

  • tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。

7 更多类型的Iterator

在非Eager模式下,最简单的创建Iterator的方法就是通过dataset.make_one_shot_iterator()来创建一个one shot iterator。除了这种one shot iterator外,还有三个更复杂的Iterator

  • initializable iterator
  • reinitializable iterator
  • feedable iterator


作者:7125messi
链接:https://www.jianshu.com/p/aeb54ed224b3
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

你可能感兴趣的:(python,tensorflow,python学习)