查看本系列三种模型写法:
【tensorflow】连续输入的线性回归模型训练代码
【tensorflow】连续输入的神经网络模型训练代码
【tensorflow】连续输入+离散输入的神经网络模型训练代码
from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
from keras import Input, Model, Sequential
from keras.layers import Dense, concatenate, Embedding, LSTM
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
def get_data():
# 设置随机种子,以确保结果可复现(可选)
np.random.seed(0)
# 生成随机数据
data = np.random.rand(10000, 10)
# 正则化数据
scaler = StandardScaler()
data = scaler.fit_transform(data)
# 生成随机数据
target = np.random.rand(10000, 1)
return train_test_split(data, target, test_size=0.1, random_state=42)
data_train, data_val, target_train, target_val = get_data()
# 迭代轮次
train_epochs = 10
# 学习率
learning_rate = 0.0001
# 批大小
batch_size = 200
model = keras.models.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=[10]),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(1)
])
model.summary()
model.compile(loss="mse", optimizer=keras.optimizers.Adam(lr=learning_rate))
history = model.fit(data_train, target_train, epochs=train_epochs, batch_size=batch_size, validation_data=(data_val, target_val))
模型训练过程中的输出如下:
get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。
由于target是浮点数,所以我们这个任务就是回归任务了。
使用keras.models.Sequential构建一个序列模型。模型由一系列层按顺序连接而成。在这个例子中,模型由三个全连接层构成。
第一个隐藏层(keras.layers.Dense)具有64个神经元,使用ReLU激活函数,并指定输入形状为[10]。输入形状表示输入数据的维度。
第二个隐藏层也是一个具有64个神经元的全连接层,同样使用ReLU激活函数。
最后一层是输出层,由一个神经元组成,不使用激活函数。
模型的结构是输入层(10维)→隐藏层(64个神经元,ReLU激活函数)→隐藏层(64个神经元,ReLU激活函数)→输出层(1个神经元)。
最后,使用model.compile方法配置模型的损失函数和优化器。在这个例子中,损失函数设置为均方误差(Mean Squared Error,MSE),优化器选择Adam优化算法,并设置学习率为learning_rate。
使用model.fit方法对模型进行训练。传入训练数据data_train和目标数据target_train,设置训练轮次train_epochs、批处理大小batch_size,以及验证集数据(data_val, target_val)。
训练过程中,模型会根据给定的训练数据和目标数据进行参数更新,通过反向传播算法优化模型的权重和偏置。每个训练轮次(epoch)都会对整个训练数据集进行一次完整的训练。训练过程还会使用验证集数据对模型进行评估,以监控模型的性能和验证集上的损失。
训练过程中的损失值和其他指标会被记录在history对象中,可以用于后续的可视化和分析。