tf.data.Dataset.map()函数的理解

在用dataset读取tfrecord的时候,看到别人的代码里面基本都有tf.data.Dataset.map()这个部分,而且前面定义了解析tfrecord的函数decord_example(example)之后,在后面的的map里面直接就dataset.map(decord_example)这样使用,并没有给example赋值。
具体代码在这里:

def decode_example(example, resize_height, resize_width, label_nums):
    dics={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'label':tf.FixedLenFeature([],tf.int64)
         }
    parsed_example = tf.parse_single_example(serialized=example, features=dics)

    tf_image=tf.decode_raw(parsed_example['image_raw'], out_type=tf.uint8)  # 这个其实就是图像的像素模式,之前我们使用矩阵来表示图像

    tf_image=tf.reshape(tf_image, shape=[resize_height, resize_width, 3])  # 对图像的尺寸进行调整,调整成三通道图像

    tf_image=tf.cast(tf_image,tf.float32)*(1./255)  # 对图像进行归一化以便保持和原图像有相同的精度

    tf_label=tf.cast(parsed_example['label'],tf.int64)

    tf_label=tf.one_hot(tf_label, label_nums,on_value=1,off_value=0)  # 将label转化成用one_hot编码的格式

    return tf_image, tf_label


def create_dataset(tfrecords_file, batch_size, resize_height, resize_width, num_class):

    dataset = tf.data.TFRecordDataset(tfrecords_file)

    # dataset = dataset1.map(decode_example)

    dataset = dataset.map(lambda x: decode_example(x, resize_height, resize_width, num_class))

    dataset = dataset.shuffle(20000).batch(batch_size)

    # dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

    return dataset

对于这点我是百思不得其解。翻了一下午的博客论坛之后,看到一个比较合理的解释,结合我自己的理解,在这里写出来:在使用dataset = tf.data.TFRecordDataset(tfrecords_file)生成一个新的dataset后,这个dataset已经含有在decord_example(example)中的example所需要的参数,看起来虽然没有传参,但其实参数是在内部进行了传递。因此可以直接用dataset = dataset1.map(decode_example)

如果我们还想在map里添加额外的参数,就要用lambda表达式,也就是dataset = dataset.map(lambda x: decode_example(x, resize_height, resize_width, num_class))。而这里的x看起来没有外部参数传进去,但其实是和上面所说的一样。在dataset = tf.data.TFRecordDataset(tfrecords_file)创建dataset之后,x也就是example所需要的参数都已经在dataset里了。

借用一个例子就是:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

这里x的参数就是上面ds = tf.data.Dataset.range(5)所创建出来的0~4的值。

参考链接:
[1] https://blog.csdn.net/nofish_xp/article/details/83116779?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.nonecase

[2] https://stackoverflow.com/questions/46263963/how-to-map-a-function-with-additional-parameter-using-the-new-dataset-api-in-tf1

你可能感兴趣的:(学习感悟)