argparse模块的作用是用于解析命令行参数,基本使用:
import argparse
#创建解析器对象,description:描述程序
parser = argparse.ArgumentParser(description='3DDFA inference pipeline')
#添加参数,type:把从命令行输入的结果转成设置的类型
parser.add_argument('-f', '--files', nargs='+',
help='image files paths fed into network, single or multiple images')
parser.add_argument('-m', '--mode', default='gpu', type=str, help='gpu or cpu mode')
args = parser.parse_args()
#torch.load(model_path)返回的是一个 OrderedDict
checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict']
'''
pytorch允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上
torch.load(checkpoint_fp) #CPU->CPU,GPU->GPU
torch.load(checkpoint_fp, map_location=lambda storage, loc: storage) #GPU->CPU
torch.load(checkpoint_fp, map_location=lambda storage, loc: storage.cuda(1)) #CPU->GPU1
'''
将模型的全部参数保存到model_dict
model_dict = model.state_dict()
加载训练好的模型
model.load_state_dict(model_dict)
加载dlib模块进行人脸检测和裁剪
if args.dlib_landmark:
dlib_landmark_model = 'models/shape_predictor_68_face_landmarks.dat'
face_regressor = dlib.shape_predictor(dlib_landmark_model)
if args.dlib_bbox:
face_detector = dlib.get_frontal_face_detector()
torchvision.transforms是pytorch中的图像预处理包
用transforms.Compose()将多个步骤融合到一起
transform = transforms.Compose([ToTensorGjz(), NormalizeGjz(mean=127.5, std=128)])
'''
Normalize:Normalized an tensor image with mean and standard deviation
ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0]
'''
对args.files中的每一张图像循环进行如下操作:
#默认使用dlib landmark和bbox
#将检测到的所有人脸位置保存到rects中
rects = face_detector(img_ori, 1)
#然后将每一个rect的左、上、右、下四个位置保存到roi_bbox,利用crop_img()裁剪图片
roi_box = parse_roi_box_from_landmark(pts)
img = crop_img(img_ori, roi_box)
#将图片resize到网络要求的大小,插值方式为最近邻插值
img = cv2.resize(img, dsize=(STD_SIZE, STD_SIZE), interpolation=cv2.INTER_LINEAR)
input = transform(img).unsqueeze(0) #在图片第零维增加一个维度
#传入模型
param = model(input)
#降维然后flatten至一维
param = param.squeeze().cpu().numpy().flatten().astype(np.float32)
# 预测68个特征点或者dense特征点
pts68 = predict_68pts(param, roi_box)
未完…