tensorflow官方有个姿态估计项目,这个输入和openpose还有点不一样,这里写个单人情况下的模型输出解析方案。
国际惯例,参考博客:
博客: 使用 TensorFlow.js 在浏览器端上实现实时人体姿势检测
tensorflow中posnet的IOS代码
不要下载官方overview网址下的posenet模型multi_person_mobilenet_v1_075_float.tflite
,要去下载IOS端的posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite
模型,在github上一搜有一堆,文末放网盘下载地址。
先载入必要的工具包:
import numpy as np
import tensorflow as tf
import cv2 as cv
import matplotlib.pyplot as plt
import time
使用tflite载入模型文件
model = tf.lite.Interpreter('posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite')
model.allocate_tensors()
input_details = model.get_input_details()
output_details = model.get_output_details()
看看输入输出分别是什么
print(input_details)
print(output_details)
'''
[{'name': 'sub_2', 'index': 93, 'shape': array([ 1, 257, 257, 3], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}]
[{'name': 'MobilenetV1/heatmap_2/BiasAdd', 'index': 87, 'shape': array([ 1, 9, 9, 17], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/offset_2/BiasAdd', 'index': 90, 'shape': array([ 1, 9, 9, 34], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/displacement_fwd_2/BiasAdd', 'index': 84, 'shape': array([ 1, 9, 9, 32], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/displacement_bwd_2/BiasAdd', 'index': 81, 'shape': array([ 1, 9, 9, 32], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}]
'''
很容易看出输入是(257,257)尺寸的彩色图像。
输出就比较麻烦了,有两块:(9,9,17)的称为heatmap
的热度图;(9,9,34)的称为offset的偏移图。其实想想也能知道,热度图定位关节的大概位置,用偏移图做进一步的矫正。接下来逐步分析怎么利用这两个输出将关节位置定位的。
必须将图像resize一下再丢进去,但是tensorflowjs里面说不用resize的方法,我还没试过。
img = cv.imread('../../photo/1.jpeg')
input_img = tf.reshape(tf.image.resize(img, [257,257]), [1,257,257,3])
floating_model = input_details[0]['dtype'] == np.float32
if floating_model:
input_img = (np.float32(input_img) - 127.5) / 127.5
model.set_tensor(input_details[0]['index'], input_img)
start = time.time()
model.invoke()
print('time:',time.time()-start)
output_data = model.get_tensor(output_details[0]['index'])
offset_data = model.get_tensor(output_details[1]['index'])
heatmaps = np.squeeze(output_data)
offsets = np.squeeze(offset_data)
print("output shape: {}".format(output_data.shape))
'''
time: 0.12212681770324707
output shape: (1, 9, 9, 17)
'''
可视化变换后的图
show_img = np.squeeze((input_img.copy()*127.5+127.5)/255.0)[:,:,::-1]
show_img = np.array(show_img*255,np.uint8)
plt.imshow(show_img)
plt.axis('off')
一句话概括原理:热度图将图像划分网格,每个网格的得分代表当前关节在此网格点附近的概率;偏移图代表xy两个坐标相对于网格点的偏移情况。
假设提取第2个关节的坐标位置:
先得到最可能的网格点:
i=1
joint_heatmap = heatmaps[...,i]
max_val_pos = np.squeeze(np.argwhere(joint_heatmap==np.max(joint_heatmap)))
remap_pos = np.array(max_val_pos/8*257,dtype=np.int32)
把offset
加上去,前1-17是x坐标偏移,后18-34是y坐标偏移
refine_pos = np.zeros((2),dtype=int)
refine_pos[0] = int(remap_pos[0] + offsets[max_val_pos[0],max_val_pos[1],i])
refine_pos[1] = int(remap_pos[1] + offsets[max_val_pos[0],max_val_pos[1],i+heatmaps.shape[-1]])
可视化看看
show_img = np.squeeze((input_img.copy()*127.5+127.5)/255.0)[:,:,::-1]
show_img = np.array(show_img*255,np.uint8)
plt.figure(figsize=(8,8))
plt.imshow(cv.circle(show_img,(refine_pos[1],refine_pos[0]),2,(0,255,0),-1))
因为上面是把原图resize乘(257,257)以后的坐标,所以根据原图的缩放系数,重新映射回去
ratio_x = img.shape[0]/257
ratio_y = img.shape[1]/257
refine_pos[0]=refine_pos[0]*ratio_x
refine_pos[1]=refine_pos[1]*ratio_y
可视化
show_img1 = img[:,:,::-1]
plt.figure(figsize=(8,8))
plt.imshow(cv.circle(show_img1.copy(),(refine_pos[1],refine_pos[0]),2,(0,255,0),-1))
上面是提取单个关节的,写成函数提取所有关节的坐标就是
def parse_output(heatmap_data,offset_data):
joint_num = heatmap_data.shape[-1]
pose_kps = np.zeros((joint_num,2),np.uint8)
for i in range(heatmap_data.shape[-1]):
joint_heatmap = heatmap_data[...,i]
max_val_pos = np.squeeze(np.argwhere(joint_heatmap==np.max(joint_heatmap)))
remap_pos = np.array(max_val_pos/8*257,dtype=np.int32)
pose_kps[i,0] = int(remap_pos[0] + offset_data[max_val_pos[0],max_val_pos[1],i])
pose_kps[i,1] = int(remap_pos[1] + offset_data[max_val_pos[0],max_val_pos[1],i+joint_num])
return pose_kps
画图的函数也很容易
def draw_kps(show_img,kps):
for i in range(kps.shape[0]):
cv.circle(show_img,(kps[i,1],kps[i,0]),2,(0,255,0),-1)
return show_img
画出来瞅瞅
kps = parse_output(heatmaps,offsets)
plt.figure(figsize=(8,8))
plt.imshow(draw_kps(show_img.copy(),kps))
plt.axis('off')
模型文件:链接:https://pan.baidu.com/s/1heRKFFz28yvpAmvFqDeAXw 密码:5tuw
博客代码:链接:https://pan.baidu.com/s/1Y7WXfQ4WC9QyOGkkN2-kUQ 密码:ono0