源码连接:TensorFlow.js实现鸢尾花种类预测-机器学习文档类资源-CSDN下载
创建index.html入口文件,跳转到script主文件。
在script.js文件夹中利用预先准备好的脚本生成鸢尾花数据集,包括训练集和验证集,并打印查看。
import {getIrisData, IRIS_CLASSES} from "./data.js";
window.onload = () => {
// 加载数据
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.2);
// 打印查看数据集
xTrain.print();
yTrain.print();
xTest.print();
yTest.print();
// 打印鸢尾花种类类别
console.log(IRIS_CLASSES);
}
getIrisData(0.2):获取数据集的时候,将20%的数据当成测试集,剩下的80%当成训练集。
xTrain:训练集的特征值。
yTrain:训练集的目标值。
xTest:验证集的特征值。
yTest:验证集的目标值。
可以在控制台查看到结果:
其中特征矩阵里面的四个值分别表示:花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度。
目标值矩阵采用one-hot编码形式。
初始化一个神经网络模型,为神经网络模型添加两层,配置模型的损失函数、激活函数、优化器、添加准确度度量。
// 定义网络模型
const model = tf.sequential();
// 添加隐藏层
model.add(tf.layers.dense({
units: 10,
inputShape: [xTrain.shape[1]],
activation: 'relu'
}));
// 添加输出层
model.add(tf.layers.dense({
units: 3,
activation: 'softmax'
}));
// 配置模型
model.compile({
loss: "categoricalCrossentropy",
optimizer: tf.train.adam(0.1),
metrics: ['accuracy']
});
训练结果需要等待,所以采用异步方式训练。
await model.fit(xTrain, yTrain,{
epochs: 100,
batchSize: 32,
validationData: [xTest, yTest],
callbacks: tfvis.show.fitCallbacks(
{name: '训练效果'},
['loss', 'val_loss', 'acc', 'val_acc'],
{callbacks: ['onEpochEnd']}
)
});
训练结果:
编写前端界面输入待预测数据,使用训练好的模型进行预测,将输出的Tensor转成普通数据并显示。
在index.html中编写form表单,用来输入预测数据。
输入数据的顺序不能错,因为我们训练数据的顺序就是花萼长度、花萼宽度、花瓣长度、花瓣宽度。
在Script.js中编写predict预测函数。
window.predict = (form) => {
// 将表单获取的到数据转成Tensor
const input = tf.tensor([[
form.a.value * 1,
form.b.value * 1,
form.c.value * 1,
form.d.value * 1,
]]);
// 预测
const pred = model.predict(input);
alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`)
}
预测结果:gif动图有点模糊,可以自己动手试试看哦。