PNN模型是发表于2016年的推荐类文章,文章地址为:https://arxiv.org/pdf/1611.00144.pdf
解决了直接将embedding输入给MLP层表达特征交叉不充分的问题,论文使用了两种特殊的乘积方式来表示特征之间的交叉,实验也证明,相比于16年以前的模型也取得了不错的效果,现在主要是学习其一种思想;
(1)模型的整体结构如下图所示,IPNN部分和OPNN部分是整个模型的核心部分;
(2)介绍l2层,为输出层,这一层可以得到预测结果的概率;
(3)接着是一个l2层,其实这两个层就是DNN模型,也是很好理解的;
(4)接着介绍IPNN部分,这部分由一个embedding层通过如下函数进行表示;
其中z为embedding的表示,f其实就是embedding;
经过下列的变换可以得到一个复杂度较低的计算结果;
(4)OPNN部分介绍:
(5)接着损失函数的设计;
论文中比较了几种模型在如下两个数据集中的效果,实验结果也说明PNN模型的效果在当时的模型中还是比较领先的;
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras.initializers import glorot_normal
import tensorflow as tf
class ipnn():
def __init__(self, **kwargs):
super(ipnn, self).__init__(**kwargs)
def build(self, input_shape):
super(ipnn, self).build(input_shape)
def call(self, inputs):
embed_list = inputs
row = []
col = []
num = len(inputs)
for i in range(num - 1):
for j in range(i + 1, num):
row.append(i)
col.append(j)
p = tf.concat([embed_list[idx] for idx in row], axis=1)#(batch, num_pairs, embed_size)
q = tf.concat([embed_list[idx] for idx in col], axis=1)
inner_product = p * q
inner_product = tf.reduce_sum(inner_product, aixs=2, keep_dims=True)
return inner_product
class opnn(Layer):
def __init__(self, seed=1024, **kwargs):
self.seed = seed
super(opnn, self).__init__(**kwargs)
def build(self, input_shape):
num_inputs = len(input_shape)
num_pairs = int(num_inputs * (num_inputs - 1) / 2)
input_shape = input_shape[0]
embed_size = input_shape[-1].value
self.kernel = self.add_weight(name='kernel', shape=(embed_size, num_pairs, embed_size),
dtype=tf.float32, initializer=glorot_normal(self.seed), trainable=True)
super(opnn, self).build(input_shape)
def call(self, inputs):
embed_list = inputs
num = embed_list
row = []
col = []
for i in range(num - 1):
for j in range(i, num):
row.append(i)
col.append(j)
p = tf.concat([embed_list[idx] for idx in row], axis=1)
q = tf.concat([embed_list[idx] for idx in col], axis=1)#(batch, num_pairs, embed_size)
p = tf.expand_dims(p, axis=1)
kp = tf.reduce_sum((tf.multiply((tf.transpose(tf.reduce_sum(tf.multiply(p, self.kernel), axis=-1), [0, 2, 1])), q)), axis=-1)
return kp