flask+PRNet实现3D人脸重建换脸服务

1. 换脸流程

采用三维重建的方式重建出参考图的3Dshape,获得相应的颜色空间(identity);重建视频中人脸3Dshape,提取相应的vertices(shape);结合参考图的颜色空间以及目标图的vertices,渲染出更换了identity的face。

离线服务提取待换脸视频中人脸图片的3D定点信息,存放于redis中,由于顶点信息至少需要float32精度存放,导致把顶点信息以float32存放于视频中会,一个value会非常的大,故把视频拆分成了很多段存放于redis中,在取redis中信息时,用多线程的方式取出,然后进行换脸服务,比单个串行服务速度快很多。

2. 涉及的技术

人脸三维重建,图像渲染,图像补全,边缘检测,人分割

人脸三维重建:网络采用PRNet

3. 存在的问题及解决办法

抖动:视频中对人脸更换后出现抖动,通过对人脸检测框进行平滑处理可以有效降低抖动程度,确定抖动由人脸检测精度低造成,目前采用face++人脸检测接口进行人脸检测

边缘伪影:由换脸mask造成对边缘出现伪影,通过设置模板mask以及人脸分割,精确得到换脸mak

眼镜:采用图像补全技术,用边缘检测方法获取mask,根据得到等mask对图像进行补全

眼睛转动及嘴巴张闭:由于图像重建后对眼睛和嘴巴是固定对,在换脸mask上去掉相应区域保留视频中人眼和嘴巴

参考图和原图存在色差:通过颜色校正,调整图像亮度

4. 接口说明

版本: 1.0

描述:传入base64编码的二进制图片数据和视频名,把检测到的人脸通过3D人脸重建替换视频中人脸,根据换脸后的视频帧视频

请求方式:post

请求链接:xxxxxxxxxx:9775/ai/v1/FaceSwap

图片要求:

              图片格式:JPG(JPEG),PNG

              图片像素尺寸:最小 200*200 像素,最大 4096*4096像素

5.  整体架构方案

flask+PRNet实现3D人脸重建换脸服务_第1张图片

6. 接口设计

接口请求参数: 

参数名

必选

类型

说明

requestId String 用于区分每一次请求的唯一的字符串id
inputImage

String 图片的base64值
videoName String 视频名称
token String 服务鉴权标识,AI组统一分配
userId String 用户id

 

接口返回结果示例:

{
    "code": 0,
    "msg": "success",
    "data": {
       "requestId": "100022" ,
       "faceSwapRes": True,
       "timeUsed": "30.11962342262268"
    }
 }

 

接口返回参数说明:

参数名

类型

说明

参数名

类型

说明

requestId
String 用户请求唯一表示
faceSwapRes
String

换脸服务返回结果,成功True,失败False

timeUsed
Int 整个请求所花费的时间,单位为毫秒

 

接口状态码code:

状态码

状态说明

0 成功
2 未检测出人脸
3 鉴权失败
4 参数无效
5 图片尺寸不符合超出范围
6

请求异常

7. 代码如下:

# encoding:utf-8
from meinheld import server
from flask import Flask, request
from skimage.io import imread, imsave
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED, as_completed
import logging
from logging.handlers import TimedRotatingFileHandler
import json
import base64
import hashlib
from threading import Thread, Lock
from PIL import Image
from io import BytesIO
from conf import config
import os
from api import PRN
from glass_judge import *
# from utils.render import render_texture,render_texture_v1
# from utils.estimate_pose import rotate_pos
import cv2
import redis
# from face_segmentation.face_segment import FaceSegment
from face_segmentation.face_segment import FaceSegmentFCN
from mesh.render import render_colors
from faceDetect.face_detection import FaceDetector, FaceTracker
from face_align import FaceAligner_v1
from Pluralistic.FaceEdit import FaceEditor, CropLayer

app = Flask(__name__)

def setLog():
    log_fmt = '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    formatter = logging.Formatter(log_fmt)
    fh = TimedRotatingFileHandler(
        filename="log/run_faceswap_server" + str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + ".log",
        when="H", interval=1,
        backupCount=72)
    fh.setFormatter(formatter)
    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger()
    log.addHandler(fh)


setLog()

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

model_path = "./checkpoint/snapshot/checkpoint_epoch_1000.pth.tar"

# faceLandMark = FaceLandmarkModel(model_path)
# segmentor = FaceSegment('./face_segmentation/checkpoints/model.pt')
fcn_segmentor = FaceSegmentFCN('./face_segmentation/weights/Keras_FCN8s_face_seg_YuvalNirkin.h5')
MODEL_PATH = './faceDetect/model_new.pb'
face_detector = FaceDetector(MODEL_PATH, gpu_memory_fraction=0.25, visible_device_list='0')
face_aligner = FaceAligner_v1()
# cv2.dnn_registerLayer('Crop', CropLayer)
prn = PRN(is_dlib=True)
editor = FaceEditor()
executor = ThreadPoolExecutor(config.threadPoolSize)

# 创建链接到redis数据库的对象
pool = redis.ConnectionPool(host=config.redisHost, port=config.redisPort, password=config.redisPassword,
                            max_connections=config.maxConnections)
redisDb = redis.Redis(connection_pool=pool)

lock = Lock()
swap_threads = []
frame_dict_list = dict()
all_task = list()
imageList = [""]*5000
frame_count_all = 0
fps = 25
w = 255
h = 255


def colorTransfer(src, dst, mask=None):
    if mask is None:
        h, w, c = dst.shape
        x = np.array(np.arange(w))
        y = np.array(np.arange(h))
        X, Y = np.meshgrid(x, y)
        X = np.reshape(X, (w * h,))
        Y = np.reshape(Y, (w * h,))
        maskIndices = (X, Y)
    else:
        # indeksy nie czarnych pikseli maski
        maskIndices = np.where(mask != 0)
    transferredDst = np.copy(dst)

    # src[maskIndices[0], maskIndices[1]] zwraca piksele w nie czarnym obszarze maski
    maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32)
    maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32)
    meanSrc = np.mean(maskedSrc, axis=0)
    meanDst = np.mean(maskedDst, axis=0)
    maskedDst = maskedDst - meanDst
    maskedDst = maskedDst + meanSrc
    maskedDst = np.clip(maskedDst, 0, 255)
    transferredDst[maskIndices[0], maskIndices[1]] = maskedDst

    return transferredDst


def swapThread(alpha, new_colors, frame_key, frame_val, videoPath):
    start_time = time.time()
    logging.info(f"frame_key is:  {str(frame_key)}")
    global fps
    if frame_key == "fps":
        fps = frame_val.get("fps")

    frame_count = frame_key.split(":")[0]
    frame_val = eval(frame_val)
    vertices = frame_val.get("vertices")
    logging.info("vertices")
    fps = int(float(frame_val.get("fps")))

    new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
    new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
    new_mask = np.where(new_mask < 1, 0, 1)

    # image = base64.b64decode(image)
    # img = plt.imread(BytesIO(image), "jpg")
    # image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    time_image = time.time()
    image = cv2.imread(videoPath + str(frame_count) + ".jpg")
    if image is None or not image.data or len(image) < 1:
        return False
    global h
    global w
    [h, w, _] = image.shape
    im_size = (w, h)
    vertices = np.fromstring(vertices, dtype=np.float32)
    # vertices = np.fromstring(vertices, dtype=np.float16)
    vertices = vertices.astype(np.float32).copy()
    vertices = vertices.reshape((43867, -1))  # (43867,3)
    new_image = render_colors(vertices, prn.triangles, new_colors, h, w)  #3D人脸融合
    new_image = (255 * new_image).astype(np.uint8)
    # 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
    # 根据视频中人脸颜色,校正渲染出的人脸的颜色
    # new_image = correct_colours(image, new_image, landmark[:,:2])
    new_image = colorTransfer(image, new_image, new_mask)
    # 合并渲染出的人脸和视频中的人脸
    swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
                 new_image * alpha * new_mask[:, :, np.newaxis] + \
                 image * (1 - alpha) * new_mask[:, :, np.newaxis]
    # 得到泊松缝合中心位置
    r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
    center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
    center = tuple(map(int, center))

    if image is None or not image.data or len(image) < 1:
        return False

    output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
                               (new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
    out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

    if out is None or not out.data or len(out) < 1:
        return False
    logging.info(f"swap merge face cost time is:  {str(time.time() - start_time)}")
    # print(f"swap merge face cost time is:  {str(time.time() - start_time)}")
    time1 = time.time()
    ret, buf = cv2.imencode(".jpg", out)
    out_base64 = base64.b64encode(buf)
    lock.acquire()
    global imageList
    imageList[int(float(frame_count))] = out_base64
    lock.release()

    return True


def faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList):
    fps = 25
    w = 255
    h = 255

    for frame_key, frame_val in frame_dict.items():
        start_time = time.time()
        logging.info(f"frame_key is:  {str(frame_key)}")
        # global fps
        if frame_key == "fps":
            fps = frame_val.get("fps")
            continue

        frame_count = frame_key.split(":")[0]
        frame_val = eval(frame_val)
        vertices = frame_val.get("vertices")
        logging.info("vertices")
        fps = int(float(frame_val.get("fps")))
        new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
        new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
        new_mask = np.where(new_mask < 1, 0, 1)
        time_image = time.time()
        image = cv2.imread(videoPath + str(frame_count) + ".jpg")
        if image is None or not image.data or len(image) < 1:
            continue

        [h, w, _] = image.shape
        im_size = (w, h)
        vertices = np.fromstring(vertices, dtype=np.float32)
        # vertices = np.fromstring(vertices, dtype=np.float16)
        vertices = vertices.astype(np.float32).copy()
        vertices = vertices.reshape((43867, -1))  # (43867,3)

        new_image = render_colors(vertices, prn.triangles, new_colors, h, w)  # 从这开始 结合
        new_image = (255 * new_image).astype(np.uint8)
        # 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
        # 根据视频中人脸颜色,校正渲染出的人脸的颜色
        # new_image = correct_colours(image, new_image, landmark[:,:2])
        new_image = colorTransfer(image, new_image, new_mask)
        print(new_image.shape)
        print(image.shape)
        # 合并渲染出的人脸和视频中的人脸
        swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
                     new_image * alpha * new_mask[:, :, np.newaxis] + \
                     image * (1 - alpha) * new_mask[:, :, np.newaxis]

        # 得到泊松缝合中心位置
        r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
        center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
        center = tuple(map(int, center))

        if image is None or not image.data or len(image) < 1:
            continue

        output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
                                   (new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
        out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

        if out is None or not out.data or len(out) < 1:
            continue
        logging.info(f"swap merge face cost time is:  {str(time.time() - start_time)}")
        # print(f"swap merge face cost time is:  {str(time.time() - start_time)}")
        time1 = time.time()
        ret, buf = cv2.imencode(".jpg", out)
        out_base64 = base64.b64encode(buf)
        print(f"encode base64 cost time is:  {str(time.time() - time1)}")

        print("frame_count is : ", frame_count)
        # global imageList
        imageList[int(float(frame_count))] = out_base64

    return imageList, fps, w, h


def get_redis(video_key, redisDb, i, alpha, new_colors, videoPath):
    logging.info("key is:  " + video_key)
    frame_dict = eval(redisDb.get(video_key))
    global frame_count_all
    frame_count_all += len(frame_dict)
    for frame_key, frame_val in frame_dict.items():
        swapThread(alpha, new_colors, frame_key, frame_val, videoPath)

    return True


def get_redis1(video_key, redisDb, i, alpha, new_colors, videoPath):
    logging.info("key is:  " + video_key)
    frame_dict = eval(redisDb.get(video_key))
    global frame_count_all
    frame_count_all += len(frame_dict)
    global frame_dict_list
    # frame_dict_list.append(frame_dict)
    frame_dict_list.update(frame_dict)

    return True


def faceSwap(ref_image, video_id, prn, videoPath):
  try:
    begin_time = time.time()
    # 人脸加权比例
    alpha = 0.8
    # read referance image and get the color space
    # ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
    ref_image = face_aligner.aligner(ref_image)
    h, w, _ = ref_image.shape
    boxes, _ = face_detector(ref_image)
    ref_pos = prn.process(ref_image, image_info=boxes[0])
    # ref_pos = prn.process(ref_image)
    logging.info("faceDetector and prn.process  cost time:  " + str(time.time() - begin_time))

    ref_image = ref_image / 255.
    ref_texture = cv2.remap(ref_image, ref_pos[:, :, :2].astype(np.float32), None, interpolation=cv2.INTER_NEAREST,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
    new_colors = prn.get_colors_from_texture(ref_texture)  # 获取重建出来的ref_texture上的点的颜色值
    logging.info("to remap get colors  cost time:  " + str(time.time() - begin_time))

    # 获取脸部mask颜色值
    redis_time = time.time()

    global all_task
    global frame_count_all
    frame_count_all = 0
    for i in range(120):
        video_key = video_id + "-" + str(i + 1)
        if redisDb.exists(video_key):
            # redis_thread = executor.submit(get_redis1, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
            redis_thread = executor.submit(get_redis, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
            all_task.append(redis_thread)

    # executor.shutdown(wait=True)
    # wait(all_task, return_when=ALL_COMPLETED)
    for future in as_completed(all_task):
        data = future.result()
        logging.info(f"in main: get page {str(data)}s success")

    # frame_dict = eval(redisDb.get(video_id))
    logging.info("get redis val cost time:  " + str(time.time() - redis_time))
    print("get redis val cost time:  " + str(time.time() - redis_time))

    logging.info("frame_dict_list len is:  " + str(len(frame_dict_list)))
    # 提取关键点
    logging.info("threads swap face cost time:  " + str(time.time() - begin_time))

    # imageList, fps, w, h = faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList)
    # return imageList, fps, w, h
    # global swap_threads

    # for frame_key, frame_val in frame_dict_list.items():
    #     # swapThread(alpha, new_colors, frame_key, frame_val, videoPath, imageList)
    #
    #     thread = Thread(target=swapThread, args=(alpha, new_colors, frame_key, frame_val, videoPath))
    #     swap_threads.append(thread)
    #     thread.start()

    # for t in swap_threads:
    #     t.join()
    return True
  except Exception as ex:

      logging.exception(ex)
      return False


@app.route('/ai/v1/FaceSwap', methods=['POST'])
def faceSwapMethod():
    try:
        start_time = time.time()
        resParm = request.data
        # 转字符串
        resParm = str(resParm, encoding="utf-8")
        resParm = eval(resParm)

        requestId = resParm.get('requestId')
        # 服务鉴权
        token = resParm.get('token')
        if not token:
            res = {'code': 3, 'msg': 'token fail'}
            logging.error("code: 3 msg:  token fail ")
            return json.dumps(res)
        videoId = resParm.get("videoName")
        if videoId is None or videoId.strip() == '':
            res = {'code': 7, 'msg': 'videoName is null'}
            logging.error("code: 3 msg:  videoName is null")

        # 按照debase64进行处理
        modelImg_base64 = resParm.get("inputImage")
        if not modelImg_base64:
            res = {'code': 4, 'msg': ' picture param invalid'}
            logging.error("code: 4  msg:  picture param invalid")
            return json.dumps(res)
        modelImg_data_1 = None
        if is_has_glass(modelImg_base64):
            modelImg = base64.b64decode(modelImg_base64)
            modelImg_data = np.fromstring(modelImg, np.uint8)
            modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)

            image = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
            res = editor.removeglasses(image)
            modelImg_data_1 = res[0]
            img = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
            cv2.imwrite("glass_img.jpg", img)
        else:

            modelImg = base64.b64decode(modelImg_base64)
            # recv_time = time.time()
            # logging.info(f"recv image cost time:  {str(recv_time - start_time)}")
            modelImg_data = np.fromstring(modelImg, np.uint8)
            modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)
        # cv2.imwrite("modelImg.jpg", modelImg_data_1)
        # 判定图片尺寸
        if modelImg_data_1.shape[0] > config.size or modelImg_data_1.shape[1] > config.size:
            res = {'code': 5, 'msg': ' picture size invalid'}
            logging.error("code: 5 msg: picture size invalid")
            return json.dumps(res)
        logging.info(f"modelImg_data_1  shape:  {str(modelImg_data_1.shape)}   size:  {str(modelImg_data_1.size)}")

        time_predict = time.time()
        # cv2.imwrite("upload_ref.jpg", modelImg_data_1)
        modelImg_data_1 = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
        swapRes = gen_swap_face(modelImg_data_1, videoId, prn)

        logging.info(f"face swap cost Time is: {str(time.time() - time_predict)} ")
        for t in swap_threads:
            t.join()

        timeUsed = time.time() - start_time
        data = {'requestId': requestId, 'faceSwapRes': str(swapRes), 'timeUsed': str(timeUsed)}
        res = {'code': 0, 'msg': 'success', 'data': data}
        logging.info(f"code:0  msg:success  face swap cost Time is: {str(timeUsed)} ")
        return json.dumps(res)
    except Exception as e:
        logging.exception(e)
        res = {'code': 6, 'msg': 'request exception'}
        return json.dumps(res)


def gen_swap_face(modelImg_data_1, videoId, prn):
    try:
        videoName = os.path.basename(videoId)
        videoName = videoName.split(".")[0]
        refImgMd = hashlib.md5(modelImg_data_1).hexdigest()
        videoPath = './img_video/' + videoName + "/"

        save_res_path = './img_video/' + videoName + "/" + refImgMd + "/"
        if not os.path.exists(save_res_path):
            os.makedirs(save_res_path)

        time_predict = time.time()
        # imageList, fps, w, h = faceSwap(modelImg_data_1, videoId, prn, videoPath)
        swapRes = faceSwap(modelImg_data_1, videoId, prn, videoPath)
        if not swapRes:
            return False

        print(f"face swap Method cost Time is: {str(time.time() - time_predict)} ")

        global imageList
        global fps
        global w
        global h
        global frame_count_all

        im_size = (w, h)
        out = None
        logging.info(f"imageList len is: {str(len(imageList))}")
        if len(imageList) < 1:
            return False

        start_time = time.time()

        # for image in imageList:
        for i in range(frame_count_all):
            image = imageList[i]

            if image is None or len(image) < 1:
                continue

            image = base64.b64decode(image)
            img = plt.imread(BytesIO(image), "jpg")
            image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if out is None:
                fourcc = cv2.VideoWriter_fourcc(*"mp4v")
                out = cv2.VideoWriter(save_res_path + videoName + "-" + refImgMd + ".mp4", fourcc, fps, im_size, True)
            out.write(image)
            # logging.info(f"imageList len is: {str(len(imageList))}")
            # logging.info(f"img_size is: {str(im_size)}")
            # print(str(i) + "index  frame_count_all  ", frame_count_all)
        logging.info("image List to merge face video cost:  " + str(time.time() - start_time))
        return True
    except Exception as x:
        logging.exception(x)
        return False


def save_video_face(videoName):
    cap = cv2.VideoCapture(videoName)
    fps = cap.get(cv2.CAP_PROP_FPS)
    im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

    out = None
    frameId = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frameId < 350:
            frameId += 1
            continue
        if out is None:
            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            out = cv2.VideoWriter("test_swap1_200f.mp4", fourcc, fps, im_size, True)
        out.write(frame)
        frameId += 1


if __name__ == "__main__":
    logging.info('Starting the server...')
    server.listen(("0.0.0.0", 9775))
    server.run(app)
    # app.run(host='0.0.0.0', port=18885, threaded=True)

 

你可能感兴趣的:(flask,python)