如何打印tensorflow 的dataset

有时,为了调试数据,需要将数据打印打出来,可以用interator来遍历数据

 

首先定义两个遍历函数,

    def print_dataset(self, data_set):
        iterator = data_set.make_one_shot_iterator()
        next_element = iterator.get_next()
        num_batch = 0
        with tf.train.MonitoredTrainingSession() as sess:
            while not sess.should_stop():
                value = sess.run(next_element)
                num_batch += 1
                print("Num Batch: ", num_batch)
                print("Batch value: ", value)
    def print_dataset2(self, data_set):
        iterator = data_set.make_initializable_iterator()
        next_element = iterator.get_next()
        num_batch = 0
        with tf.train.MonitoredTrainingSession() as sess:
            sess.run(iterator.initializer)
            while True:
                try:
                    value = sess.run(next_element)
                    print("Num Batch: ", num_batch)
                    print("Batch value: ", value)
                    #assert j == value
                    #j += 1
                    num_batch += 1
                except tf.errors.OutOfRangeError:
                    break

第一个函数不支持lookup等操作,会报错

ValueError: Failed to create a one-shot iterator for a dataset. `Dataset.make_one_shot_iterator()` does not support datasets that capture stateful objects, such as a `Variable` or `LookupTable`. In these cases, use `Dataset.make_initializable_iterator()`. (Original error: Cannot capture a stateful node (name:hash_table, type:HashTableV2) by value.)

在这种情况下,使用第二个函数。

dataset = tf.data.TextLineDataset(file_list)

print("=======len(file_list)=======",len(file_list))

print(dataset)

self.print_dataset(dataset)



self.text_set = self.text_set.map(lambda src, tgt:
                                          (self.case_table.lookup(src), self.case_table.lookup(tgt))
                                          ).prefetch(buffer_size)

self.print_dataset2(self.text_set)

 

转载于:https://www.cnblogs.com/ljstu/p/11266099.html

你可能感兴趣的:(如何打印tensorflow 的dataset)