好球还是坏球(棒球中术语),用tensorflow.js预测一下?

在这篇文章中,我们将使用TensorFlow.js,D3.js和网络的力量来可视化训练模型的过程,以预测棒球数据中的坏球(蓝色区域)和好球(橙色区域)。在整个训练过程中,我们将一步一步的将模型预测出的好球区域动态的展示出来。您可以通过访问Observable notebook网站在浏览器中运行此模型。

体育方面的高级指标

如今的职业体育环境里充满了大量的数据。这些数据被团队、业余爱好者和粉丝应用于各种案例。感谢像TensorFlow这样的框架,使得这些数据集可以应用于机器学习领域。

美国职业棒球大联盟高级媒体(MLBAM)发布了一个可供公众研究的大型数据集。该数据集包含有关过去几年在MLB游戏中投掷的投球的传感器信息。从这个数据集中挑选了一个包含5000个样本(2,500个坏球和2,500个好球)的训练集用于此处实验。

以下是训练数据的具体数据格式示例:

以下是绘制好球区域时的训练数据分布。蓝点被标记为坏球,橙点被标记为好球(标注来自大联盟裁判员)

使用TensorFlow.js构建模型

TensorFlow.js将机器学习带入JavaScript和Web领域。我们将使用这个优秀的框架来构建一个深度神经网络模型。这个模型将能够以大联盟裁判的精确度来区分好球和坏球。

该模型从PITCHf/x中选出以下评测指标进行训练:

  • 协调球越过本垒的位置('px'和'pz')
  • 击球手站在球场的哪一侧
  • 击球区(击球手的躯干)的高度,以英尺为单位。
  • 击球区底部的高度(击球手的膝盖)以英尺为单位
  • 该次击球是好球还是坏球(由裁判员判定的)

结构

我们将使用TensorFlow.js的Layers API定义此模型。Layers API基于Keras,对以前使用过Keras框架的人来说应该很熟悉:

const model = tf.sequential();

// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));

// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(0.01),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

加载和准备数据

精选的训练集可以在GitHub gist获取。该数据集是CSV格式的,需要下载下来在本地转换成符合TensorFlow.js的格式。

const data = [];
csvData.forEach((values) => {
  // 'logit' data uses the 5 fields:
  const x = [];
  x.push(parseFloat(values.px));
  x.push(parseFloat(values.pz));
  x.push(parseFloat(values.sz_top));
  x.push(parseFloat(values.sz_bot));
  x.push(parseFloat(values.left_handed_batter));
  // The label is simply 'is strike' or 'is ball':
  const y = parseInt(values.is_strike, 10);
  data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);

解析CSV数据后,需要将JS类型转换为Tensor batches才能进行训练和评估。有关此过程的详细信息,请参阅code lab。TensorFlow.js团队正在开发一种新的数据API接口,以便使数据获取在将来变得更容易。

训练模型

让我们把前期的准备都综合起来吧。定义好了模型,准备好了训练数据,现在我们将要开始训练了。以下的异步方法训练了一批训练样本并更新热图:

// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
  const history = await model.fit(batches[index].x, batches[index].y, {
    epochs: 1,
    shuffle: false,
    validationData: [batches[index].x, batches[index].y],
    batchSize: CONSTANTS.BATCH_SIZE
  });

  // Don't block the UI frame by using tf.nextFrame()
  await tf.nextFrame();
  updateHeatmap();
  await tf.nextFrame();
}

可视化模型的准确度

使用来自均匀放置在本垒板上方的 4英尺x4英尺 栅格的预测矩阵来构建热图。在每个训练步骤之后将该矩阵传递到模型中以检查模型的准确度。使用D3库将预测结果呈现为热图。

建立预测矩阵

热图中所使用的预测矩阵从本垒板的中间开始,向左和向右各延伸2英尺宽,高度从本垒板的底部到4英尺高。好球区域位于本垒板上方1.5至3.5英尺之间。下图在二维平面上呈现出各个矩阵之间的关系:

将预测矩阵与模型一起使用

当每个批次的训练数据都在模型中训练之后,我们将预测矩阵传递到模型中,这样就可以去预测好球和坏球了。

function predictZone() {
  const predictions = model.predictOnBatch(predictionMatrix.data);
  const values = predictions.dataSync();

  // Sort each value so the higher prediction is the first element in the array:
  const results = [];
  let index = 0;
  for (let i = 0; i < values.length; i++) {
    let list = [];
    list.push({value: values[index++], strike: 0});
    list.push({value: values[index++], strike: 1});
    list = list.sort((a, b) => b.value - a.value);
    results.push(list);
  }
  return results;
}

使用D3生成热图

我们可以使用D3来显示预测结果。50x50尺寸的每个元素在SVG中呈现为10px x 10px的矩形。每个矩形的颜色取决于预测结果(好球或坏球)以及模型对该结果的确定程度(从50%-100%)。以下代码段显示了如何使用D3 svg 矩形组去更新数据:

function updateHeatmap() {
  rects.data(generateHeatmapData());
  rects
    .attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
    .attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
    .attr('width', CONSTANTS.HEATMAP_SIZE)
    .attr('height', CONSTANTS.HEATMAP_SIZE)
    .style('fill', (coord) => {
      if (coord.strike) {
        return strikeColorScale(coord.value);
      } else {
        return ballColorScale(coord.value);
      }
  });
}

有关使用D3绘制热图的完整详细信息,请参阅此部分。

总结

如今web前端有许多令人惊叹的库和工具来创建可视化视觉效果。把这些与机器学习的强大功能和TensorFlow.js相结合,可以使开发人员创建一些非常有趣的demo。

注:本文为译文,点击此处预览原文

你可能感兴趣的:(神经网络,深度学习,人工智能,机器学习,tensorflow)