gaze-estimation问题主要是数据集标注及其困难,针对最近开源的一个数据集,实验回归的方式进行了训练。
整个项目源码:https://github.com/ycdhqzhiai/Gaze-PFLD
使用这个数据集:TEyeD: Over 20 million real-world eye images with Pupil, Eyelid, and Iris 2D and 3D Segmentations, 2D and 3D Landmarks, 3D Eyeball, Gaze Vector, and Eye Movement Types
import os
import cv2
import glob
import numpy as np
import argparse
import json
##注意该代码只能存9999张图片,如果按帧率为30计算,大概最多只能存55分钟视频
def parse_args():
parser = argparse.ArgumentParser(description="EyeGaze datasets")
parser.add_argument("--video_path", type=str, default='DIKABLISVIDEOS', help='videos path')
parser.add_argument("--annotations",type=str, default='ANNOTATIONS', help='videos label path including gaze_vec iris_lm_2D lid_lm_2D pupil_lm_2D')
parser.add_argument("--images",type=str, default='images', help='save_path')
parser.add_argument("--draw_img",type=str, default='draw_img', help='save_path')
parser.add_argument("--blind",type=str, default='blind', help='save_path')
parser.add_argument("--json",type=str, default='json', help='save_path')
args = parser.parse_args()
return args
def mkd(path):
if not os.path.exists(path):
os.makedirs(path)
def judge_exists(path):
if os.path.exists(path):
return False
return True
def log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements):
b1 = judge_exists(agaze_vec)
b2 = judge_exists(airis_lm_2D)
b3 = judge_exists(alid_lm_2D)
b4 = judge_exists(apupil_lm_2D)
b5 = judge_exists(aeye_movements)
if b1:
print('gaze_vec not found!!! EXIT')
if b2:
print('iris_lm_2D not found!!! EXIT')
if b3:
print('lid_lm_2D not found!!! EXIT')
if b4:
print('pupil_lm_2D not found!!! EXIT')
if b5:
print('eye_movements not found!!! EXIT')
if b1 or b2 or b3 or b4 or b5:
return False
return True
def main():
args = parse_args()
video_list = glob.glob(os.path.join(args.video_path, '*.mp4'))
for video in video_list:
name = os.path.split(video)[1]
# if not '5_2' in name:
# continue
images_dir = os.path.join(args.images, name)
draw_img_dir = os.path.join(args.draw_img, name)
blind_dir = os.path.join(args.blind, name)
json_dir = os.path.join(args.json, name)
mkd(images_dir)
mkd(draw_img_dir)
mkd(blind_dir)
mkd(json_dir)
agaze_vec = os.path.join(args.annotations, name+'gaze_vec.txt')
airis_lm_2D = os.path.join(args.annotations, name+'iris_lm_2D.txt')
alid_lm_2D = os.path.join(args.annotations, name+'lid_lm_2D.txt')
apupil_lm_2D = os.path.join(args.annotations, name+'pupil_lm_2D.txt')
aeye_movements = os.path.join(args.annotations, name+'eye_movements.txt')
flage = log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements)
if not flage:
exit()
with open(agaze_vec, 'r') as fgaze_vec:
lgaze_vec = fgaze_vec.readlines()[1:]
with open(airis_lm_2D, 'r') as firis_lm_2D:
liris_lm_2D = firis_lm_2D.readlines()[1:]
with open(alid_lm_2D, 'r') as flid_lm_2D:
llid_lm_2D = flid_lm_2D.readlines()[1:]
with open(apupil_lm_2D, 'r') as fpupil_lm_2D:
lpupil_lm_2D = fpupil_lm_2D.readlines()[1:]
with open(aeye_movements, 'r') as feye_movements:
leye_movements = feye_movements.readlines()[3:]
cap = cv2.VideoCapture(video)
num = 0
while 1:
ret, frame = cap.read()
if not ret:
break
src = frame.copy()
save_src = '{}/{}_{:0>5d}.jpg'.format(images_dir, name[:-4], num)
save_draw = '{}/{}_{:0>5d}.jpg'.format(draw_img_dir, name[:-4], num)
save_blind = '{}/{}_{:0>5d}.jpg'.format(blind_dir, name[:-4], num)
save_json = '{}/{}_{:0>5d}.json'.format(json_dir, name[:-4], num)
eye_movements = leye_movements[num].strip()[2:3]
gaze_vec = np.array([float(x) for x in lgaze_vec[num].strip().split(';')[1:3]])
iris_lm_2D = np.array([float(x) for x in liris_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#虹膜,中间那块
lid_lm_2D = np.array([float(x) for x in llid_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#眼睑,最外面那块
pupil_lm_2D = np.array([float(x) for x in lpupil_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#瞳孔,最里面那块
num += 1
if eye_movements == '1':
continue
eye_c = np.mean(pupil_lm_2D, axis=0).astype(int)
for index in range(iris_lm_2D.shape[0]):
x_y = iris_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,255,0),-1) # 绿色
for index in range(lid_lm_2D.shape[0]):
x_y = lid_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (255,0,0),-1) # 蓝色
for index in range(pupil_lm_2D.shape[0]):
x_y = pupil_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,0,255),-1) # 红色
cv2.circle(frame, tuple(eye_c), 1, (255,255,255),-1)
cv2.line(frame, tuple(eye_c), tuple(eye_c+(gaze_vec*100).astype(int)), (0,255,255), 1) # 黄色
label_dict = {
'gaze_vec':gaze_vec.tolist(), 'iris_lm_2D':iris_lm_2D.tolist(), 'lid_lm_2D':lid_lm_2D.tolist(), 'pupil_lm_2D':pupil_lm_2D.tolist()}
if -1 in gaze_vec:
cv2.imwrite(save_blind, frame)
with open(save_json.replace('json\\', 'blind\\'), 'w') as dump_f:
json.dump(label_dict,dump_f)
else:
if num % 3 == 0:
cv2.imwrite(save_src, src)
with open(save_json, 'w') as dump_f:
json.dump(label_dict,dump_f)
cv2.imwrite(save_draw, frame)
if __name__ == '__main__':
main()
使用PFLD来训练gaze-estimation,PFLDInference骨干网络用来预测landmarks,AuxiliaryNet网络用来预测gaze-vector。
def preprocess_unityeyes_image(img, json_data, datasets, input_width, input_height):
ow = 160
oh = 96
# Prepare to segment eye image
ih, iw = img.shape[:2]
ih_2, iw_2 = ih/2.0, iw/2.0
heatmap_w = int(ow/2)
heatmap_h = int(oh/2)
#img = cv2.resize(im, (im.shape[1]*3, im.shape[0]*3))
#img = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
if datasets == 'B':
gaze = np.array(json_data['gaze'])
landmarks = np.array(json_data['landmarks'])
left_corner = landmarks[0]
right_corner = landmarks[4]
eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
eye_middle = landmarks[24].astype(int)
elif datasets == 'E':
gaze = np.array(json_data['gaze_vec'])
left_corner = np.array(json_data['lid_lm_2D'])[0]
right_corner = np.array(json_data['lid_lm_2D'])[33]
eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
eye_middle = np.mean([np.amin(np.array(json_data['iris_lm_2D']), axis=0), np.amax(np.array(json_data['iris_lm_2D']), axis=0)], axis=0)
landmarks = np.concatenate((np.array(json_data['lid_lm_2D']), np.array(json_data['iris_lm_2D']), np.array(json_data['pupil_lm_2D']), eye_middle.reshape(1,2)))
else:
print('UnityEyes do not write!!!')
exit()
crop_img, lad = get_img(img, landmarks)
crop_img = cv2.resize(crop_img, (input_width,input_height))
# if 1:
# print(crop_img.shape)
# for (x, y) in lad:
# color = (0, 255, 0)
# cv2.circle(crop_img, (int(round(x*crop_img.shape[1])), int(round(y*crop_img.shape[0]))), 1, color, -1, lineType=cv2.LINE_AA)
# #crop_img = cv2.resize(crop_img, (160,96))
# cv2.imshow('c', crop_img)
# cv2.waitKey(0)
# exit()
return crop_img, lad, gaze
class EyesDataset(data.Dataset):
def __init__(self, datasets, dataroot, transforms=None, input_width=160, input_height=112):
self.dataroot = dataroot
self.datasets = datasets
self.input_width = input_width
self.input_height = input_height
self.transforms = transforms
if datasets == 'U':
self.img_paths = glob.glob(os.path.join(dataroot, 'UnityEyes/images', '/*.jpg'))
elif datasets == 'E':
self.img_paths = glob.glob(os.path.join(dataroot, 'Eye200W/images', '/*.jpg'))
elif datasets == 'B':
self.img_paths = glob.glob(os.path.join(dataroot, 'BL_Eye/images', '/*.jpg'))
self.img_paths = sorted(self.img_paths)
self.json_paths = []
for img_path in self.img_paths:
json_files = img_path.replace('images', 'json').replace('.jpg', '.json')
self.json_paths.append(json_files)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
full_img = cv2.imread(self.img_paths[index])
with open(self.json_paths[index]) as f:
json_data = json.load(f)
eye, landmarks, gaze = preprocess_unityeyes_image(full_img, json_data, self.datasets, self.input_width, self.input_height)
if self.transforms:
eye = self.transforms(eye)
return eye, landmarks, gaze
def __len__(self):
return len(self.img_paths)
class Gaze_PFLD(nn.Module):
def __init__(self):
super(Gaze_PFLD, self).__init__()
self.lad = PFLDInference()
self.gaze = AuxiliaryNet()
def forward(self, x):
features, landmark = self.lad(x)
gaze = self.gaze(features)
return landmark, gaze
class PFLDLoss(nn.Module):
def __init__(self):
super(PFLDLoss, self).__init__()
self.gaze_loss = nn.MSELoss()
def forward(self, landmark_gt,
landmarks, gaze_pred, gaze):
lad_loss = wing_loss(landmark_gt, landmarks)
gaze_loss = self.gaze_loss(gaze_pred, gaze)
return gaze_loss*1000, lad_loss
def wing_loss(y_true, y_pred, w=10.0, epsilon=2.0, N_LANDMARK=51):
y_pred = y_pred.reshape(-1, N_LANDMARK, 2)
y_true = y_true.reshape(-1, N_LANDMARK, 2)
x = y_true - y_pred
c = w * (1.0 - math.log(1.0 + w / epsilon))
absolute_x = torch.abs(x)
losses = torch.where(w > absolute_x,
w * torch.log(1.0 + absolute_x / epsilon),
absolute_x - c)
loss = torch.mean(torch.sum(losses, axis=[1, 2]), axis=0)
return loss
import argparse
import numpy as np
import cv2
import torch
import torchvision
from models.pfld import PFLDInference, AuxiliaryNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
checkpoint = torch.load(args.model_path, map_location=device)
print(checkpoint.keys())
pfld_backbone = PFLDInference().to(device)
auxiliarynet = AuxiliaryNet().to(device)
pfld_backbone.load_state_dict(checkpoint['pfld_backbone'])
auxiliarynet.load_state_dict(checkpoint["auxiliarynet"])
pfld_backbone.eval()
auxiliarynet.eval()
pfld_backbone = pfld_backbone.to(device)
auxiliarynet = auxiliarynet.to(device)
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()])
img = cv2.imread('5.png')
img = cv2.resize(img, (img.shape[1]*1, img.shape[0]*1))
height, width = img.shape[:2]
input = cv2.resize(img, (160,112))
input = transform(input).unsqueeze(0).to(device)
features, landmarks = pfld_backbone(input)
gaze = auxiliarynet(features)
pre_landmark = landmarks[0]
#print(pre_landmark.shape)
pre_landmark = pre_landmark.cpu().detach().numpy().reshape(
-1, 2) * [width, height]
gaze = gaze.cpu().detach().numpy()[0]
c_pos = pre_landmark[-1,:]
cv2.line(img, tuple(c_pos.astype(int)), tuple(c_pos.astype(int)+(gaze*400).astype(int)), (0,255,0), 1)
for (x, y) in pre_landmark.astype(np.int32):
cv2.circle(img, (x, y), 1, (0, 0, 255))
cv2.imshow('gaze estimation', img)
cv2.imwrite('gaze.jpg', img)
cv2.waitKey(0)
def parse_args():
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--model_path',
default="./checkpoint/snapshot/checkpoint_epoch_13.pth.tar",
type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
# from __future__ import absolute_import
# from __future__ import division
# from __future__ import print_function
import argparse
import sys
import time
from models.pfld import Gaze_PFLD
import torch
import torch.nn as nn
import models
# def load_model_weight(model, checkpoint):
# state_dict = checkpoint['model_state_dict']
# # strip prefix of state_dict
# if list(state_dict.keys())[0].startswith('module.'):
# state_dict = {k[7:]: v for k, v in checkpoint['model_state_dict'].items()}
# model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
# # check loaded parameters and created model parameters
# for k in state_dict:
# if k in model_state_dict:
# if state_dict[k].shape != model_state_dict[k].shape:
# print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
# k, model_state_dict[k].shape, state_dict[k].shape))
# state_dict[k] = model_state_dict[k]
# else:
# print('Drop parameter {}.'.format(k))
# for k in model_state_dict:
# if not (k in state_dict):
# print('No param {}.'.format(k))
# state_dict[k] = model_state_dict[k]
# model.load_state_dict(state_dict, strict=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default="./checkpoint/snapshot/checkpoint.pth.tar", help='weights path') # from yolov5/models/
parser.add_argument('--img-size', nargs='+', type=int, default=[112, 160], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
device = "cpu"
print("=====> load pytorch checkpoint...")
checkpoint = torch.load(opt.weights, map_location=torch.device('cpu'))
nstack = checkpoint['nstack']
nfeatures = checkpoint['nfeatures']
nlandmarks = checkpoint['nlandmarks']
net = Gaze_PFLD().to(device)
net.load_state_dict(checkpoint['gaze_pfld'])
img = torch.zeros(1, 1, *opt.img_size).to(device)
print(img.shape)
landmarks, gaze = net.forward(img)
f = opt.weights.replace('.pth.tar', '.onnx') # filename
torch.onnx.export(net, img, f,export_params=True, verbose=False, opset_version=12, input_names=['inputs'])
# # ONNX export
try:
import onnx
from onnxsim import simplify
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pth.tar', '.onnx') # filename
torch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],
output_names=['output'])
# Checks
onnx_model = onnx.load(f) # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, f)
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)