人脸关键点 HRNet-Facial-Landmark-Detection

原版:https://github.com/HRNet/HRNet-Facial-Landmark-Detection

我改的:https://github.com/jacke121/HRNet-Facial-Landmark-Detection

图像输入 256*256,1070 gpu 100多ms,有的准,有的不准。

不是特别准,可能代码不对

权重下载:https://github.com/HRNet/HRNet-Facial-Landmark-Detection/issues/8

权重下载需要onedrive:

https://blog.csdn.net/jacke121/article/details/100656907

 

人脸数据集:

https://wywu.github.io/projects/LAB/WFLW.html

cofw数据集:

https://download.csdn.net/download/u013510846/10798589

https://download.csdn.net/download/jrneymar/9849264

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Created by Tianheng Cheng([email protected])
# ------------------------------------------------------------------------------

import os
import pprint
import argparse
import time
import scipy
import scipy.misc
import cv2
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import sys

from lib.core.evaluation import decode_preds, compute_nme

import numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import lib.models as models
from lib.config import config, update_config
from lib.utils import utils
from lib.datasets import get_dataset
from lib.core import function


def parse_args():

    parser = argparse.ArgumentParser(description='Train Face Alignment')

    parser.add_argument('--cfg', help='experiment configuration filename',
                        default='experiments/wflw/face_alignment_wflw_hrnet_w18.yaml', type=str)
    parser.add_argument('--model-file', help='model parameters', default="HR18-WFLW.pth", type=str)

    args = parser.parse_args()
    update_config(config, args)
    return args

def inference(config, data_loader, model):
    # batch_time = AverageMeter()
    # data_time = AverageMeter()
    # losses = AverageMeter()

    num_classes = config.MODEL.NUM_JOINTS
    predictions = torch.zeros((len(data_loader.dataset), num_classes, 2))

    model.eval()

    nme_count = 0
    nme_batch_sum = 0
    count_failure_008 = 0
    count_failure_010 = 0
    end = time.time()

    with torch.no_grad():
        for i, (inp, target, meta) in enumerate(data_loader):
            output = model(inp)
            score_map = output.data.cpu()
            preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])

            # NME
            nme_temp = compute_nme(preds, meta)

            failure_008 = (nme_temp > 0.08).sum()
            failure_010 = (nme_temp > 0.10).sum()
            count_failure_008 += failure_008
            count_failure_010 += failure_010

            nme_batch_sum += np.sum(nme_temp)
            nme_count = nme_count + preds.size(0)
            for n in range(score_map.size(0)):
                predictions[meta['index'][n], :, :] = preds[n, :, :]

            # measure elapsed time
            end = time.time()

    nme = nme_batch_sum / nme_count
    # failure_008_rate = count_failure_008 / nme_count
    # failure_010_rate = count_failure_010 / nme_count


    return nme, predictions
def get_transform(center, scale, output_size, rot=0):
    """
    General image processing functions
    """
    # Generate transformation matrix
    h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(output_size[1]) / h
    t[1, 1] = float(output_size[0]) / h
    t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
    t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    if not rot == 0:
        rot = -rot  # To match direction of rotation from cropping
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0, :2] = [cs, -sn]
        rot_mat[1, :2] = [sn, cs]
        rot_mat[2, 2] = 1
        # Need to rotate around center
        t_mat = np.eye(3)
        t_mat[0, 2] = -output_size[1]/2
        t_mat[1, 2] = -output_size[0]/2
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
    return t
def transform_pixel(pt, center, scale, output_size, invert=0, rot=0):
    # Transform pixel location to different reference
    t = get_transform(center, scale, output_size, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int) + 1

def crop(img, center, scale, output_size, rot=0):
    center_new = center.clone()

    # Preprocessing for efficient cropping
    ht, wd = img.shape[0], img.shape[1]
    sf = scale * 200.0 / output_size[0]
    if sf < 2:
        sf = 1
    else:
        new_size = int(np.math.floor(max(ht, wd) / sf))
        new_ht = int(np.math.floor(ht / sf))
        new_wd = int(np.math.floor(wd / sf))
        if new_size < 2:
            return torch.zeros(output_size[0], output_size[1], img.shape[2]) \
                        if len(img.shape) > 2 else torch.zeros(output_size[0], output_size[1])
        else:
            img = scipy.misc.imresize(img, [new_ht, new_wd])  # (0-1)-->(0-255)
            center_new[0] = center_new[0] * 1.0 / sf
            center_new[1] = center_new[1] * 1.0 / sf
            scale = scale / sf

    # Upper left point
    ul = np.array(transform_pixel([0, 0], center_new, scale, output_size, invert=1))
    # Bottom right point
    br = np.array(transform_pixel(output_size, center_new, scale, output_size, invert=1))

    # Padding so that when rotated proper amount of context is included
    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
    if not rot == 0:
        ul -= pad
        br += pad

    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]

    new_img = np.zeros(new_shape, dtype=np.float32)

    # Range to fill new array
    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
    # Range to sample from original image
    old_x = max(0, ul[0]), min(len(img[0]), br[0])
    old_y = max(0, ul[1]), min(len(img), br[1])
    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]

    if not rot == 0:
        # Remove padding
        new_img = scipy.misc.imrotate(new_img, rot)
        new_img = new_img[pad:-pad, pad:-pad]
    new_img = scipy.misc.imresize(new_img, output_size)
    return new_img

def main():

    args = parse_args()
    input_size=[256, 256]
    output_size=[64, 64]
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    logger, final_output_dir, tb_log_dir = utils.create_logger(config, args.cfg, 'test')


    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus)
    model.cuda()
    model.eval()
    # load model
    state_dict = torch.load(args.model_file)
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
    model.module.load_state_dict(state_dict)



    path=r'D:\data\photo\gongsi/'
    path=r'D:\data\photo\faces/'
    files =os.listdir(path)
    for file in files:
        # img=cv2.imread(path+file)

        image_path=path+file
        # image_path=path+'tz1.jpg'
        img = cv2.imread(image_path)
        # image_path = os.path.join(self.data_root,self.landmarks_frame.iloc[idx, 0])
        scale =np.asarray([1,1],dtype=np.float32)# self.landmarks_frame.iloc[idx, 1]

        center_w =img.shape[0]//2# self.landmarks_frame.iloc[idx, 2]
        center_h =img.shape[1]//2# self.landmarks_frame.iloc[idx, 3]
        center = torch.Tensor([[center_w, center_h]])


        imgge=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        r = 0

        img = crop(imgge, center[0], scale[0], input_size, rot=r)

        img = img.astype(np.float32)
        img = (img/255.0 - mean) / std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(np.asarray([img]))


        meta = { 'center': center, 'scale': scale}
        start=time.time()
        output = model(img.cuda())
        print('time',time.time()-start)
        score_map = output.data.cpu()
        preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])

        for pred in preds[0]:
            cv2.circle(imgge, (int(pred[0]), int(pred[1])), 2, (0, 0, 255), -1)

        imgge=cv2.cvtColor(imgge,cv2.COLOR_BGR2RGB)
        cv2.imshow('asdf',imgge)
        cv2.waitKey()



if __name__ == '__main__':
    main()

 

你可能感兴趣的:(python)