小白读deepFm CTR预估(shenweichen代码)

作者的代码可见:https://github.com/shenweichen/DeepCTR

本文分析的数据全部借用文章中的数据。见



PART1 main函数

step1:读取数据,代码中使用的是pandas的read_csv文件的读取方法。

step2:将离线型特征、连续型特征的特征名、标签名分别放到不同的list中。

        为了后续画图方便,没有将所有特征用到,只用了前两个连续特征和前两个离散特征



这样实际上我们得到了三个list。如果我们用该代码,使用的特征名和他给的特征名不一致的,我们可以直接显示给定特征名。例如

    sparse_features = ['C1', 'C2']
    dense_features  = ['I1', 'I2']
    target          = ['label']

step3:基础特征工程

        这也是在一般的机器学习中我们也会做的一个操作就是特征工程,只不过深度学习的特征工程较机器学习简单,一般我们不去人工挖掘一些交叉的特征,而是交给embedding层去做处理。
        一般有哪些特征工程的方法,可以参考:https://www.zhihu.com/question/29316149
该论文中只做了两种处理,一是对离线型特征做LabelEncoder,该方法是将特征中的文本信息转成数字。例如:

from pandas import DataFrame
from sklearn.preprocessing import LabelEncoder
import pandas as pd
data = {'水果':['苹果','梨','草莓'],
       '数量':[3,2,5],
       '价格':[10,9,8]}
#根据dict创建dataframe
df = DataFrame(data)
lbe = LabelEncoder()
df['水果'] = lbe.fit_transform(data['水果'])

        可以看到苹果被重新编码为1,梨重新编码为0,草莓重新编码为2。若使用的时候,自己的数据集中的数据本身就是数字型的,其实这步就可以省略。
        对于连续型特征,由于涉及到线性运算,若某一维度的值特别大后,就会导致该特征对模型整体影响偏高。所以为了消除这种影响,一般会进行归一化处理。本文采用的min-max归一化。

mms = MinMaxScaler(feature_range=(0, 1))
data[dense_features] = mms.fit_transform(data[dense_features])

将每列的特征压缩在0-1范围之类。当前还有0均值1方差的归一化处理方式。

step4:数据处理(格式转换)

sparseFeat = [SparseFeat(feat, vocabulary_size=data[feat].nunique(), embedding_dim=4) for feat in sparse_features]
denseFeat =  [DenseFeat(feat, 1,) for feat in dense_features]
#其中vocabulary_size表示的是词典的词大小,即离散变量它的取值有多少个。
#例如对于性别这个离散特征,它的取值有'female', 'male'和未识别。则有三种可取值。

        其中nunique()可以求dataframe的某列中不同元素的个数。用该方法就可以求出离散特征可取值的个数。还是对之前的苹果和梨的例子。可以看到

print(df['水果'].nunique()) #结果为3

其实sparseFeat 和denseFeat 只是将之前的特征名封装成了一个类。对于离散型特征C1,要指定该特征可取值的个数,即vocabulary_size,还有embedding_dim即经过embedding后希望得到的维度。对于embedding_dim其实有个疑问:

if embedding_dim == "auto":
       embedding_dim =  6 * int(pow(vocabulary_size, 0.25))

        对于embedding后的维度,用户可以通过显式指定embedding_dim,也可以设置为auto,这样会根据上述的公式得到维度。也就是词典的大小开两次根号后的六倍。但是为什么这么处理呢?


PART2 DeepFM方法

        为了直观地看出代码中的模型整个框架,先用C1、C2、I1、I2两个连续型特征和两个离散型特征。使用如下代码

plot_model(model, to_file='model_deepfm.png', show_shapes=True,show_layer_names=True)

画出模型框架图,结果如下:


        可以看出,整个框架分为三个部分,FM,DNN,Linear三个部分。为了更好了理解每一个部分,我们拆分来理解。我们先看一下DeepFM的输入
        首先理解FM模型,该模型可以理解为是在LR模型基础上的一种改进吧。因为线性模型仅考虑特征的线性组合,但是有时候特征与特征之间也存在一定的关系,由此应运出该模型。首先推导出模型表达式
FM模块


这样就推导出了FM的计算公式,简记为和平方-平方和。接着我们看代码中对应FM模型的部分

fm_logit = add_func([FM()(concat_func(v, axis=1))
                         for k, v in group_embedding_dict.items() if k in fm_group])

首先得理解group_embedding_dict到底是什么?其对应的代码如下:

def input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed, prefix='', seq_mask_zero=True,
                               support_dense=True, support_group=False):
    sparse_feature_columns = list(
        filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
    varlen_sparse_feature_columns = list(
        filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []

    embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, init_std, seed, prefix=prefix,
                                                    seq_mask_zero=seq_mask_zero)
    group_sparse_embedding_dict = embedding_lookup(embedding_matrix_dict, features, sparse_feature_columns)
    dense_value_list = get_dense_input(features, feature_columns)
    if not support_dense and len(dense_value_list) > 0:
        raise ValueError("DenseFeat is not supported in dnn_feature_columns")

    sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, varlen_sparse_feature_columns)
    group_varlen_sparse_embedding_dict = get_varlen_pooling_list(sequence_embed_dict, features,
                                                                 varlen_sparse_feature_columns)
    group_embedding_dict = mergeDict(group_sparse_embedding_dict, group_varlen_sparse_embedding_dict)
    if not support_group:
        group_embedding_dict = list(chain.from_iterable(group_embedding_dict.values()))
    return group_embedding_dict, dense_value_list

你可能感兴趣的:(小白读deepFm CTR预估(shenweichen代码))