tensorflow 建个小图

def Dataset(file_pattern, batch_size, num_epochs=1):
    logging.info('Creating Dataset from %s', file_pattern)
    return tf.data.experimental.make_csv_dataset(
        file_pattern=file_pattern,
        batch_size=batch_size,
        label_name=LABEL_NAME,
        num_epochs=num_epochs,
        num_rows_for_inference=10
    )

input_fn = Dataset(expanded, FLAGS.batch_size)

def collect_unique_tokens(input_fn):
    logging.info('Creating vocabulary...')
    vocabulary_dict = {item: set() for item in CATEGORICAL_COLUMNS}
    graph = tf.Graph()
    with graph.as_default():
        iterator = input_fn().make_one_shot_iterator()
        t_features, t_labels = iterator.get_next()
    with tf.Session(graph=graph) as sess:
        while True:
            try:
                features, _ = sess.run([t_features, t_labels])
                for item in CATEGORICAL_COLUMNS:
                    for value in features[item]:
                        vocabulary_dict[item].add(value)
            except tf.errors.OutOfRangeError:
                break
    return vocabulary_dict

你可能感兴趣的:(DL,tools)