在Pytorch中数据加载是通过torch.utils.data.dataset
与torch.utils.data.dataloader
完成的,而在TensorFlow中现在主推的是使用tf.data
实现数据加载。此前,在TensorFlow中读取数据一般有两种方法:
新出的数据加载工具更加简洁高效,也是后面主推的方式,若是使用TensorFlow的Eager模式就必须使用这种数据加载方式,在这个模式下的数据加载也会存在一些细微不同(Eager模式下丢掉了Session,可以像python Debug一样调试程序,对于数据迭代可以使用python的内部函数iter
实现)。
下面是tf.data.Dataset
的类继承关系图:
一般用到的是tf.data.Dataset.from_tensor_slices
完成数据加载,但是根据上图的继承关系也提供了另外3种数据加载方式:
tf.data.TextLineDataset()
:这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。tf.data.FixedLengthRecordDataset()
:这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。tf.data.TFRecordDataset()
:顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。使用tf.data.Dataset
实现数据加载只需要使用调用一个函数就可以了,类似于下面的形式:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
在上面的数据集定义好之后会有一些函数可供选择:
tf.data.Dataset.map(f, num_parallel_calls)
tf.data.Dataset.batch(batch_size)
tf.data.Dataset.repeat(count)
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
。在实际使用中,基本可以不传count参数,无限重复这个数据集。使用过TensorFlow的大家都会知道,TF通过计算图将计算的定义和执行分隔开, 这是一种声明式(declaretive)的编程模型。确实,这种静态图的执行模式优点很多,但是在debug时确实非常不方便(类似于对编译好的C语言程序调用,此时是我们无法对其进行内部的调试),因此有了Eager Execution。
引入的Eager Execution模式后,TensorFlow就拥有了类似于Pytorch一样动态图模型能力,我们可以不必再等到see.run(*)才能看到执行结果,可以方便在IDE随时调试代码,查看OPs执行结果。在代码中添加一句话就可以实现Eager模式的启用:
tf.enable_eager_execution()
使用make_one_shot_iterator()
初始化迭代
def get_item_by_tf(file_path, label):
img_byte = tf.io.read_file(file_path)
img_decode = tf.image.decode_jpeg(img_byte)
img_decode = tf.cast(img_decode, dtype=tf.float32)
img_regular = tf.divide(tf.subtract(img_decode, 127.0), 255.0)
# img_resize = tf.image.resize_images(img_decode)
return img_regular, label
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = 32
dataset = dataset.map(get_item_by_tf, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
next_op = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
try:
while True:
img_data = sess.run(next_op)
print(img_data[0].shape)
except tf.errors.OutOfRangeError:
print("load end!")
在遇到一些特殊情况下使用TF自带的读取工具无法有效读取数据,需要使用到第三方的书读取,这里就是要使用到tf.py_func(tf.numpy_function)
来进行读取,只需要将对应的回调函数进行替换就行了:
def cv_reader_call(file_path):
file_path = file_path.decode()
img = cv2.imread(file_path, 0)[:,:,np.newaxis]
return img
def get_item_by_cv(file_path):
# return tf.numpy_function(cv_reader_call, [file_path], [tf.uint8])
return tf.py_func(cv_reader_call, [file_path], [tf.uint8])
使用make_initializable_iterator()
初始化迭代
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = tf.placeholder(dtype=tf.int64, shape=[])
dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
data_iter = dataset.make_initializable_iterator()
next_op = data_iter.get_next()
with tf.Session() as sess:
sess.run(data_iter.initializer, feed_dict={batch_size_var: 32})
tf.enable_eager_execution()
if __name__ == "__main__":
data_root = "E:/DataBase/MNIST/MNIST_DataSet/imgs_train/"
img_list = os.listdir(data_root)
img_list = [os.path.join(data_root, item) for item in img_list]
img_list = tf.constant(img_list)
dataset = tf.data.Dataset.from_tensor_slices(img_list)
batch_size_var = 32
dataset = dataset.map(get_item_by_cv, num_parallel_calls=mlt.cpu_count())
dataset = dataset.batch(batch_size=2 * batch_size_var).shuffle(buffer_size=10 * batch_size_var).repeat()
data_iter = iter(dataset)
while True:
try:
img_data = next(data_iter)
print(img_data[0].shape)
except StopIteration:
print("load end!")