JavaScript之机器学习5:Tensorflow.js 多分类任务

案例:鸢尾花(iris)分类
操作步骤

  1. 加载IRIS数据集(训练集与验证集)
  2. 定义模型结构:带有softmax的多层神经网络
    • 初始化一个神经网络模型
    • 为神经网络模型添加两个层
    • 设计层的神经元个数,inputShape,激活函数
  3. 训练模型并预测
    • 交叉熵损失函数与准确度度量

JavaScript之机器学习5:Tensorflow.js 多分类任务_第1张图片
JavaScript之机器学习5:Tensorflow.js 多分类任务_第2张图片
主要示例代码:

 <!-- index.html -->
<form action="" onsubmit="predict(this); return false;">
    花萼长度:<input type="text" name="a"><br>
    花萼宽度:<input type="text" name="b"><br>
    花瓣长度:<input type="text" name="c"><br>
    花瓣宽度:<input type="text" name="d"><br>
    <button type="submit">预测</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getIrisData, IRIS_CLASSES } from './data';

window.onload = async() => {
    //分别代表训练集和验证集的特征和标签  
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);  // 15%的数据用于验证集
    // xTrain.print();
    // yTrain.print();
    // xTest.print();
    // yTest.print();
    // console.log(IRIS_CLASSES);
    // 定义模型结构
    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape:[xTrain.shape[1]], // 特征长度:4
        activation: 'sigmoid'
    }));
    model.add(tf.layers.dense({
        units: 3,
        activation:'softmax'
    }));

    model.compile({
        loss:'categoricalCrossentropy',
        optimizer: tf.train.adam(0.1),
        metrics: ['accuracy']
    });

    await model.fit(xTrain, yTrain, {
        epochs: 100,
        validationData: [xTest, yTest],
        callbacks: tfvis.show.fitCallbacks(
            {name:'训练效果'},
            ['loss','val_loss','acc','val_acc'],
            {callbacks:['onEpochEnd']}
        )
    });

    window.predict = (form) => {
        const input = tf.tensor([[
            form.a.value *1,
            form.b.value *1,
            form.c.value *1,
            form.d.value *1,
        ]]);
        const pred = model.predict(input);
        alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`)
    }
};

你可能感兴趣的:(JavaScript,tensorflow,机器学习,神经网络)