第十三课 wide&deep模型

这篇paper很简单,网上也有很多人翻译过来. 使用tensorflow自带的库,其实很简单。有些难点的地方是,关于特征工程部分的理解。请参考 第九课 tensorflow 特征工程: feature_column

下面是具体的实践demo:

# coding:utf-8
"""
wide and deep
"""

from framework.data_input import IDataInput
from framework.inference import IInference
from framework.train import ITrain
from framework.eval import IEval
import pandas as pd
import tensorflow as tf
import common
from hook import LoggerHook


class DataInput(IDataInput):

    def __init__(self,
                 csv_column_names,
                 label_column_name,
                 input_file_paths,
                 shuffle,
                 batch_size,
                 example_per_epoch_num,
                 parallel_thread_num=16):

        super(DataInput, self).__init__(input_file_paths,
                                        batch_size,
                                        example_per_epoch_num,
                                        parallel_thread_num=parallel_thread_num)
        self._csv_column_names = csv_column_names
        self._label_column_name = label_column_name
        self._shuffle = shuffle

    def read_data(self):
        input_file_path = self._input_file_paths[0]

        df = pd.read_csv(tf.gfile.Open(input_file_path),
                         names=self._csv_column_names,
                         skipinitialspace=True,
                         skiprows=1)
        df.dropna(axis=0, how='any')
        label = df[self._label_column_name].apply(lambda x: ">50K" in x).astype(int)

        print(df.head())
        print(label.head())

        return tf.estimator.inputs.pandas_input_fn(
            x=df,
            y=label,
            batch_size=self._batch_size,
            shuffle=self._shuffle,
            num_threads=self._parallel_thread_num,
            num_epochs=self._example_per_echo_num
        )

    def _preprocess_data(self, record):
        pass

    def _generate_train_batch(self, train_data, label, shuffle=True):
        pass

    def _read_data_from_queue(self, file_path_queue):
        pass


class Inference(IInference):

    def __init__(self, model_dir, model_type, linear_feature_columns, dnn_feature_columns):
        super(Inference, self).__init__()
        self._model_dir = model_dir
        self._model_type = model_type
        self._linear_feature_columns = linear_feature_columns
        self._dnn_featrue_columns = dnn_feature_columns

    def inference(self, data):

        if self._model_type == 'wide':
            return tf.estimator.LinearClassifier(feature_columns=self._linear_feature_columns,
                                                 model_dir=self._model_dir)
        elif self._model_type == 'deep':
            return tf.estimator.DNNClassifier(feature_columns=self._dnn_featrue_columns,
                                              model_dir=self._model_dir,
                                              hidden_units=[100, 50])
        elif self._model_type == 'wide_n_deep':
            return tf.estimator.DNNLinearCombinedClassifier(model_dir=self._model_dir,
                                                            linear_feature_columns=self._linear_feature_columns,
                                                            dnn_feature_columns=self._dnn_featrue_columns,
                                                            dnn_hidden_units=[100, 50])
        else:
            raise RuntimeError('no %s model type' % self._model_type)


class Train(ITrain):

    def __init__(self, model_type, model_dir):
        super(Train, self).__init__()
        self._model_type = model_type
        self._model_dir = model_dir

    @property
    def model_dir(self):
        return self._model_dir

    @property
    def model_type(self):
        return self._model_type

    def train(self):
        data_input = DataInput(common.CSV_COLUMNS,
                               common.LABEL_COLUMN_NAME,
                               ['./input/adult.data'],
                               shuffle=True,
                               batch_size=128,
                               example_per_epoch_num=None)
        input_fn = data_input.read_data()

        if self._model_type == 'wide':
            inference = Inference(self._model_dir,
                                  self._model_type,
                                  common.base_columns + common.crossed_columns, None)
        elif self._model_type == 'deep':
            inference = Inference(self._model_dir,
                                  self._model_type,
                                  None,
                                  common.deep_columns)

        elif self._model_type == 'wide_n_deep':
            inference = Inference(self._model_dir,
                                  self._model_type,
                                  common.crossed_columns,
                                  common.deep_columns)
        else:
            raise RuntimeError('model type error: ' + self._model_type)

        model = inference.inference(None)

        logger_hook = LoggerHook()

        model.train(input_fn=input_fn, hooks=[logger_hook], steps=2000)
        return model


class Eval(IEval):

    def __init__(self, model):
        super(Eval, self).__init__(None, None, None, 128)
        self._model = model

    def accuracy(self, predict_results, labels):
        pass

    def read_test_data_set(self):
        pass

    def predict(self, test_data_batch):
        pass

    def eval(self):
        data_input = DataInput(common.CSV_COLUMNS,
                               common.LABEL_COLUMN_NAME,
                               ['./input/adult.test'],
                               shuffle=False,
                               batch_size=128,
                               example_per_epoch_num=1)
        test_data_input_fn = data_input.read_data()
        results = self._model.evaluate(input_fn=test_data_input_fn,
                                       steps=None)

        for key in sorted(results):
            print("%s: %s" % (key, results[key]))

# coding:utf-8
"""
common
"""

import tensorflow as tf

GENDER = 'gender'
EDUCATION = 'education'
MARITAL_STATUS = 'marital_status'
RELATIONSHIP = 'relationship'
WORK_CLASS = 'workclass'
OCCUPATION = 'occupation'
NATIVE_COUNTRY = 'native_country'
AGE = 'age'
EDUCATION_NUM = 'education_num'
CAPITAL_GAIN = 'capital_gain'
CAPITAL_LOSS = 'capital_loss'
HOURS_PER_WEEK = 'hours_per_week'

CSV_COLUMNS = [
    AGE, WORK_CLASS, "fnlwgt", EDUCATION, EDUCATION_NUM,
    MARITAL_STATUS, OCCUPATION, RELATIONSHIP, "race", GENDER,
    CAPITAL_GAIN, CAPITAL_LOSS, HOURS_PER_WEEK, NATIVE_COUNTRY,
    "income_bracket"
]


gender = tf.feature_column.categorical_column_with_vocabulary_list(GENDER, ['Female', 'Male'])
education = tf.feature_column.categorical_column_with_vocabulary_list(EDUCATION,
                                                                      ["Bachelors", "HS-grad", "11th", "Masters",
                                                                       "9th",
                                                                       "Some-college", "Assoc-acdm", "Assoc-voc",
                                                                       "7th-8th",
                                                                       "Doctorate", "Prof-school", "5th-6th", "10th",
                                                                       "1st-4th",
                                                                       "Preschool", "12th"
                                                                       ]
                                                                      )
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
    MARITAL_STATUS, [
        "Married-civ-spouse", "Divorced", "Married-spouse-absent",
        "Never-married", "Separated", "Married-AF-spouse", "Widowed"
    ])

relationship = tf.feature_column.categorical_column_with_vocabulary_list(
    RELATIONSHIP, [
        "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
        "Other-relative"
    ])

workclass = tf.feature_column.categorical_column_with_vocabulary_list(
    WORK_CLASS, [
        "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
        "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
    ])


occupation = tf.feature_column.categorical_column_with_hash_bucket(
    OCCUPATION, hash_bucket_size=1000)

native_country = tf.feature_column.categorical_column_with_hash_bucket(NATIVE_COUNTRY,
                                                                       hash_bucket_size=1000)

age = tf.feature_column.numeric_column(AGE)
education_num = tf.feature_column.numeric_column(EDUCATION_NUM)
capital_gain = tf.feature_column.numeric_column(CAPITAL_GAIN)
capital_loss = tf.feature_column.numeric_column(CAPITAL_LOSS)
hours_per_week = tf.feature_column.numeric_column(HOURS_PER_WEEK)

age_buckets = tf.feature_column.bucketized_column(age,
                                                  boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

base_columns = [
    gender, education, marital_status, relationship, workclass, occupation,
    native_country, age_buckets,
]

crossed_columns = [
    tf.feature_column.crossed_column([EDUCATION, OCCUPATION], hash_bucket_size=1000),
    tf.feature_column.crossed_column(
        [age_buckets, EDUCATION, OCCUPATION], hash_bucket_size=1000),
    tf.feature_column.crossed_column(
        [NATIVE_COUNTRY, OCCUPATION], hash_bucket_size=1000)
]

deep_columns = [
    tf.feature_column.indicator_column(workclass),
    tf.feature_column.indicator_column(education),
    tf.feature_column.indicator_column(gender),
    tf.feature_column.indicator_column(relationship),

    tf.feature_column.embedding_column(native_country, dimension=8),
    tf.feature_column.embedding_column(occupation, dimension=8),

    age,
    education_num,
    capital_gain,
    capital_loss,
    hours_per_week
]


LABEL_COLUMN_NAME = "income_bracket"

MODEL_DIR = './output'
WIDE_MODEL_DIR = MODEL_DIR + '/wide'
DEEP_MODEL_DIR = MODEL_DIR + '/deep'
WIDE_N_DEEP_DIR = MODEL_DIR + '/wide_deep'

WIDE_MODEL_TYPE = 'wide'
DEEP_MODEL_TYPE = 'deep'
WIDE_N_DEEP_MODEL_TYPE = 'wide_n_deep'
# coding:utf-8
"""
hook
"""

import tensorflow as tf
import time
import datetime


class LoggerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(LoggerHook, self).__init__()
        self._step = -1
        self._start_time = time.time()
        self._log_frequency = 10

    def begin(self):
        self._step = -1
        self._start_time = time.time()
        self._log_frequency = 10

    def before_run(self, run_context):
        self._step += 1
        # loss会作为参数一起被运行 会在after_run运行结束后 将run_values 也就是这里的loss值传回

        loss_value = tf.get_collection(tf.GraphKeys.LOSSES)
        return tf.train.SessionRunArgs(loss_value)

    def after_run(self, run_context, run_values):
        if self._step % self._log_frequency == 0:
            current_time = time.time()
            duration = current_time - self._start_time
            self._start_time = current_time

            loss_value = run_values.results

            i = 0
            for l in loss_value:
                print(i, ':', l)
                i += 1
            print('-' * 40)
            examples_per_sec = self._log_frequency * 128 / duration
            sec_per_batch = float(duration / self._log_frequency)

            format_str = ('%s: step %d, loss = todo (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print(format_str % (datetime.datetime.now(), self._step,
                                examples_per_sec, sec_per_batch))
# coding:utf-8
"""
main
"""

from wide_and_deep import DataInput
from wide_and_deep import Train
from wide_and_deep import Eval
import common
import logging
import tensorflow as tf

if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)

    # model_type = common.WIDE_MODEL_TYPE
    # model_type = common.DEEP_MODEL_TYPE
    model_type = common.WIDE_N_DEEP_MODEL_TYPE

    train = None

    if model_type == common.WIDE_MODEL_TYPE:
        train = Train(common.WIDE_MODEL_TYPE, common.WIDE_MODEL_DIR)
    elif model_type == common.DEEP_MODEL_TYPE:
        train = Train(common.DEEP_MODEL_TYPE, common.DEEP_MODEL_DIR)
    elif model_type == common.WIDE_N_DEEP_MODEL_TYPE:
        train = Train(common.WIDE_N_DEEP_MODEL_TYPE, common.WIDE_N_DEEP_DIR)
    else:
        raise RuntimeError("error model type")

    if train is not None:
        model = train.train()
        wd_eval = Eval(model=model)
        wd_eval.eval()

你可能感兴趣的:(第十三课 wide&deep模型)