解析tf.data.Dataset.interleave()的用法
dataset = tf.data.Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
print(list(dataset.as_numpy_iterator()))
dataset = dataset.interleave(
map_func=lambda x: tf.data.Dataset.from_tensors(x).repeat(3),
cycle_length=4,
block_length=2,
)
print(list(dataset.as_numpy_iterator()))
"""
输入:[1, 2, 3, 4, 5]
# 每次取block_length个元素
[1, 1,
2, 2,
3, 3,
4, 4,
1, 2,
3, 4,
5, 5,
5,
]
# 具体逻辑如下:
1. 先从原始dataset取出cycle_length个元素
>> [1, 2, 3,4] # 第一次取
>> [5] # 第二次取
2. 然后对cycle_length个元素运用map_func,即取出的这cycle_length元素分别重复repeat次,构成新的dataset
>> [[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]] # 构成的新dataset-1
>> [[5, 5, 5]] # 构成的新dataset-2
3. 从新的dataset上轮流取出block_length个元素
>> [[1, 1], [2, 2], [3, 3], [4, 4]] # 新dataset-1
>> [[1], [2], [3], [4]] # 新dataset-1
>> [[5, 5]] # 新dataset-2
>> [[5]] # 新dataset-2
4. 当新的dataset中的元素取完后,再从原始dataset取出cycle_length个元素,即重复第一步
>> [1, 1, 2, 2, 3, 3, 4, 4] # # 新dataset-1
>> [1, 2, 3, 4] # # 新dataset-1
>> [5, 5] # # 新dataset-2
>> [5] # # 新dataset-2
输出:[1, 1, 2, 2, 3, 3, 4, 4, 1, 2, 3, 4, 5, 5, 5]
"""