TensorFlow2.0入门到进阶2.12 ——函数API实现wide&deep模型

文章目录

  • 1、wide&deep原理
  • 2、代码

1、wide&deep原理

wide&deep模型:https://blog.csdn.net/caoyuan666/article/details/105869670

2、代码

函数式API 在创建模型时就像调用函数一样,将上一层结果像函数变量一样输入的下一层的函数中:

#复合函数:f(x)=h(g(x))
input = keras.layers.Input(shape=x_train.shape[1:])
hidden1=keras.layers.Dense(30,activation='relu')(input)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
 
#将wide和deep数据拼接
concat = keras.layers.concatenate([input,hidden2])
output = keras.layers.Dense(1)(concat)

#由于函数式API没有将模型返回保存,所以需要使用model将模型固化下来
model = keras.models.Model(inputs=[input],
                         outputs=[output])

model.summary()
model.compile(loss='mean_squared_error',
              optimizer='adam',)

网络结构结构:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 8)]          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 30)           270         input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30)           930         dense[0][0]                      
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 38)           0           input_1[0][0]                    
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1)            39          concatenate[0][0]                
==================================================================================================
Total params: 1,239
Trainable params: 1,239
Non-trainable params: 0
__________________________________________________________________________________________________

wide层:input(对于输入数据只经过一层input)
deep层:hidden2(经过两层隐层,相对较深,这里只是举例,其实有点前)

你可能感兴趣的:(tensorflow,深度学习,神经网络,机器学习,python)