这篇paper很简单,网上也有很多人翻译过来. 使用tensorflow自带的库,其实很简单。有些难点的地方是,关于特征工程部分的理解。请参考 第九课 tensorflow 特征工程: feature_column
下面是具体的实践demo:
# coding:utf-8
"""
wide and deep
"""
from framework.data_input import IDataInput
from framework.inference import IInference
from framework.train import ITrain
from framework.eval import IEval
import pandas as pd
import tensorflow as tf
import common
from hook import LoggerHook
class DataInput(IDataInput):
def __init__(self,
csv_column_names,
label_column_name,
input_file_paths,
shuffle,
batch_size,
example_per_epoch_num,
parallel_thread_num=16):
super(DataInput, self).__init__(input_file_paths,
batch_size,
example_per_epoch_num,
parallel_thread_num=parallel_thread_num)
self._csv_column_names = csv_column_names
self._label_column_name = label_column_name
self._shuffle = shuffle
def read_data(self):
input_file_path = self._input_file_paths[0]
df = pd.read_csv(tf.gfile.Open(input_file_path),
names=self._csv_column_names,
skipinitialspace=True,
skiprows=1)
df.dropna(axis=0, how='any')
label = df[self._label_column_name].apply(lambda x: ">50K" in x).astype(int)
print(df.head())
print(label.head())
return tf.estimator.inputs.pandas_input_fn(
x=df,
y=label,
batch_size=self._batch_size,
shuffle=self._shuffle,
num_threads=self._parallel_thread_num,
num_epochs=self._example_per_echo_num
)
def _preprocess_data(self, record):
pass
def _generate_train_batch(self, train_data, label, shuffle=True):
pass
def _read_data_from_queue(self, file_path_queue):
pass
class Inference(IInference):
def __init__(self, model_dir, model_type, linear_feature_columns, dnn_feature_columns):
super(Inference, self).__init__()
self._model_dir = model_dir
self._model_type = model_type
self._linear_feature_columns = linear_feature_columns
self._dnn_featrue_columns = dnn_feature_columns
def inference(self, data):
if self._model_type == 'wide':
return tf.estimator.LinearClassifier(feature_columns=self._linear_feature_columns,
model_dir=self._model_dir)
elif self._model_type == 'deep':
return tf.estimator.DNNClassifier(feature_columns=self._dnn_featrue_columns,
model_dir=self._model_dir,
hidden_units=[100, 50])
elif self._model_type == 'wide_n_deep':
return tf.estimator.DNNLinearCombinedClassifier(model_dir=self._model_dir,
linear_feature_columns=self._linear_feature_columns,
dnn_feature_columns=self._dnn_featrue_columns,
dnn_hidden_units=[100, 50])
else:
raise RuntimeError('no %s model type' % self._model_type)
class Train(ITrain):
def __init__(self, model_type, model_dir):
super(Train, self).__init__()
self._model_type = model_type
self._model_dir = model_dir
@property
def model_dir(self):
return self._model_dir
@property
def model_type(self):
return self._model_type
def train(self):
data_input = DataInput(common.CSV_COLUMNS,
common.LABEL_COLUMN_NAME,
['./input/adult.data'],
shuffle=True,
batch_size=128,
example_per_epoch_num=None)
input_fn = data_input.read_data()
if self._model_type == 'wide':
inference = Inference(self._model_dir,
self._model_type,
common.base_columns + common.crossed_columns, None)
elif self._model_type == 'deep':
inference = Inference(self._model_dir,
self._model_type,
None,
common.deep_columns)
elif self._model_type == 'wide_n_deep':
inference = Inference(self._model_dir,
self._model_type,
common.crossed_columns,
common.deep_columns)
else:
raise RuntimeError('model type error: ' + self._model_type)
model = inference.inference(None)
logger_hook = LoggerHook()
model.train(input_fn=input_fn, hooks=[logger_hook], steps=2000)
return model
class Eval(IEval):
def __init__(self, model):
super(Eval, self).__init__(None, None, None, 128)
self._model = model
def accuracy(self, predict_results, labels):
pass
def read_test_data_set(self):
pass
def predict(self, test_data_batch):
pass
def eval(self):
data_input = DataInput(common.CSV_COLUMNS,
common.LABEL_COLUMN_NAME,
['./input/adult.test'],
shuffle=False,
batch_size=128,
example_per_epoch_num=1)
test_data_input_fn = data_input.read_data()
results = self._model.evaluate(input_fn=test_data_input_fn,
steps=None)
for key in sorted(results):
print("%s: %s" % (key, results[key]))
# coding:utf-8
"""
common
"""
import tensorflow as tf
GENDER = 'gender'
EDUCATION = 'education'
MARITAL_STATUS = 'marital_status'
RELATIONSHIP = 'relationship'
WORK_CLASS = 'workclass'
OCCUPATION = 'occupation'
NATIVE_COUNTRY = 'native_country'
AGE = 'age'
EDUCATION_NUM = 'education_num'
CAPITAL_GAIN = 'capital_gain'
CAPITAL_LOSS = 'capital_loss'
HOURS_PER_WEEK = 'hours_per_week'
CSV_COLUMNS = [
AGE, WORK_CLASS, "fnlwgt", EDUCATION, EDUCATION_NUM,
MARITAL_STATUS, OCCUPATION, RELATIONSHIP, "race", GENDER,
CAPITAL_GAIN, CAPITAL_LOSS, HOURS_PER_WEEK, NATIVE_COUNTRY,
"income_bracket"
]
gender = tf.feature_column.categorical_column_with_vocabulary_list(GENDER, ['Female', 'Male'])
education = tf.feature_column.categorical_column_with_vocabulary_list(EDUCATION,
["Bachelors", "HS-grad", "11th", "Masters",
"9th",
"Some-college", "Assoc-acdm", "Assoc-voc",
"7th-8th",
"Doctorate", "Prof-school", "5th-6th", "10th",
"1st-4th",
"Preschool", "12th"
]
)
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
MARITAL_STATUS, [
"Married-civ-spouse", "Divorced", "Married-spouse-absent",
"Never-married", "Separated", "Married-AF-spouse", "Widowed"
])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
RELATIONSHIP, [
"Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
"Other-relative"
])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
WORK_CLASS, [
"Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
"Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
])
occupation = tf.feature_column.categorical_column_with_hash_bucket(
OCCUPATION, hash_bucket_size=1000)
native_country = tf.feature_column.categorical_column_with_hash_bucket(NATIVE_COUNTRY,
hash_bucket_size=1000)
age = tf.feature_column.numeric_column(AGE)
education_num = tf.feature_column.numeric_column(EDUCATION_NUM)
capital_gain = tf.feature_column.numeric_column(CAPITAL_GAIN)
capital_loss = tf.feature_column.numeric_column(CAPITAL_LOSS)
hours_per_week = tf.feature_column.numeric_column(HOURS_PER_WEEK)
age_buckets = tf.feature_column.bucketized_column(age,
boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
base_columns = [
gender, education, marital_status, relationship, workclass, occupation,
native_country, age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column([EDUCATION, OCCUPATION], hash_bucket_size=1000),
tf.feature_column.crossed_column(
[age_buckets, EDUCATION, OCCUPATION], hash_bucket_size=1000),
tf.feature_column.crossed_column(
[NATIVE_COUNTRY, OCCUPATION], hash_bucket_size=1000)
]
deep_columns = [
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(gender),
tf.feature_column.indicator_column(relationship),
tf.feature_column.embedding_column(native_country, dimension=8),
tf.feature_column.embedding_column(occupation, dimension=8),
age,
education_num,
capital_gain,
capital_loss,
hours_per_week
]
LABEL_COLUMN_NAME = "income_bracket"
MODEL_DIR = './output'
WIDE_MODEL_DIR = MODEL_DIR + '/wide'
DEEP_MODEL_DIR = MODEL_DIR + '/deep'
WIDE_N_DEEP_DIR = MODEL_DIR + '/wide_deep'
WIDE_MODEL_TYPE = 'wide'
DEEP_MODEL_TYPE = 'deep'
WIDE_N_DEEP_MODEL_TYPE = 'wide_n_deep'
# coding:utf-8
"""
hook
"""
import tensorflow as tf
import time
import datetime
class LoggerHook(tf.train.SessionRunHook):
def __init__(self):
super(LoggerHook, self).__init__()
self._step = -1
self._start_time = time.time()
self._log_frequency = 10
def begin(self):
self._step = -1
self._start_time = time.time()
self._log_frequency = 10
def before_run(self, run_context):
self._step += 1
# loss会作为参数一起被运行 会在after_run运行结束后 将run_values 也就是这里的loss值传回
loss_value = tf.get_collection(tf.GraphKeys.LOSSES)
return tf.train.SessionRunArgs(loss_value)
def after_run(self, run_context, run_values):
if self._step % self._log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
i = 0
for l in loss_value:
print(i, ':', l)
i += 1
print('-' * 40)
examples_per_sec = self._log_frequency * 128 / duration
sec_per_batch = float(duration / self._log_frequency)
format_str = ('%s: step %d, loss = todo (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.datetime.now(), self._step,
examples_per_sec, sec_per_batch))
# coding:utf-8
"""
main
"""
from wide_and_deep import DataInput
from wide_and_deep import Train
from wide_and_deep import Eval
import common
import logging
import tensorflow as tf
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
# model_type = common.WIDE_MODEL_TYPE
# model_type = common.DEEP_MODEL_TYPE
model_type = common.WIDE_N_DEEP_MODEL_TYPE
train = None
if model_type == common.WIDE_MODEL_TYPE:
train = Train(common.WIDE_MODEL_TYPE, common.WIDE_MODEL_DIR)
elif model_type == common.DEEP_MODEL_TYPE:
train = Train(common.DEEP_MODEL_TYPE, common.DEEP_MODEL_DIR)
elif model_type == common.WIDE_N_DEEP_MODEL_TYPE:
train = Train(common.WIDE_N_DEEP_MODEL_TYPE, common.WIDE_N_DEEP_DIR)
else:
raise RuntimeError("error model type")
if train is not None:
model = train.train()
wd_eval = Eval(model=model)
wd_eval.eval()