简单修改了官网的例子,并添加了一点点注释。主要是为了理解用法。
https://github.com/tensorflow/models/tree/r1.9.0/official/wide_deep
tf.__version__ == 1.9.0
数据集是预测收入是否超过5万美元,二分类问题。
下载数据集:
def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
print('download %s from %s ...' % (filename, url))
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)
运行模型:
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'income_bracket'
]
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']]
_NUM_EXAMPLES = {
'train': 32561,
'validation': 16281,
}
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
# def define_wide_deep_flags():
# """Add supervised learning flags, as well as wide-deep model type."""
# flags_core.define_base()
# flags_core.define_benchmark()
# flags.adopt_module_key_flags(flags_core)
# flags.DEFINE_enum(
# name="model_type", short_name="mt", default="wide_deep",
# enum_values=['wide', 'deep', 'wide_deep'],
# help="Select model topology.")
# flags.set_defaults(data_dir='./census_data',
# model_dir='./model',
# train_epochs=40,
# epochs_between_evals=2,
# batch_size=40)
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns Wide和Deep组件都会用到
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_per_week = tf.feature_column.numeric_column('hours_per_week')
# 返回的是sparser_tensor
education = tf.feature_column.categorical_column_with_vocabulary_list(
'education', [
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
'5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
'marital_status', [
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
'relationship', [
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
'Other-relative'])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
'workclass', [
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
# To show an example of hashing:
# 当category的数量很多,也就无法使用指定category的方法来处理了,那么,可以使用这种哈希分桶的方式来进行处理。
occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=1000)
# Transformations.
age_buckets = tf.feature_column.bucketized_column(
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
# Wide columns and deep columns.
# The wide model is a linear model with a wide set of *sparse and crossed feature* columns
base_columns = [
# 全是离散特征
education, marital_status, relationship, workclass, occupation,
age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=1000),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
]
wide_columns = base_columns + crossed_columns
deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
# 对类别少的分类特征列做 one-hot 编码
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
# embedding的长度一般会经验设置为 categories ** (0.25)
tf.feature_column.embedding_column(occupation, dimension=8),
]
return wide_columns, deep_columns
def build_estimator(model_dir, model_type):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = build_model_columns()
hidden_units = [100, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0}))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have run data_download.py' % data_file)
def parse_csv(line):
print('Parsing', data_file)
# tf.decode_csv 会把csv文件转换成 a list of Tensor,一列一个。
columns = tf.decode_csv(line, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
# https://stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-dataset-map-dataset-prefetch-and-dataset-shuffle
# shuffle程序将会从dataset中随机生成大小等于buffer size的样本。
# 不是shuffle整个数据集,而是维护buffe_size个元素的buffer,从那个buffer中随机选取下一个元素。
# 如果buffer_size比数据集中元素数大的话,你会得到一个均匀的shuffle,如果是 1 那就根没有shuffle。
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
# dataset.shuffle(buffersize=1000).batch(32).repeat(10)的功能是:
# 在每个epoch内将图片打乱组成大小为32的batch,并重复10次。
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return dataset
def export_model(model, model_type, export_dir):
"""Export to SavedModel format.
Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
"""
wide_columns, deep_columns = build_model_columns()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)
def run_wide_deep(model_type):
"""Run Wide-Deep training and eval loop.
Args:
model_type: string indicating model type. "wide", "deep" or "wide_deep".
"""
model_dir = ''
if model_type == 'wide':
model_dir = 'model/wide'
elif model_type == 'deep':
model_dir = 'model/deep'
else:
model_dir = 'model/wide_deep'
data_dir = 'census_data'
train_epochs = 24
epochs_between_evals = 2
batch_size = 40
print('train_epochs', train_epochs)
print('epochs_between_evals', epochs_between_evals)
print('batch_size', batch_size)
# Clean up the model directory if present
# shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
model = build_estimator(model_dir, model_type)
train_file = os.path.join(data_dir, 'adult.data')
test_file = os.path.join(data_dir, 'adult.test')
print('train_file:', train_file)
print('test_file', test_file)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return input_fn(
train_file, epochs_between_evals, True, batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, batch_size)
# run_params = {
# 'batch_size': batch_size,
# 'train_epochs': train_epochs,
# 'model_type': model_type,
# }
# benchmark_logger = logger.config_benchmark_logger(flags_obj)
# benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
# loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
# train_hooks = hooks_helper.get_train_hooks(
# flags_obj.hooks, batch_size=flags_obj.batch_size,
# tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
# 'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(train_epochs // epochs_between_evals):
model.train(input_fn=train_input_fn)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
tf.logging.info('Results at epoch %d / %d',
(n + 1) * epochs_between_evals,
train_epochs)
tf.logging.info('-' * 60)
for key in sorted(results):
tf.logging.info('%s: %s' % (key, results[key]))
# benchmark_logger.log_evaluation_result(results)
# if model_helpers.past_stop_threshold(
# flags_obj.stop_threshold, results['accuracy']):
# break
# if results['accuracy'] > stop_threshold:
# break
def main(_):
# 参数分别设置成wide,deep,wide_deep
run_wide_deep('deep')
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
# define_wide_deep_flags()
absl_app.run(main)
运行结果:
accuracy | accuracy_baseline | auc | auc_precision_recall | average_loss | global_step | label/mean | loss | precision | prediction/mean | recall | |
wide | 0.833671 | 0.763774 | 0.880125 | 0.688972 | 0.357315 | 19548 | 0.236226 | 14.25843 | 0.674006 | 0.244115 | 0.573063 |
deep | 0.852036 | 0.763774 | 0.902916 | 0.757215 | 0.323559 | 19548 | 0.236226 | 12.91145 | 0.704876 | 0.251214 | 0.642746 |
wide&deep | 0.851606 | 0.763774 | 0.904219 | 0.76885 | 0.321143 | 19548 | 0.236226 | 12.815 | 0.735974 | 0.244277 | 0.579823 |
。