node使用tensorflow.js实现垃圾分类练习

t1.js 

const tf = require('@tensorflow/tfjs-node-gpu');

const getData = require('./data');
const TRAIN_PATH = './垃圾分类/垃圾分类/train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';

(async () => {
  const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT);

  //引入别人训练好的模型
  const mobilenet = await tf.loadLayersModel(MOBILENET_URL);

  //查看模型结构
  mobilenet.summary();

  const model = tf.sequential();

  //截断模型,复用了86个层
  for (let i = 0; i < 86; ++i) {
    const layer = mobilenet.layers[i];
    layer.trainable = false;
    model.add(layer);
  }

  //降维,摊平数据
  model.add(tf.layers.flatten());

  //设置全连接层
  model.add(tf.layers.dense({
    units: 10,
    activation: 'relu'//设置激活函数,用于处理非线性问题
  }));

  model.add(tf.layers.dense({
    units: classes.length,
    activation: 'softmax'//用于多分类问题
  }));

  //设置损失函数,优化器
  model.compile({
    loss: 'sparseCategoricalCrossentropy',
    optimizer: tf.train.adam(),
    metrics:['acc']
  });

  //训练模型
  // await model.fit(xs, ys, { epochs: 20 });
  await model.fitDataset(ds, { epochs: 20 });

  //保存模型
  await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();

data.js 

const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");

const img2x = (imgPath) => {
  const buffer = fs.readFileSync(imgPath);

  //清除数据
  return tf.tidy(() => {

    //把图片转成tensor
    const imgt = tf.node.decodeImage(new Uint8Array(buffer));

    //调整图片大小
    const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);

    //归一化
    return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
  });
}

const getData = async (traindir, output) => {
  let classes = fs.readdirSync(traindir, 'utf-8');
  classes = classes.slice(1);
  fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));

  // const inputs=[];
  // const labels=[];
  const data = [];
  classes.forEach((dir, dirIndex) => {
    fs.readdirSync(`${traindir}/${dir}`)
      .filter(n => n.match(/jpg$/))
      .slice(0, 1000)
      .forEach(filename => {
        const imgPath = `${traindir}/${dir}/${filename}`;

        data.push({ imgPath, dirIndex });

        // const buffer = fs.readFileSync(imgPath);
        // const x=img2x(buffer);
        // inputs.push(x);//图片tensor
        // labels.push(dirIndex);//对应的标签
      });
  });

  //打乱训练顺序,提高准确度
  tf.util.shuffle(data);

  const ds = tf.data.generator(function* () {
    const count = data.length;
    const batchSize = 32;
    for (let start = 0; start < count; start += batchSize) {
      const end = Math.min(start + batchSize, count);
      console.log('当前批次', start);
      yield tf.tidy(() => {
        const inputs = [];
        const labels = [];
        for (let j = start; j < end; ++j) {
          const { imgPath, dirIndex } = data[j];
          const x = img2x(imgPath);
          inputs.push(x);
          labels.push(dirIndex);
        }
        const xs = tf.concat(inputs);
        const ys = tf.tensor(labels);
        return { xs, ys };
      });
    }
  });

  //一维tensor数组转成高维tensor数组
  // const xs=tf.concat(inputs);
  // const ys=tf.tensor(labels);
  // return {xs,ys,classes};
  return { ds, classes };
}

module.exports = getData;

代码和训练图片下载链接 https://www.ljkanka.com/index/t6

你可能感兴趣的:(tensorflow.js笔记,nodejs,js笔记,javascript)