github上提供的应用代码
关于wide&deep的作用,参加翻译的google一篇博客。
训练数据是Census Income Data Set
该数据集包含48000条样本,其中属性有年龄(age)、职业(occupation)、教育(education)和收入(income)等,收入是二元标签,要不>50k要不<=50k。数据集大概分为32000条训练样本和16000条测试样本。
包含的属性如下:
字段 | 取值 |
---|---|
age | continuous |
workclass | Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked |
fnlwgt | continuous |
education | Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. |
education-num | continuous |
marital-status | Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. |
occupation | Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. |
relationship | Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. |
race | White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. |
sex | Female, Male. |
capital-gain | continuous. |
capital-loss | continuous. |
hours-per-week | continuous. |
native-country | United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands. |
数据集都是原始特征,一些还是字符串格式,需要进行特征处理,都转换为数值。
关于连续和分类数据的特征工程,可以参考:
理解特征工程(1)-连续数值数据
理解特征工程(2)-分类数据
整理goolge blog:tensorflow feature columns简介
对于上述数据集中连续数值特征(continuous),如age、education_num等直接处理,如
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_per_week = tf.feature_column.numeric_column('hours_per_week')
对于education、marital_status、occupation等可以用分类数据特征处理tf.feature_column.categorical_column_with_vocabulary_list转换,如:
education = tf.feature_column.categorical_column_with_vocabulary_list(
'education', [
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
'5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
'marital_status', [
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
也可以用tf.feature_column.categorical_column_with_hash_bucket方法处理,如:
# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=_HASH_BUCKET_SIZE)
还可以做交叉特征,如:
crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'],
hash_bucket_size=_HASH_BUCKET_SIZE),
]
然后组成wide和deep特征,如下:
# Wide columns and deep columns.
base_columns = [
education, marital_status, relationship, workclass, occupation,
age_buckets,
]
crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'],
hash_bucket_size=_HASH_BUCKET_SIZE),
]
wide_columns = base_columns + crossed_columns
deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
tf.feature_column.embedding_column(occupation, dimension=8),
]
构建的训练基于高级接口estimator。
def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = model_column_fn()
hidden_units = [100, 75, 50, 25]
# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0},
inter_op_parallelism_threads=inter_op,
intra_op_parallelism_threads=intra_op))
if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)