前端利用tensorflow实现动作捕捉
tensorflow介绍
官方模型https://www.tensorflow.org/js/models
里面可以看到姿势检测,这就是我们要实现动作捕捉关键的库
官方提供三种模型选项
https://github.com/tensorflow/tfjs-models/tree/master/pose-detection
可以自己点demo进行观看:
MoveNet Demo 提供17个点位 帧率可达50帧以上
BlazePose Demo 除了17个点位,还提供了面部手脚额外的点位,总共33个
PoseNet Demo 可检测多个姿势,每个提供17个点位
官方也提供了的相应的源码
源码中会使用stats和dat.gui,看起来会很乱
下面我抽离主要功能进行说明,移除stats性能监控、dat.gui操作这些逻辑
以PoseNet 为例,其他两个也是一样的,使用的差别就是在于用哪个模型而已
文字说明
主要逻辑
1 获取摄像头数据,也可以使用本地视频或者远端视频,总之就是一个video标签
2 需要一个画布canvas,把摄像头数据和点位画上去,一帧画一个,组合起来就是一个视频
3 连接模型,或者说是以模型创建探测器,因为他api叫createDetector
4 探测器创建完后,接着就是探测的对象
--因为是一个画面一个画面的探测的,所以需要先弄一个画布
--这个画布去画视频每一帧
-- 我们只需要把这画布canvas传进去就行
5 调用estimatePoses对画布进行分析,api字面意思是估计姿势,他会返回17个点位
6 拿到点位就可以进行你的业务逻辑了,现在我的业务就是把那些点按一定规律连起来,连起来的线术语叫骨骼
7 最后,一些变量要释放,防止内存泄漏
代码说明
引入依赖包
import { PoseDetector } from '@tensorflow-models/pose-detection';
import * as poseDetection from '@tensorflow-models/pose-detection';
import '@tensorflow/tfjs-backend-webgl';
虽然只使用了两个,但你需要安装这些依赖
"@mediapipe/pose": "^0.5.1635988162",
"@tensorflow-models/pose-detection": "^2.0.0",
"@tensorflow/tfjs-backend-webgl": "^4.1.0",
"@tensorflow/tfjs-converter": "^4.1.0",
"@tensorflow/tfjs-core": "^4.1.0",
html
顶级变量
let videoEl: HTMLVideoElement;
let canvasEl: HTMLCanvasElement;
let canvasCtx: CanvasRenderingContext2D;
let detector: PoseDetector;
let model = poseDetection.SupportedModels.PoseNet;
const DEFAULT_LINE_WIDTH = 2;
const DEFAULT_RADIUS = 4;
const SCORE_THRESHOLD = 0.5;
let requestID: any; // requestAnimationFrame
启动函数
const init = async () => {
// 获取dom
canvasEl = document.getElementById('output') as HTMLCanvasElement;
videoEl = document.getElementById('video') as HTMLVideoElement;
// 获取画布
canvasCtx = canvasEl.getContext('2d')!;
// 设置视频源,这里使用摄像头
const stream = await navigator.mediaDevices.getUserMedia({
audio: false,
video: true,
});
// 设置流
videoEl.srcObject = stream;
// 视频加载后执行
videoEl.onloadeddata = async function () {
// 下一步这里开始
};
};
video onload之后
videoEl.onloadeddata = async function () {
const { width, height } = videoEl.getBoundingClientRect();
canvasEl.width = width;
canvasEl.height = height;
// 加载模型,model 在顶级变量里已经设置为poseDetection.SupportedModels.PoseNet
detector = await poseDetection.createDetector(model, {
quantBytes: 4,
architecture: 'MobileNetV1',
outputStride: 16,
inputResolution: { width, height },
multiplier: 0.75,
});
// 开始检测
startDetect();
};
探测函数
// 开始检测
async function startDetect() {
const video = document.getElementById('video') as HTMLVideoElement;
// 检测画布动作
const poses = await detector.estimatePoses(canvasEl, {
flipHorizontal: false, // 是否水平翻转
maxPoses: 1, // 最大检测人数
// scoreThreshold: 0.5, // 置信度
// nmsRadius: 20, // 非极大值抑制
});
// 绘制视频
canvasCtx.drawImage(video, 0, 0, canvasEl.width, canvasEl.height);
// 画第一个人的姿势 poses[0]
// 画点
drawKeypoints(canvasCtx, poses[0].keypoints);
// 画骨骼
drawSkeleton(canvasCtx, poses[0].keypoints, poses.id);
// 一帧执行一次 可替换为setTimeout方案: setTimeout(()=>startDetect(),1000/16)
requestID = requestAnimationFrame(() => startDetect());
}
画点画线的函数drawKeypoints
drawSkeleton
// 画点
function drawKeypoints(ctx: CanvasRenderingContext2D, keypoints) {
// keypointInd 主要按left middle right 返回索引,left是单数索引,right是双数索引,打印一下你就知道了
const keypointInd = poseDetection.util.getKeypointIndexBySide(model);
ctx.strokeStyle = 'White';
ctx.lineWidth = DEFAULT_LINE_WIDTH;
ctx.fillStyle = 'Red';
for (const i of keypointInd.middle) {
drawKeypoint(keypoints[i]);
}
ctx.fillStyle = 'Green';
for (const i of keypointInd.left) {
drawKeypoint(keypoints[i]);
}
ctx.fillStyle = 'Orange';
for (const i of keypointInd.right) {
drawKeypoint(keypoints[i]);
}
}
function drawKeypoint(ctx: CanvasRenderingContext2D, keypoint) {
// If score is null, just show the keypoint.
const score = keypoint.score != null ? keypoint.score : 1;
if (score >= SCORE_THRESHOLD) {
const circle = new Path2D();
circle.arc(keypoint.x, keypoint.y, DEFAULT_RADIUS, 0, 2 * Math.PI);
ctx.fill(circle);
ctx.stroke(circle);
}
}
// 画骨架
function drawSkeleton(ctx: CanvasRenderingContext2D, keypoints: any, poseId?: any) {
// Each poseId is mapped to a color in the color palette.
const color = 'White';
ctx.fillStyle = color;
ctx.strokeStyle = color;
ctx.lineWidth = DEFAULT_LINE_WIDTH;
poseDetection.util.getAdjacentPairs(model).forEach(([i, j]) => {
const kp1 = keypoints[i];
const kp2 = keypoints[j];
// If score is null, just show the keypoint.
const score1 = kp1.score != null ? kp1.score : 1;
const score2 = kp2.score != null ? kp2.score : 1;
if (score1 >= SCORE_THRESHOLD && score2 >= SCORE_THRESHOLD) {
ctx.beginPath();
ctx.moveTo(kp1.x, kp1.y);
ctx.lineTo(kp2.x, kp2.y);
ctx.stroke();
}
});
}
准备完成,开始执行
onMounted(() => {
init();
});
页面离开前记得释放变量
onUnmounted(() => {
detector.dispose();
detector = null;
cancelAnimationFrame(requestID);
});