tensoflow学模型-wide&deep(2) 官方代码分析

github上提供的应用代码
关于wide&deep的作用,参加翻译的google一篇博客。

数据集

训练数据是Census Income Data Set
该数据集包含48000条样本,其中属性有年龄(age)、职业(occupation)、教育(education)和收入(income)等,收入是二元标签,要不>50k要不<=50k。数据集大概分为32000条训练样本和16000条测试样本。
包含的属性如下:

字段 取值
age continuous
workclass Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked
fnlwgt continuous
education Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num continuous
marital-status Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex Female, Male.
capital-gain continuous.
capital-loss continuous.
hours-per-week continuous.
native-country United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

训练读取的数据集是逗号分割,浏览地址。
tensoflow学模型-wide&deep(2) 官方代码分析_第1张图片

特征工程

数据集都是原始特征,一些还是字符串格式,需要进行特征处理,都转换为数值。
关于连续和分类数据的特征工程,可以参考:
理解特征工程(1)-连续数值数据
理解特征工程(2)-分类数据
整理goolge blog:tensorflow feature columns简介
对于上述数据集中连续数值特征(continuous),如age、education_num等直接处理,如

  age = tf.feature_column.numeric_column('age')
  education_num = tf.feature_column.numeric_column('education_num')
  capital_gain = tf.feature_column.numeric_column('capital_gain')
  capital_loss = tf.feature_column.numeric_column('capital_loss')
  hours_per_week = tf.feature_column.numeric_column('hours_per_week')

对于education、marital_status、occupation等可以用分类数据特征处理tf.feature_column.categorical_column_with_vocabulary_list转换,如:

  education = tf.feature_column.categorical_column_with_vocabulary_list(
      'education', [
          'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
          'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
          '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

  marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
      'marital_status', [
          'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
          'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

也可以用tf.feature_column.categorical_column_with_hash_bucket方法处理,如:

  # To show an example of hashing:
  occupation = tf.feature_column.categorical_column_with_hash_bucket(
      'occupation', hash_bucket_size=_HASH_BUCKET_SIZE)

还可以做交叉特征,如:

  crossed_columns = [ 
      tf.feature_column.crossed_column(
          ['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
      tf.feature_column.crossed_column(
          [age_buckets, 'education', 'occupation'],
          hash_bucket_size=_HASH_BUCKET_SIZE),
  ]

然后组成wide和deep特征,如下:

  # Wide columns and deep columns.
  base_columns = [
      education, marital_status, relationship, workclass, occupation,
      age_buckets,
  ]

  crossed_columns = [
      tf.feature_column.crossed_column(
          ['education', 'occupation'], hash_bucket_size=_HASH_BUCKET_SIZE),
      tf.feature_column.crossed_column(
          [age_buckets, 'education', 'occupation'],
          hash_bucket_size=_HASH_BUCKET_SIZE),
  ]

  wide_columns = base_columns + crossed_columns

  deep_columns = [
      age,
      education_num,
      capital_gain,
      capital_loss,
      hours_per_week,
      tf.feature_column.indicator_column(workclass),
      tf.feature_column.indicator_column(education),
      tf.feature_column.indicator_column(marital_status),
      tf.feature_column.indicator_column(relationship),
      # To show an example of embedding
      tf.feature_column.embedding_column(occupation, dimension=8),
  ]

estimator训练

构建的训练基于高级接口estimator。

def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
  """Build an estimator appropriate for the given model type."""
  wide_columns, deep_columns = model_column_fn()
  hidden_units = [100, 75, 50, 25] 

  # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
  # trains faster than GPU for this model.
  run_config = tf.estimator.RunConfig().replace(
      session_config=tf.ConfigProto(device_count={'GPU': 0}, 
                                    inter_op_parallelism_threads=inter_op,
                                    intra_op_parallelism_threads=intra_op))

  if model_type == 'wide':
    return tf.estimator.LinearClassifier(
        model_dir=model_dir,
        feature_columns=wide_columns,
        config=run_config)
  elif model_type == 'deep':
    return tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=hidden_units,
        config=run_config)
  else:
    return tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=wide_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=hidden_units,
        config=run_config)


你可能感兴趣的:(机器学习)