[tf 2.0] padded_batch_use

参考:

  1. tf2.0中输出 不用sess.run
    https://stackoom.com/question/3nd25/%E5%BB%BA%E8%AE%AE%E5%9C%A8tensorflow-%E4%B8%AD%E8%B0%83%E8%AF%95-tf-data-Dataset-%E6%93%8D%E4%BD%9C
  2. padded_batch用法
    https://blog.csdn.net/z2539329562/article/details/89791783
import tensorflow as tf
x = [[1, 0, 0],
     [2, 3, 0],
     [4, 5, 6],
     [7, 8, 0],
     [9, 0, 0],
     [0, 1, 0]]
 
padded_shapes=(
        tf.TensorShape([None])
        )
dataset = tf.data.Dataset.from_tensor_slices(x)
def rd(dataset):
    # 如何输出成[1 0 0]具体数值的形式
    dataset_iter = dataset.__iter__()
#     for i in range(len(x)):
    while(True):
        try:
            print(dataset_iter.next())
        except:
            return 
# 切片之后的 dataset 
rd(dataset)
dataset = dataset.padded_batch(2, padded_shapes=padded_shapes)
# batch_size=2 后的dataset
rd(dataset)
#output
tf.Tensor(
[[1 0 0]
 [2 3 0]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[4 5 6]
 [7 8 0]], shape=(2, 3), dtype=int32)
tf.Tensor(
[[9 0 0]
 [0 1 0]], shape=(2, 3), dtype=int32)
dataset = tf.data.Dataset.range(100)
#
# rd(dataset)
# dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x)
# dataset = dataset.padded_batch(4, padded_shapes=[None])
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset

<MapDataset shapes: (None,), types: tf.int64>
dataset = dataset.padded_batch(4, padded_shapes=[None])
dataset
<PaddedBatchDataset shapes: (None, None), types: tf.int64>
rd(dataset)

[tf 2.0] padded_batch_use_第1张图片
[tf 2.0] padded_batch_use_第2张图片
[tf 2.0] padded_batch_use_第3张图片

你可能感兴趣的:(tensorflow2.0)