实战Keras3.0:回归神经网络模型

 

# 导入所需库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
from keras.models import Sequential
from keras.layers import Dense

# 生成模拟数据
# 使用make_regression函数生成100个样本,每个样本有1个特征,噪声为0.1
X, y = make_regression(n_samples=100, n_features=1, noise=0.1)

#将y的形状调整为(100, 1)
y = y.reshape(-1, 1)

# train_test_split函数,用于划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义神经网络模型
model = Sequential()

# 添加输入层,设置输入维度为1,输出维度为64的全连接层
model.add(Dense(64, input_dim=1, activation='relu'))

# 添加隐藏层,设置输出维度为32的全连接层
model.add(Dense(32, activation='relu'))

# 添加输出层,设置输出维度为1的全连接层,激活函数为线性
model.add(Dense(1, activation='linear'))

#打印模型
model.summary()

# 编译模型,设置损失函数为均方误差,优化器为随机梯度下降,评估指标为均方误差
model.compile(loss='mean_squared_error', optimizer='sgd', metrics=['mean_squared_error'])

# 训练模型,设置迭代次数为100,批量大小为32
model.fit(X_train, y_train, epochs=100, batch_size=32)

# 预测测试集
y_pred = model.predict(X_test)

# 绘制真实值与预测值的散点图
plt.scatter(X_test, y_test, color='blue', label='True')
plt.scatter(X_test, y_pred, color='red', label='Predicted')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.show()

 

 模型结构

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 64)                128       
                                                                 
 dense_1 (Dense)             (None, 32)                2080      
                                                                 
 dense_2 (Dense)             (None, 1)                 33        
                                                                 
=================================================================
Total params: 2,241
Trainable params: 2,241
Non-trainable params: 0
_________________________________________________________________

 

 训练过程

Epoch 1/100
3/3 [==============================] - 0s 4ms/step - loss: 214.8241 - mean_squared_error: 214.8241
Epoch 2/100
3/3 [==============================] - 0s 2ms/step - loss: 195.0716 - mean_squared_error: 195.0716
Epoch 3/100
3/3 [==============================] - 0s 2ms/step - loss: 123.4095 - mean_squared_error: 123.4095
Epoch 4/100
3/3 [==============================] - 0s 3ms/step - loss: 24.0782 - mean_squared_error: 24.0782
Epoch 5/100
3/3 [==============================] - 0s 2ms/step - loss: 12.7566 - mean_squared_error: 12.7566
Epoch 6/100
3/3 [==============================] - 0s 3ms/step - loss: 7.9639 - mean_squared_error: 7.9639
Epoch 7/100
3/3 [==============================] - 0s 2ms/step - loss: 34.8655 - mean_squared_error: 34.8655
Epoch 8/100
3/3 [==============================] - 0s 3ms/step - loss: 9.1391 - mean_squared_error: 9.1391
Epoch 9/100
3/3 [==============================] - 0s 2ms/step - loss: 5.0033 - mean_squared_error: 5.0033
Epoch 10/100
3/3 [==============================] - 0s 3ms/step - loss: 3.7178 - mean_squared_error: 3.7178
Epoch 11/100
3/3 [==============================] - 0s 2ms/step - loss: 8.3191 - mean_squared_error: 8.3191
Epoch 12/100
3/3 [==============================] - 0s 3ms/step - loss: 15.9171 - mean_squared_error: 15.9171
Epoch 13/100
3/3 [==============================] - 0s 2ms/step - loss: 2.6201 - mean_squared_error: 2.6201
Epoch 14/100
3/3 [==============================] - 0s 3ms/step - loss: 1.9299 - mean_squared_error: 1.9299
Epoch 15/100
3/3 [==============================] - 0s 2ms/step - loss: 1.4617 - mean_squared_error: 1.4617
Epoch 16/100
3/3 [==============================] - 0s 3ms/step - loss: 4.1653 - mean_squared_error: 4.1653
Epoch 17/100
3/3 [==============================] - 0s 2ms/step - loss: 0.9672 - mean_squared_error: 0.9672
Epoch 18/100
3/3 [==============================] - 0s 2ms/step - loss: 0.5275 - mean_squared_error: 0.5275
Epoch 19/100
3/3 [==============================] - 0s 2ms/step - loss: 0.4061 - mean_squared_error: 0.4061
Epoch 20/100
3/3 [==============================] - 0s 2ms/step - loss: 0.5154 - mean_squared_error: 0.5154
Epoch 21/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1824 - mean_squared_error: 0.1824
Epoch 22/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1629 - mean_squared_error: 0.1629
Epoch 23/100
3/3 [==============================] - 0s 3ms/step - loss: 0.2634 - mean_squared_error: 0.2634
Epoch 24/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1077 - mean_squared_error: 0.1077
Epoch 25/100
3/3 [==============================] - 0s 3ms/step - loss: 0.1250 - mean_squared_error: 0.1250
Epoch 26/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1474 - mean_squared_error: 0.1474
Epoch 27/100
3/3 [==============================] - 0s 3ms/step - loss: 0.1953 - mean_squared_error: 0.1953
Epoch 28/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0625 - mean_squared_error: 0.0625
Epoch 29/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0814 - mean_squared_error: 0.0814
Epoch 30/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1829 - mean_squared_error: 0.1829
Epoch 31/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1177 - mean_squared_error: 0.1177
Epoch 32/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0517 - mean_squared_error: 0.0517
Epoch 33/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0456 - mean_squared_error: 0.0456
Epoch 34/100
3/3 [==============================] - 0s 2ms/step - loss: 0.1010 - mean_squared_error: 0.1010
Epoch 35/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0919 - mean_squared_error: 0.0919
Epoch 36/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0734 - mean_squared_error: 0.0734
Epoch 37/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0617 - mean_squared_error: 0.0617
Epoch 38/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0574 - mean_squared_error: 0.0574
Epoch 39/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0338 - mean_squared_error: 0.0338
Epoch 40/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0297 - mean_squared_error: 0.0297
Epoch 41/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0404 - mean_squared_error: 0.0404
Epoch 42/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0288 - mean_squared_error: 0.0288
Epoch 43/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0686 - mean_squared_error: 0.0686
Epoch 44/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0294 - mean_squared_error: 0.0294
Epoch 45/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0372 - mean_squared_error: 0.0372
Epoch 46/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0195 - mean_squared_error: 0.0195
Epoch 47/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0550 - mean_squared_error: 0.0550
Epoch 48/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0309 - mean_squared_error: 0.0309
Epoch 49/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0160 - mean_squared_error: 0.0160
Epoch 50/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0185 - mean_squared_error: 0.0185
Epoch 51/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0313 - mean_squared_error: 0.0313
Epoch 52/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0154 - mean_squared_error: 0.0154
Epoch 53/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0139 - mean_squared_error: 0.0139
Epoch 54/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0136 - mean_squared_error: 0.0136
Epoch 55/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0127 - mean_squared_error: 0.0127
Epoch 56/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0159 - mean_squared_error: 0.0159
Epoch 57/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0416 - mean_squared_error: 0.0416
Epoch 58/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0164 - mean_squared_error: 0.0164
Epoch 59/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0161 - mean_squared_error: 0.0161
Epoch 60/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0164 - mean_squared_error: 0.0164
Epoch 61/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0120 - mean_squared_error: 0.0120
Epoch 62/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0125 - mean_squared_error: 0.0125
Epoch 63/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0137 - mean_squared_error: 0.0137
Epoch 64/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0104 - mean_squared_error: 0.0104
Epoch 65/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0102 - mean_squared_error: 0.0102
Epoch 66/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0151 - mean_squared_error: 0.0151
Epoch 67/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0155 - mean_squared_error: 0.0155
Epoch 68/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0111 - mean_squared_error: 0.0111
Epoch 69/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0099 - mean_squared_error: 0.0099
Epoch 70/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0105 - mean_squared_error: 0.0105
Epoch 71/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0091 - mean_squared_error: 0.0091
Epoch 72/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0097 - mean_squared_error: 0.0097
Epoch 73/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0242 - mean_squared_error: 0.0242
Epoch 74/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0192 - mean_squared_error: 0.0192
Epoch 75/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0099 - mean_squared_error: 0.0099
Epoch 76/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0085 - mean_squared_error: 0.0085
Epoch 77/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0100 - mean_squared_error: 0.0100
Epoch 78/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0092 - mean_squared_error: 0.0092
Epoch 79/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0090 - mean_squared_error: 0.0090
Epoch 80/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0098 - mean_squared_error: 0.0098
Epoch 81/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0085 - mean_squared_error: 0.0085
Epoch 82/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0087 - mean_squared_error: 0.0087
Epoch 83/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0099 - mean_squared_error: 0.0099
Epoch 84/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0128 - mean_squared_error: 0.0128
Epoch 85/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0092 - mean_squared_error: 0.0092
Epoch 86/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0118 - mean_squared_error: 0.0118
Epoch 87/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0097 - mean_squared_error: 0.0097
Epoch 88/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0097 - mean_squared_error: 0.0097
Epoch 89/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0084 - mean_squared_error: 0.0084
Epoch 90/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0084 - mean_squared_error: 0.0084
Epoch 91/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0117 - mean_squared_error: 0.0117
Epoch 92/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0109 - mean_squared_error: 0.0109
Epoch 93/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0094 - mean_squared_error: 0.0094
Epoch 94/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0083 - mean_squared_error: 0.0083
Epoch 95/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0104 - mean_squared_error: 0.0104
Epoch 96/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0100 - mean_squared_error: 0.0100
Epoch 97/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0086 - mean_squared_error: 0.0086
Epoch 98/100
3/3 [==============================] - 0s 3ms/step - loss: 0.0087 - mean_squared_error: 0.0087
Epoch 99/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0075 - mean_squared_error: 0.0075
Epoch 100/100
3/3 [==============================] - 0s 2ms/step - loss: 0.0101 - mean_squared_error: 0.0101
1/1 [==============================] - 0s 90ms/step

 

 真实值(True)与预测值(Predicted)的散点图 

 

实战Keras3.0:回归神经网络模型_第1张图片

 

你可能感兴趣的:(回归,神经网络,keras)