视线估计Gaze-Estimation PFLD实现

视线估计Gaze-Estimation PFLD实现



使用这个数据集:TEyeD: Over 20 million real-world eye images with Pupil, Eyelid, and Iris 2D and 3D Segmentations, 2D and 3D Landmarks, 3D Eyeball, Gaze Vector, and Eye Movement Types

  • 数据集预处理
import os
import cv2
import glob
import numpy as np
import argparse
import json

def parse_args():
    parser = argparse.ArgumentParser(description="EyeGaze datasets")
    parser.add_argument("--video_path", type=str, default='DIKABLISVIDEOS', help='videos path')
    parser.add_argument("--annotations",type=str,  default='ANNOTATIONS', help='videos label path including gaze_vec iris_lm_2D lid_lm_2D pupil_lm_2D')
    parser.add_argument("--images",type=str,  default='images', help='save_path')
    parser.add_argument("--draw_img",type=str,  default='draw_img', help='save_path')
    parser.add_argument("--blind",type=str,  default='blind', help='save_path')
    parser.add_argument("--json",type=str,  default='json', help='save_path')
    args = parser.parse_args()
    return args

def mkd(path):
    if not os.path.exists(path):

def judge_exists(path):
    if os.path.exists(path):
        return False
    return True

def log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements):
    b1 = judge_exists(agaze_vec)
    b2 = judge_exists(airis_lm_2D)
    b3 = judge_exists(alid_lm_2D)
    b4 = judge_exists(apupil_lm_2D)
    b5 = judge_exists(aeye_movements)
    if b1:
        print('gaze_vec not found!!! EXIT')
    if b2:
        print('iris_lm_2D not found!!! EXIT')
    if b3:
        print('lid_lm_2D not found!!! EXIT')
    if b4:
        print('pupil_lm_2D not found!!! EXIT')
    if b5:
        print('eye_movements not found!!! EXIT')
    if b1 or b2 or b3 or b4 or b5:
        return False
    return True

def main():
    args = parse_args()
    video_list = glob.glob(os.path.join(args.video_path, '*.mp4'))
    for video in video_list:
        name = os.path.split(video)[1]
        # if not '5_2' in name:
        #     continue
        images_dir = os.path.join(args.images, name)
        draw_img_dir = os.path.join(args.draw_img, name)
        blind_dir = os.path.join(args.blind, name)
        json_dir = os.path.join(args.json, name)

        agaze_vec = os.path.join(args.annotations, name+'gaze_vec.txt')
        airis_lm_2D = os.path.join(args.annotations, name+'iris_lm_2D.txt')
        alid_lm_2D = os.path.join(args.annotations, name+'lid_lm_2D.txt')
        apupil_lm_2D = os.path.join(args.annotations, name+'pupil_lm_2D.txt')
        aeye_movements = os.path.join(args.annotations, name+'eye_movements.txt')

        flage = log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements)
        if not flage:

        with open(agaze_vec, 'r') as fgaze_vec:
            lgaze_vec = fgaze_vec.readlines()[1:]
        with open(airis_lm_2D, 'r') as firis_lm_2D:
            liris_lm_2D = firis_lm_2D.readlines()[1:]
        with open(alid_lm_2D, 'r') as flid_lm_2D:
            llid_lm_2D = flid_lm_2D.readlines()[1:]
        with open(apupil_lm_2D, 'r') as fpupil_lm_2D:
            lpupil_lm_2D = fpupil_lm_2D.readlines()[1:]
        with open(aeye_movements, 'r') as feye_movements:
            leye_movements = feye_movements.readlines()[3:]

        cap = cv2.VideoCapture(video)
        num = 0
        while 1:
            ret, frame = cap.read()
            if not ret:
            src = frame.copy()
            save_src    = '{}/{}_{:0>5d}.jpg'.format(images_dir, name[:-4], num)
            save_draw   = '{}/{}_{:0>5d}.jpg'.format(draw_img_dir, name[:-4], num)
            save_blind  = '{}/{}_{:0>5d}.jpg'.format(blind_dir, name[:-4], num)
            save_json   = '{}/{}_{:0>5d}.json'.format(json_dir, name[:-4], num)

            eye_movements = leye_movements[num].strip()[2:3]
            gaze_vec    = np.array([float(x) for x in lgaze_vec[num].strip().split(';')[1:3]])
            iris_lm_2D  = np.array([float(x) for x in liris_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#虹膜,中间那块
            lid_lm_2D   = np.array([float(x) for x in llid_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#眼睑,最外面那块
            pupil_lm_2D   = np.array([float(x) for x in lpupil_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#瞳孔,最里面那块
            num += 1

            if eye_movements == '1':
            eye_c = np.mean(pupil_lm_2D, axis=0).astype(int)
            for index in range(iris_lm_2D.shape[0]):
                x_y = iris_lm_2D[index]
                cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,255,0),-1) # 绿色

            for index in range(lid_lm_2D.shape[0]):
                x_y = lid_lm_2D[index]
                cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (255,0,0),-1) # 蓝色
            for index in range(pupil_lm_2D.shape[0]):
                x_y = pupil_lm_2D[index]
                cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,0,255),-1) # 红色
            cv2.circle(frame, tuple(eye_c), 1, (255,255,255),-1)
            cv2.line(frame, tuple(eye_c), tuple(eye_c+(gaze_vec*100).astype(int)), (0,255,255), 1) # 黄色
            label_dict = {
     'gaze_vec':gaze_vec.tolist(), 'iris_lm_2D':iris_lm_2D.tolist(), 'lid_lm_2D':lid_lm_2D.tolist(), 'pupil_lm_2D':pupil_lm_2D.tolist()}

            if -1 in gaze_vec:
                cv2.imwrite(save_blind, frame)
                with open(save_json.replace('json\\', 'blind\\'), 'w') as dump_f:
                if num % 3 == 0:
                    cv2.imwrite(save_src, src)
                    with open(save_json, 'w') as dump_f:
                    cv2.imwrite(save_draw, frame)
if __name__ == '__main__':



  • dataloder
def preprocess_unityeyes_image(img, json_data, datasets, input_width, input_height):
    ow = 160
    oh = 96
    # Prepare to segment eye image
    ih, iw = img.shape[:2]
    ih_2, iw_2 = ih/2.0, iw/2.0

    heatmap_w = int(ow/2)
    heatmap_h = int(oh/2)
    #img = cv2.resize(im, (im.shape[1]*3, im.shape[0]*3))
    #img = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)

    if datasets == 'B':
        gaze    = np.array(json_data['gaze'])
        landmarks  = np.array(json_data['landmarks'])
        left_corner = landmarks[0]
        right_corner = landmarks[4]
        eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
        eye_middle =  landmarks[24].astype(int)
    elif datasets == 'E':
        gaze    = np.array(json_data['gaze_vec'])
        left_corner = np.array(json_data['lid_lm_2D'])[0]
        right_corner = np.array(json_data['lid_lm_2D'])[33]
        eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
        eye_middle =  np.mean([np.amin(np.array(json_data['iris_lm_2D']), axis=0), np.amax(np.array(json_data['iris_lm_2D']), axis=0)], axis=0)
        landmarks  = np.concatenate((np.array(json_data['lid_lm_2D']), np.array(json_data['iris_lm_2D']), np.array(json_data['pupil_lm_2D']), eye_middle.reshape(1,2)))
        print('UnityEyes do not write!!!')
    crop_img, lad = get_img(img, landmarks)

    crop_img = cv2.resize(crop_img, (input_width,input_height))
    # if 1:
    #     print(crop_img.shape)
    #     for (x, y) in lad:
    #         color = (0, 255, 0)
    #         cv2.circle(crop_img, (int(round(x*crop_img.shape[1])), int(round(y*crop_img.shape[0]))), 1, color, -1, lineType=cv2.LINE_AA)

    #     #crop_img = cv2.resize(crop_img, (160,96))
    #     cv2.imshow('c', crop_img)
    #     cv2.waitKey(0)
    #     exit()
    return crop_img, lad, gaze

class EyesDataset(data.Dataset):
    def __init__(self, datasets, dataroot, transforms=None, input_width=160, input_height=112):
        self.dataroot = dataroot
        self.datasets = datasets
        self.input_width = input_width
        self.input_height = input_height
        self.transforms = transforms
        if datasets == 'U':
            self.img_paths = glob.glob(os.path.join(dataroot, 'UnityEyes/images', '/*.jpg'))
        elif datasets == 'E':
            self.img_paths = glob.glob(os.path.join(dataroot, 'Eye200W/images', '/*.jpg'))
        elif datasets == 'B':
            self.img_paths = glob.glob(os.path.join(dataroot, 'BL_Eye/images', '/*.jpg'))
        self.img_paths = sorted(self.img_paths)
        self.json_paths = []
        for img_path in self.img_paths:
            json_files = img_path.replace('images', 'json').replace('.jpg', '.json')

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        full_img = cv2.imread(self.img_paths[index])
        with open(self.json_paths[index]) as f:
            json_data = json.load(f)
        eye, landmarks, gaze = preprocess_unityeyes_image(full_img, json_data, self.datasets, self.input_width, self.input_height)
        if self.transforms:
            eye = self.transforms(eye)
        return eye, landmarks, gaze
    def __len__(self):
        return len(self.img_paths)
  • model
class Gaze_PFLD(nn.Module):
    def __init__(self):
        super(Gaze_PFLD, self).__init__()
        self.lad = PFLDInference()
        self.gaze = AuxiliaryNet()
    def forward(self, x):
        features, landmark = self.lad(x)
        gaze = self.gaze(features)
        return landmark, gaze
  • loss
class PFLDLoss(nn.Module):
    def __init__(self):
        super(PFLDLoss, self).__init__()
        self.gaze_loss = nn.MSELoss()
    def forward(self, landmark_gt, 
                landmarks, gaze_pred, gaze):
        lad_loss = wing_loss(landmark_gt, landmarks)
        gaze_loss = self.gaze_loss(gaze_pred, gaze)
        return gaze_loss*1000, lad_loss
def wing_loss(y_true, y_pred, w=10.0, epsilon=2.0, N_LANDMARK=51):
    y_pred = y_pred.reshape(-1, N_LANDMARK, 2)
    y_true = y_true.reshape(-1, N_LANDMARK, 2)

    x = y_true - y_pred
    c = w * (1.0 - math.log(1.0 + w / epsilon))
    absolute_x = torch.abs(x)
    losses = torch.where(w > absolute_x,
                         w * torch.log(1.0 + absolute_x / epsilon),
                         absolute_x - c)
    loss = torch.mean(torch.sum(losses, axis=[1, 2]), axis=0)
    return loss


import argparse
import numpy as np
import cv2
import torch
import torchvision
from models.pfld import PFLDInference, AuxiliaryNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
    checkpoint = torch.load(args.model_path, map_location=device)
    pfld_backbone = PFLDInference().to(device)
    auxiliarynet = AuxiliaryNet().to(device)



    pfld_backbone = pfld_backbone.to(device)
    auxiliarynet = auxiliarynet.to(device)
    transform = torchvision.transforms.Compose(

    img = cv2.imread('5.png')
    img = cv2.resize(img, (img.shape[1]*1, img.shape[0]*1))
    height, width = img.shape[:2]

    input = cv2.resize(img, (160,112))
    input = transform(input).unsqueeze(0).to(device)
    features, landmarks = pfld_backbone(input)
    gaze = auxiliarynet(features) 

    pre_landmark = landmarks[0]
    pre_landmark = pre_landmark.cpu().detach().numpy().reshape(
        -1, 2) * [width, height]

    gaze = gaze.cpu().detach().numpy()[0]

    c_pos = pre_landmark[-1,:]

    cv2.line(img, tuple(c_pos.astype(int)), tuple(c_pos.astype(int)+(gaze*400).astype(int)), (0,255,0), 1)
    for (x, y) in pre_landmark.astype(np.int32):
        cv2.circle(img, (x, y), 1, (0, 0, 255))

    cv2.imshow('gaze estimation', img)
    cv2.imwrite('gaze.jpg', img)

def parse_args():
    parser = argparse.ArgumentParser(description='Testing')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()

视线估计Gaze-Estimation PFLD实现_第1张图片

3.export onnx

# from __future__ import absolute_import
# from __future__ import division
# from __future__ import print_function
import argparse
import sys
import time
from models.pfld import Gaze_PFLD

import torch
import torch.nn as nn
import models

# def load_model_weight(model, checkpoint):
#     state_dict = checkpoint['model_state_dict']
#     # strip prefix of state_dict
#     if list(state_dict.keys())[0].startswith('module.'):
#         state_dict = {k[7:]: v for k, v in checkpoint['model_state_dict'].items()}

#     model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()

#     # check loaded parameters and created model parameters
#     for k in state_dict:
#         if k in model_state_dict:
#             if state_dict[k].shape != model_state_dict[k].shape:
#                 print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
#                     k, model_state_dict[k].shape, state_dict[k].shape))
#                 state_dict[k] = model_state_dict[k]
#         else:
#             print('Drop parameter {}.'.format(k))
#     for k in model_state_dict:
#         if not (k in state_dict):
#             print('No param {}.'.format(k))
#             state_dict[k] = model_state_dict[k]
#     model.load_state_dict(state_dict, strict=False)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default="./checkpoint/snapshot/checkpoint.pth.tar", help='weights path')  # from yolov5/models/
    parser.add_argument('--img-size', nargs='+', type=int, default=[112, 160], help='image size')  # height, width
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    opt = parser.parse_args()
    opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand

    device = "cpu"
    print("=====> load pytorch checkpoint...")
    checkpoint = torch.load(opt.weights, map_location=torch.device('cpu')) 
    nstack = checkpoint['nstack']
    nfeatures = checkpoint['nfeatures']
    nlandmarks = checkpoint['nlandmarks']

    net = Gaze_PFLD().to(device)

    img = torch.zeros(1, 1, *opt.img_size).to(device)
    landmarks, gaze = net.forward(img)
    f = opt.weights.replace('.pth.tar', '.onnx')  # filename
    torch.onnx.export(net, img, f,export_params=True, verbose=False, opset_version=12, input_names=['inputs'])
    # # ONNX export
        import onnx
        from onnxsim import simplify

        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = opt.weights.replace('.pth.tar', '.onnx')  # filename
        torch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, f)
        print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('ONNX export success, saved as %s' % f)
    except Exception as e:
        print('ONNX export failure: %s' % e)
