wide&deep模型:https://blog.csdn.net/caoyuan666/article/details/105869670
函数API实现wide&deep模型
子类API实现wide&deep模型
本实验使用数据为房价预测的数据集,如果不清楚的小伙伴请看:
一个房价预测回归项目轻松入门TensorFlow
多输入一般用于多套输入特征的情况下使用。
通过查看数据集维度,可知本数据集共有8个特征
from sklearn.datasets import fetch_california_housing
housing=fetch_california_housing()
#print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)
输出结果:
(20640, 8)
(20640,)
用一个输入维度分别为5和6的多输入例子来展示:
#多输入 函数式的方法
input_wide=keras.layers.Input(shape=[5])
input_deep=keras.layers.Input(shape=[6])
hidden1=keras.layers.Dense(30,activation='relu')(input_deep)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
concat=keras.layers.concatenate([input_wide,hidden2])
output=keras.layers.Dense(1)(concat)
model=keras.models.Model(inputs=[input_wide,input_deep],
outputs=output)
model.summary()
model.compile(loss='mean_squared_error',
optimizer='adam',)
模型结构:
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_6 (InputLayer) [(None, 6)] 0
__________________________________________________________________________________________________
dense_6 (Dense) (None, 30) 210 input_6[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer) [(None, 5)] 0
__________________________________________________________________________________________________
dense_7 (Dense) (None, 30) 930 dense_6[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 35) 0 input_5[0][0]
dense_7[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 1) 36 concatenate_2[0][0]
==================================================================================================
Total params: 1,176
Trainable params: 1,176
Non-trainable params: 0
__________________________________________________________________________________________________
之前我们数据集为8维特征,这时候当然要拆分一下,第一个截取前5维特征,第二个截取后6维特征,中间的部分特征使用了两次:
callbacks=[keras.callbacks.EarlyStopping(patience=5,min_delta=1e-2)]
x_train_scaled_wide=x_train_scaled[:,:5]
x_train_scaled_deep=x_train_scaled[:,2:]
x_valid_scaled_wide=x_valid_scaled[:,:5]
x_valid_scaled_deep=x_valid_scaled[:,2:]
x_test_scaled_wide=x_test_scaled[:,:5]
x_test_scaled_deep=x_test_scaled[:,2:]
history=model.fit([x_train_scaled_wide,x_train_scaled_deep],y_train,
epochs=20,
validation_data=([x_valid_scaled_wide,x_valid_scaled_deep],y_valid),
callbacks = callbacks )
input_wide=keras.layers.Input(shape=[5])
input_deep=keras.layers.Input(shape=[6])
hidden1=keras.layers.Dense(30,activation='relu')(input_deep)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
concat=keras.layers.concatenate([input_wide,hidden2])
output=keras.layers.Dense(1)(concat)
output2=keras.layers.Dense(1)(hidden2)
model=keras.models.Model(inputs=[input_wide,input_deep],
outputs=[output,output2])
model.summary()
model.compile(loss='mean_squared_error',
optimizer='adam',)
模型结构:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 6)] 0
__________________________________________________________________________________________________
dense (Dense) (None, 30) 210 input_2[0][0]
__________________________________________________________________________________________________
input_1 (InputLayer) [(None, 5)] 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 930 dense[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 35) 0 input_1[0][0]
dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 36 concatenate[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 1) 31 dense_1[0][0]
==================================================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
__________________________________________________________________________________________________