dataset = tf.data.Dataset.from_tensor_slices(img_ids)
img_ids 是包含所有需要训练的图片id集合, 这行代码执行完之后, 在dataset里面的每一个元素都是一个tensor, 每个tensor的值是图片id.
执行shuffle, 在这里先执行shuffle而不是之后执行, 是为了降低memory cost.
dataset = dataset.shuffle(buffer_size=1000).repeat(1)
tensorflow 官网上有关于是先shuffle还是先repeat的介绍, 原文如下:
The
tf.data.Dataset.repeattransformation repeats the input data a finite (or infinite) number of times; each repetition of the data is typically referred to as an *epoch*. The
tf.data.Dataset.shuffletransformation randomizes the order of the dataset's examples.
If the
repeattransformation is applied before the
shuffletransformation, then the epoch boundaries are blurred. That is, certain elements can be repeated before other elements appear even once. On the other hand, if the
shuffletransformation is applied before the repeat transformation, then performance might slow down at the beginning of each epoch related to initialization of the internal state of the
shuffletransformation. In other words, the former (
repeatbefore
shuffle) provides better performance, while the latter (
shufflebefore
repeat) provides stronger ordering guarantees.
翻译过来大概意思就是说, 先repeat再shuffle会模糊每一个epoch的边界, 意思是dataset如果epcoh是>1的话, repeat再shuffle, 会使得某些元素在一个eopch里重复出现好几次, 造成网络在初始训练的时候, 训练前期效果可能没有那么好; 先shuffle再repeat, 虽然可以保证数据的一致性, 但会在每一次repeat开始的时候进行shuffle, 降低运行速度.
dataset = dataset.map(
lambda img_id:tuple(
tf.py_function(
func = paser_func,
inp = [img_id],
Tout = [tf.float32, tf.float32]
)), num_parallel_calls=tf.data.experimental.AUTOTUNE)
这个是使用tf.py_function对python的函数进行封装, 需要声明三个变量, func就是python预处理的函数, inp就是函数的参数, Tout就是python函数返回的数据类型, 如果是np.float32会自动转换成tf.float32. num_parallel_calls函数是为了提高数据map的处理速度, 并行处理,tf.data.experimental.AUTOTUNE可以让程序自动的选择最优的线程并行个数. ps: 在1.x中, inp传进去的img_id, 原本是string型会被转换成byte类型, 直接转下就可以
if type(img_id) == type(b'123'):
img_id = str(img_id, encoding='utf-8')
然而在2.0中, 直接使用img_id = img_id.numpy()
取出byte类型值来, 再进行string值转换.
map操作除了可以有上面的写法外, 还可以去掉lambda,额外使用一个函数再一次封装:
dataset = dataset.map(tf_parse_func, num_parallel_calls = tf.data.experimental.AUTOTUNE)
def tf_parse_func(img_id):
[img, label] = tf.py_function(paser_func, [img_id], [tf.float32, tf.float32])
return img, label
其实就是把lambda函数用普通的函数替换掉, ps: lambda img_id:tuple()中tuple的意思是, lambda封装的函数返回值应该是个tuple类型的, 即要么是个列表, 要么直接返回多个值的意思
dataset = dataset.batch(params['batch_size'], drop_remainder=True)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
for epco in range(10):
for step, (img, label) in enumerate(dataset):
output = net(img)
loss = loss_func(output, label)
Here is a summary of the best practices for designing performant TensorFlow input pipelines: