最终效果页面如下:
1,通过摄像头获取石头剪刀布的图片数据,各100张左右。
2,训练网络
3,开始预测,并显示预测结果。
一,新建三个文件
Rock Samples:
Paper Samples:
Scissors Samples:
Once training is complete, click 'Start Predicting' to see predictions, and 'Stop Predicting' to end
let mobilenet;
let model;
const webcam = new Webcam(document.getElementById('wc'));
const dataset = new RPSDataset();
var rockSamples=0, paperSamples=0, scissorsSamples=0;
let isPredicting = false;
async function loadMobilenet() {
const mobilenet = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
const layer = mobilenet.getLayer('conv_pw_13_relu');
return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}
async function train() {
dataset.ys = null;
dataset.encodeLabels(3);
model = tf.sequential({
layers: [
tf.layers.flatten({inputShape: mobilenet.outputs[0].shape.slice(1)}),
tf.layers.dense({ units: 100, activation: 'relu'}),
tf.layers.dense({ units: 3, activation: 'softmax'})
]
});
const optimizer = tf.train.adam(0.0001);
model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'});
let loss = 0;
model.fit(dataset.xs, dataset.ys, {
epochs: 10,
callbacks: {
onBatchEnd: async (batch, logs) => {
loss = logs.loss.toFixed(5);
console.log('LOSS: ' + loss);
}
}
});
}
function handleButton(elem){
switch(elem.id){
case "0":
rockSamples++;
document.getElementById("rocksamples").innerText = "Rock samples:" + rockSamples;
break;
case "1":
paperSamples++;
document.getElementById("papersamples").innerText = "Paper samples:" + paperSamples;
break;
case "2":
scissorsSamples++;
document.getElementById("scissorssamples").innerText = "Scissors samples:" + scissorsSamples;
break;
}
label = parseInt(elem.id);
const img = webcam.capture();
dataset.addExample(mobilenet.predict(img), label);
}
async function predict() {
while (isPredicting) {
const predictedClass = tf.tidy(() => {
const img = webcam.capture();
const activation = mobilenet.predict(img);
const predictions = model.predict(activation);
return predictions.as1D().argMax();
});
const classId = (await predictedClass.data())[0];
var predictionText = "";
switch(classId){
case 0:
predictionText = "I see Rock";
break;
case 1:
predictionText = "I see Paper";
break;
case 2:
predictionText = "I see Scissors";
break;
}
document.getElementById("prediction").innerText = predictionText;
predictedClass.dispose();
await tf.nextFrame();
}
}
function doTraining(){
train();
}
function startPredicting(){
isPredicting = true;
predict();
}
function stopPredicting(){
isPredicting = false;
predict();
}
async function init(){
await webcam.setup();
mobilenet = await loadMobilenet();
tf.tidy(() => mobilenet.predict(webcam.capture()));
}
init();
class RPSDataset {
constructor() {
this.labels = []
}
addExample(example, label) {
if (this.xs == null) {
this.xs = tf.keep(example);
this.labels.push(label);
} else {
const oldX = this.xs;
this.xs = tf.keep(oldX.concat(example, 0));
this.labels.push(label);
oldX.dispose();
}
}
encodeLabels(numClasses) {
for (var i = 0; i < this.labels.length; i++) {
if (this.ys == null) {
this.ys = tf.keep(tf.tidy(
() => {return tf.oneHot(
tf.tensor1d([this.labels[i]]).toInt(), numClasses)}));
} else {
const y = tf.tidy(
() => {return tf.oneHot(
tf.tensor1d([this.labels[i]]).toInt(), numClasses)});
const oldY = this.ys;
this.ys = tf.keep(oldY.concat(y, 0));
oldY.dispose();
y.dispose();
}
}
}
}
/**
* A class that wraps webcam video elements to capture Tensor4Ds.
*/
class Webcam {
/**
* @param {HTMLVideoElement} webcamElement A HTMLVideoElement representing the
* webcam feed.
*/
constructor(webcamElement) {
this.webcamElement = webcamElement;
}
/**
* Captures a frame from the webcam and normalizes it between -1 and 1.
* Returns a batched image (1-element batch) of shape [1, w, h, c].
*/
capture() {
return tf.tidy(() => {
// Reads the image as a Tensor from the webcam