PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测

『深度学习7日打卡营』人脸关键点检测

一、问题定义

人脸关键点检测,是输入一张人脸图片,模型会返回人脸关键点的一系列坐标,从而定位到人脸的关键信息。

# 环境导入
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import cv2
import paddle

paddle.set_device('gpu') # 设置为GPU

import warnings 
warnings.filterwarnings('ignore') # 忽略 warning

二、数据准备

2.1 下载数据集

本次实验所采用的数据集来源为github的开源项目

目前该数据集已上传到 AI Studio 人脸关键点识别,加载后可以直接使用下面的命令解压。

# !unzip data/data69065/data.zip

解压后的数据集结构为

data/
|—— test
|   |—— Abdel_Aziz_Al-Hakim_00.jpg
    ... ...
|—— test_frames_keypoints.csv
|—— training
|   |—— Abdullah_Gul_10.jpg
    ... ...
|—— training_frames_keypoints.csv

其中,trainingtest 文件夹分别存放训练集和测试集。training_frames_keypoints.csvtest_frames_keypoints.csv 存放着训练集和测试集的标签。接下来,我们先来观察一下 training_frames_keypoints.csv 文件,看一下训练集的标签是如何定义的。

key_pts_frame = pd.read_csv('data/training_frames_keypoints.csv') # 读取数据集
print('Number of images: ', key_pts_frame.shape[0]) # 输出数据集大小
key_pts_frame.head(5) # 看前五条数据
Number of images:  3462
Unnamed: 0 0 1 2 3 4 5 6 7 8 ... 126 127 128 129 130 131 132 133 134 135
0 Luis_Fonsi_21.jpg 45.0 98.0 47.0 106.0 49.0 110.0 53.0 119.0 56.0 ... 83.0 119.0 90.0 117.0 83.0 119.0 81.0 122.0 77.0 122.0
1 Lincoln_Chafee_52.jpg 41.0 83.0 43.0 91.0 45.0 100.0 47.0 108.0 51.0 ... 85.0 122.0 94.0 120.0 85.0 122.0 83.0 122.0 79.0 122.0
2 Valerie_Harper_30.jpg 56.0 69.0 56.0 77.0 56.0 86.0 56.0 94.0 58.0 ... 79.0 105.0 86.0 108.0 77.0 105.0 75.0 105.0 73.0 105.0
3 Angelo_Reyes_22.jpg 61.0 80.0 58.0 95.0 58.0 108.0 58.0 120.0 58.0 ... 98.0 136.0 107.0 139.0 95.0 139.0 91.0 139.0 85.0 136.0
4 Kristen_Breitweiser_11.jpg 58.0 94.0 58.0 104.0 60.0 113.0 62.0 121.0 67.0 ... 92.0 117.0 103.0 118.0 92.0 120.0 88.0 122.0 84.0 122.0

5 rows × 137 columns

上表中每一行都代表一条数据,其中,第一列是图片的文件名,之后从第0列到第135列,就是该图的关键点信息。因为每个关键点可以用两个坐标表示,所以 136/2 = 68,就可以看出这个数据集为68点人脸关键点数据集。

Tips1: 目前常用的人脸关键点标注,有如下点数的标注

  • 5点
  • 21点
  • 68点
  • 98点

Tips2:本次所采用的68标注,标注顺序如下:

# 计算标签的均值和标准差,用于标签的归一化
key_pts_values = key_pts_frame.values[:,1:] # 取出标签信息
data_mean = key_pts_values.mean() # 计算均值
data_std = key_pts_values.std()   # 计算标准差
print('标签的均值为:', data_mean)
print('标签的标准差为:', data_std)
标签的均值为: 104.4724870017331
标签的标准差为: 43.17302271754281

2.2 查看图像

def show_keypoints(image, key_pts):
    """
    Args:
        image: 图像信息
        key_pts: 关键点信息,
    展示图片和关键点信息
    """
    plt.imshow(image.astype('uint8'))  # 展示图片信息
    for i in range(len(key_pts)//2,):
        plt.scatter(key_pts[i*2], key_pts[i*2+1], s=20, marker='.', c='b') # 展示关键点信息
# 展示单条数据

n = 14 # n为数据在表格中的索引 
image_name = key_pts_frame.iloc[n, 0] # 获取图像名称
key_pts = key_pts_frame.iloc[n, 1:].as_matrix() # 将图像label格式转为numpy.array的格式
key_pts = key_pts.astype('float').reshape(-1) # 获取图像关键点信息
print(key_pts.shape)
plt.figure(figsize=(5, 5)) # 展示的图像大小
show_keypoints(mpimg.imread(os.path.join('data/training/', image_name)), key_pts) # 展示图像与关键点信息
plt.show() # 展示图像
(136,)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XXO3m7qg-1612608269248)(output_12_1.png)]

2.3 数据集定义

使用飞桨框架高层API的 paddle.io.Dataset 自定义数据集类,具体可以参考官网文档 自定义数据集。

作业1:自定义 Dataset,完成人脸关键点数据集定义

按照 __init__ 中的定义,实现 __getitem____len__.

# 按照Dataset的使用规范,构建人脸关键点数据集

from paddle.io import Dataset

class FacialKeypointsDataset(Dataset):
    # 人脸关键点数据集
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, csv_file, root_dir, transform=None):
        """
        步骤二:实现构造函数,定义数据集大小
        Args:
            csv_file (string): 带标注的csv文件路径
            root_dir (string): 图片存储的文件夹路径
            transform (callable, optional): 应用于图像上的数据处理方法
        """
        self.key_pts_frame = pd.read_csv(csv_file) # 读取csv文件
        self.root_dir = root_dir # 获取图片文件夹路径
        self.transform = transform # 获取 transform 方法

    def __getitem__(self, idx):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """

        # 实现 __getitem__
        image_name=os.path.join(self.root_dir, self.key_pts_frame.iloc[idx,0])
        image=mpimg.imread(image_name)

        if(image.shape[2]==4):
            image=image[:,:,0:3]
        
        # 获取关键点信息
        key_pts = self.key_pts_frame.iloc[idx, 1:].as_matrix()
        key_pts = key_pts.astype('float').reshape(-1) # [136, 1]

        # 如果定义了 transform 方法,使用 transform方法
        if self.transform:
            image, key_pts = self.transform([image, key_pts])
        
        # 转为 numpy 的数据格式
        image = np.array(image, dtype='float32')
        key_pts = np.array(key_pts, dtype='float32')

        return image, key_pts

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        
        # 实现 __len__
        return len(self.key_pts_frame)

2.4 训练集可视化

实例化数据集并显示一些图像。

# 构建一个数据集类
face_dataset = FacialKeypointsDataset(csv_file='data/training_frames_keypoints.csv',
                                      root_dir='data/training/')

# 输出数据集大小
print('数据集大小为: ', len(face_dataset))
# 根据 face_dataset 可视化数据集
num_to_display = 3

for i in range(num_to_display):
    
    # 定义图片大小
    fig = plt.figure(figsize=(20,10))
    
    # 随机选择图片
    rand_i = np.random.randint(0, len(face_dataset))
    sample = face_dataset[rand_i]

    # 输出图片大小和关键点的数量
    print(i, sample[0].shape, sample[1].shape)

    # 设置图片打印信息
    ax = plt.subplot(1, num_to_display, i + 1)
    ax.set_title('Sample #{}'.format(i))
    
    # 输出图片
    show_keypoints(sample[0], sample[1])
数据集大小为:  3462
0 (131, 115, 3) (136,)
1 (293, 229, 3) (136,)
2 (148, 121, 3) (136,)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cLZuxcfn-1612608269249)(output_18_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WQQ5LV0V-1612608269251)(output_18_2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HlF5zw2Y-1612608269251)(output_18_3.png)]

上述代码虽然完成了数据集的定义,但是还有一些问题,如:

  • 每张图像的大小不一样,图像大小需要统一以适配网络输入要求
  • 图像格式需要适配模型的格式输入要求
  • 数据量比较小,没有进行数据增强

这些问题都会影响模型最终的性能,所以需要对数据进行预处理。

2.5 Transforms

对图像进行预处理,包括灰度化、归一化、重新设置尺寸、随机裁剪,修改通道格式等等,以满足数据要求;每一类的功能如下:

  • 灰度化:丢弃颜色信息,保留图像边缘信息;识别算法对于颜色的依赖性不强,加上颜色后鲁棒性会下降,而且灰度化图像维度下降(3->1),保留梯度的同时会加快计算。
  • 归一化:加快收敛
  • 重新设置尺寸:数据增强
  • 随机裁剪:数据增强
  • 修改通道格式:改为模型需要的结构

作业2:实现自定义ToCHW

实现数据预处理方法 ToCHW

# 标准化自定义 transform 方法

# class TransformAPI(object):
#     """
#     步骤一:继承 object 类
#     """
#     def __call__(self, data):

#         """
#         步骤二:在 __call__ 中定义数据处理方法
#         """
        
#         processed_data = data
#         return  processed_data
import paddle.vision.transforms.functional as F

class GrayNormalize(object):
    # 将图片变为灰度图,并将其值放缩到[0, 1]
    # 将 label 放缩到 [-1, 1] 之间

    def __call__(self, data):
        image = data[0]   # 获取图片
        key_pts = data[1] # 获取标签
        
        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)

        # 灰度化图片
        gray_scale = paddle.vision.transforms.Grayscale(num_output_channels=3)
        image_copy = gray_scale(image_copy)
        
        # 将图片值放缩到 [0, 1]
        image_copy = image_copy / 255.0
        
        # 将坐标点放缩到 [-1, 1]
        mean = data_mean # 获取标签均值
        std = data_std   # 获取标签标准差
        key_pts_copy = (key_pts_copy - mean)/std

        return image_copy, key_pts_copy

class Resize(object):
    # 将输入图像调整为指定大小

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, data):

        image = data[0]    # 获取图片
        key_pts = data[1]  # 获取标签

        image_copy = np.copy(image)      
        key_pts_copy = np.copy(key_pts)

        h, w = image_copy.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = F.resize(image_copy, (new_h, new_w))
        
        # scale the pts, too
        key_pts_copy[::2] = key_pts_copy[::2] * new_w / w
        key_pts_copy[1::2] = key_pts_copy[1::2] * new_h / h

        return img, key_pts_copy


class RandomCrop(object):
    # 随机位置裁剪输入的图像

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, data):
        image = data[0]
        key_pts = data[1]

        image_copy = np.copy(image)
        key_pts_copy = np.copy(key_pts)

        h, w = image_copy.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image_copy = image_copy[top: top + new_h,
                      left: left + new_w]

        key_pts_copy[::2] = key_pts_copy[::2] - left
        key_pts_copy[1::2] = key_pts_copy[1::2] - top

        return image_copy, key_pts_copy

class ToCHW(object):
    # 将图像的格式由HWC改为CHW
    def __call__(self, data):

        # 实现ToCHW,可以使用 paddle.vision.transforms.Transpose 实现
        image = data[0]
        key_pts = data[1]

        transpose = T.Transpose() # 改为CHW
        image = transpose(image)
        return image, key_pts

看一下每种图像预处理方法的的效果。

import paddle.vision.transforms as T

# 测试 Resize
resize = Resize(256)

# 测试 RandomCrop
random_crop = RandomCrop(128)

# 测试 GrayNormalize
norm = GrayNormalize()

# 测试 Resize + RandomCrop,图像大小变到250*250, 然后截取出224*224的图像块
composed = paddle.vision.transforms.Compose([Resize(250), RandomCrop(224)])

test_num = 800 # 测试的数据下标
data = face_dataset[test_num]

transforms = {'None': None, 
              'norm': norm,
              'random_crop': random_crop,
              'resize': resize ,
              'composed': composed}
for i, func_name in enumerate(['None', 'norm', 'random_crop', 'resize', 'composed']):
    
    # 定义图片大小
    fig = plt.figure(figsize=(20,10))
    
    # 处理图片
    if transforms[func_name] != None:
        transformed_sample = transforms[func_name](data)
    else:
        transformed_sample = data

    # 设置图片打印信息
    ax = plt.subplot(1, 5, i + 1)
    ax.set_title(' Transform is #{}'.format(func_name))
    
    # 输出图片
    show_keypoints(transformed_sample[0], transformed_sample[1])

PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测_第1张图片
PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测_第2张图片
PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测_第3张图片
PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测_第4张图片
PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』第三节课后作业--人脸关键点检测_第5张图片

2.6 使用数据预处理的方式完成数据定义

让我们将 Resize、RandomCrop、GrayNormalize、ToCHW 应用于新的数据集

from paddle.vision.transforms import Compose

data_transform = Compose([Resize(256), RandomCrop(224), GrayNormalize(), ToCHW()])

# create the transformed dataset
train_dataset = FacialKeypointsDataset(csv_file='data/training_frames_keypoints.csv',
                                       root_dir='data/training/',
                                       transform=data_transform)
print('Number of train dataset images: ', len(train_dataset))

for i in range(4):
    sample = train_dataset[i]
    print(i, sample[0].shape, sample[1].shape)

test_dataset = FacialKeypointsDataset(csv_file='data/test_frames_keypoints.csv',
                                             root_dir='data/test/',
                                             transform=data_transform)

print('Number of test dataset images: ', len(test_dataset))
Number of train dataset images:  3462
0 (3, 224, 224) (136,)
1 (3, 224, 224) (136,)
2 (3, 224, 224) (136,)
3 (3, 224, 224) (136,)
Number of test dataset images:  770

3、模型组建

3.1 组网可以很简单

根据前文的分析可知,人脸关键点检测和分类,可以使用同样的网络结构,如LeNet、Resnet50等完成特征的提取,只是在原来的基础上,需要修改模型的最后部分,将输出调整为 人脸关键点的数量*2,即每个人脸关键点的横坐标与纵坐标,就可以完成人脸关键点检测任务了,具体可以见下面的代码,也可以参考官网案例:人脸关键点检测

网络结构如下:

作业3:根据上图,实现网络结构

import paddle.nn as nn
from paddle.vision.models import resnet50
class SimpleNet(nn.Layer):
    
    def __init__(self, key_pts):

        super(SimpleNet, self).__init__()

        # 实现 __init__
        self.backbone = paddle.vision.models.resnet101(pretrained=True)
        # 添加第一个线性变换层
        self.linear1 = nn.Linear(in_features=1000, out_features=512)

        # 使用 ReLU 激活函数
        self.act1 = nn.ReLU()

        # 添加第二个线性变换层作为输出,输出元素的个数为 key_pts*2,代表每个关键点的坐标
        self.linear2 = nn.Linear(in_features=512, out_features=key_pts*2)

    def forward(self, x):

        # 实现 forward
        x = self.backbone(x)
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        return x

3.2 网络结构可视化

使用model.summary可视化网络结构。

model = paddle.Model(SimpleNet(key_pts=68))
model.summary((-1, 3, 224, 224))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-1        [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
   BatchNorm2D-1    [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
      ReLU-1        [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
    MaxPool2D-1     [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
     Conv2D-3        [[1, 64, 56, 56]]     [1, 64, 56, 56]         4,096     
   BatchNorm2D-3     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-2         [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-4        [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
   BatchNorm2D-4     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
     Conv2D-5        [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
   BatchNorm2D-5     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
     Conv2D-2        [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
   BatchNorm2D-2     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
 BottleneckBlock-1   [[1, 64, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-6        [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
   BatchNorm2D-6     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-3         [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-7        [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
   BatchNorm2D-7     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
     Conv2D-8        [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
   BatchNorm2D-8     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
 BottleneckBlock-2   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-9        [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
   BatchNorm2D-9     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-4         [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-10       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-10     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
     Conv2D-11       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-11     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
 BottleneckBlock-3   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-13       [[1, 256, 56, 56]]    [1, 128, 56, 56]       32,768     
  BatchNorm2D-13     [[1, 128, 56, 56]]    [1, 128, 56, 56]         512      
      ReLU-5         [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-14       [[1, 128, 56, 56]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-14     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
     Conv2D-15       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-15     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
     Conv2D-12       [[1, 256, 56, 56]]    [1, 512, 28, 28]       131,072    
  BatchNorm2D-12     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
 BottleneckBlock-4   [[1, 256, 56, 56]]    [1, 512, 28, 28]          0       
     Conv2D-16       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-16     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-6         [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-17       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-17     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
     Conv2D-18       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-18     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
 BottleneckBlock-5   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-19       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-19     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-7         [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-20       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-20     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
     Conv2D-21       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-21     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
 BottleneckBlock-6   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-22       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-22     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-8         [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-23       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-23     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
     Conv2D-24       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-24     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
 BottleneckBlock-7   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-26       [[1, 512, 28, 28]]    [1, 256, 28, 28]       131,072    
  BatchNorm2D-26     [[1, 256, 28, 28]]    [1, 256, 28, 28]        1,024     
      ReLU-9        [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-27       [[1, 256, 28, 28]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-27     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-28       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-28    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
     Conv2D-25       [[1, 512, 28, 28]]   [1, 1024, 14, 14]       524,288    
  BatchNorm2D-25    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
 BottleneckBlock-8   [[1, 512, 28, 28]]   [1, 1024, 14, 14]          0       
     Conv2D-29      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-29     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-10       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-30       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-30     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-31       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-31    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
 BottleneckBlock-9  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-32      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-32     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-11       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-33       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-33     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-34       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-34    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-10  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-35      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-35     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-12       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-36       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-36     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-37       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-37    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-11  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-38      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-38     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-13       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-39       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-39     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-40       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-40    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-12  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-41      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-41     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-14       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-42       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-42     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-43       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-43    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-13  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-44      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-44     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-15       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-45       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-45     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-46       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-46    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-14  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-47      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-47     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-16       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-48       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-48     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-49       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-49    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-15  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-50      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-50     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-17       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-51       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-51     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-52       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-52    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-16  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-53      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-53     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-18       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-54       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-54     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-55       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-55    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-17  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-56      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-56     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-19       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-57       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-57     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-58       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-58    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-18  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-59      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-59     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-20       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-60       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-60     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-61       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-61    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-19  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-62      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-62     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-21       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-63       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-63     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-64       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-64    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-20  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-65      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-65     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-22       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-66       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-66     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-67       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-67    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-21  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-68      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-68     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-23       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-69       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-69     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-70       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-70    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-22  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-71      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-71     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-24       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-72       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-72     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-73       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-73    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-23  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-74      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-74     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-25       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-75       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-75     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-76       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-76    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-24  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-77      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-77     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-26       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-78       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-78     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-79       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-79    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-25  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-80      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-80     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-27       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-81       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-81     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-82       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-82    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-26  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-83      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-83     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-28       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-84       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-84     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-85       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-85    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-27  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-86      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-86     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-29       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-87       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-87     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-88       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-88    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-28  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-89      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-89     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-30       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-90       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-90     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-91       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-91    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-29  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-92      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-92     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-31       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-93       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-93     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
     Conv2D-94       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-94    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-30  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-96      [[1, 1024, 14, 14]]    [1, 512, 14, 14]       524,288    
  BatchNorm2D-96     [[1, 512, 14, 14]]    [1, 512, 14, 14]        2,048     
      ReLU-32        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
     Conv2D-97       [[1, 512, 14, 14]]     [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-97      [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
     Conv2D-98        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-98     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
     Conv2D-95      [[1, 1024, 14, 14]]    [1, 2048, 7, 7]       2,097,152   
  BatchNorm2D-95     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
BottleneckBlock-31  [[1, 1024, 14, 14]]    [1, 2048, 7, 7]           0       
     Conv2D-99       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-99      [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-33        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-100        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-100     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-101        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-101    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
BottleneckBlock-32   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-102       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-102     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-34        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-103        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-103     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-104        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-104    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
BottleneckBlock-33   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
AdaptiveAvgPool2D-1  [[1, 2048, 7, 7]]     [1, 2048, 1, 1]           0       
     Linear-1           [[1, 2048]]           [1, 1000]          2,049,000   
     ResNet-1        [[1, 3, 224, 224]]       [1, 1000]              0       
     Linear-2           [[1, 1000]]            [1, 512]           512,512    
      ReLU-35            [[1, 512]]            [1, 512]              0       
     Linear-3            [[1, 512]]            [1, 136]           69,768     
===============================================================================
Total params: 45,236,784
Trainable params: 45,026,096
Non-trainable params: 210,688
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 391.66
Params size (MB): 172.56
Estimated Total Size (MB): 564.80
-------------------------------------------------------------------------------






{'total_params': 45236784, 'trainable_params': 45026096}

四、模型训练

4.1 模型配置

训练模型前,需要设置训练模型所需的优化器,损失函数和评估指标。

  • 优化器:Adam优化器,快速收敛。
  • 损失函数:SmoothL1Loss
  • 评估指标:NME

4.2 自定义评估指标

特定任务的 Metric 计算方式在框架既有的 Metric接口中不存在,或算法不符合自己的需求,那么需要我们自己来进行Metric的自定义。这里介绍如何进行Metric的自定义操作,更多信息可以参考官网文档自定义Metric;首先来看下面的代码。


from paddle.metric import Metric

class NME(Metric):
    """
    1. 继承paddle.metric.Metric
    """
    def __init__(self, name='nme', *args, **kwargs):
        """
        2. 构造函数实现,自定义参数即可
        """
        super(NME, self).__init__(*args, **kwargs)
        self._name = name
        self.rmse = 0
        self.sample_num = 0
    
    def name(self):
        """
        3. 实现name方法,返回定义的评估指标名字
        """
        return self._name
    
    def update(self, preds, labels):
        """
        4. 实现update方法,用于单个batch训练时进行评估指标计算。
        - 当`compute`类函数未实现时,会将模型的计算输出和标签数据的展平作为`update`的参数传入。
        """
        N = preds.shape[0]

        preds = preds.reshape((N, -1, 2))
        labels = labels.reshape((N, -1, 2))

        self.rmse = 0
        
        for i in range(N):
            pts_pred, pts_gt = preds[i, ], labels[i, ]
            interocular = np.linalg.norm(pts_gt[36, ] - pts_gt[45, ])

            self.rmse += np.sum(np.linalg.norm(pts_pred - pts_gt, axis=1)) / (interocular * preds.shape[1])
            self.sample_num += 1

        return self.rmse / N
    
    def accumulate(self):
        """
        5. 实现accumulate方法,返回历史batch训练积累后计算得到的评价指标值。
        每次`update`调用时进行数据积累,`accumulate`计算时对积累的所有数据进行计算并返回。
        结算结果会在`fit`接口的训练日志中呈现。
        """
        return self.rmse / self.sample_num
    
    def reset(self):
        """
        6. 实现reset方法,每个Epoch结束后进行评估指标的重置,这样下个Epoch可以重新进行计算。
        """
        self.rmse = 0
        self.sample_num = 0

作业4:实现模型的配置和训练

# 使用 paddle.Model 封装模型
model = paddle.Model(SimpleNet(key_pts=68))

# 定义Adam优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.001,
                                weight_decay=5e-4,
                                parameters=model.parameters())

# 定义SmoothL1Loss
loss = nn.SmoothL1Loss()

# 使用自定义metrics
metric = NME()

# 配置模型
model.prepare(optimizer=optimizer, loss=loss, metrics=metric)

# 模型训练
# model.fit(train_dataset, epochs=100, batch_size=128, verbose=1)

损失函数的选择:L1Loss、L2Loss、SmoothL1Loss的对比

  • L1Loss: 在训练后期,预测值与ground-truth差异较小时, 损失对预测值的导数的绝对值仍然为1,此时如果学习率不变,损失函数将在稳定值附近波动,难以继续收敛达到更高精度。
  • L2Loss: 在训练初期,预测值与ground-truth差异较大时,损失函数对预测值的梯度十分大,导致训练不稳定。
  • SmoothL1Loss: 在x较小时,对x梯度也会变小,而在x很大时,对x的梯度的绝对值达到上限 1,也不会太大以至于破坏网络参数。

4.2 模型训练

 model.fit(train_dataset, epochs=100, batch_size=64, verbose=1)

4.3 模型保存

checkpoints_path = './checkpoints/models'
model.save(checkpoints_path)

五、模型预测

# 定义功能函数

def show_all_keypoints(image, predicted_key_pts):
    """
    展示图像,预测关键点
    Args:
        image:裁剪后的图像 [224, 224, 3]
        predicted_key_pts: 预测关键点的坐标
    """
    # 展示图像
    plt.imshow(image.astype('uint8'))

    # 展示关键点
    for i in range(0, len(predicted_key_pts), 2):
        plt.scatter(predicted_key_pts[i], predicted_key_pts[i+1], s=20, marker='.', c='m')

def visualize_output(test_images, test_outputs, batch_size=1, h=20, w=10):
    """
    展示图像,预测关键点
    Args:
        test_images:裁剪后的图像 [224, 224, 3]
        test_outputs: 模型的输出
        batch_size: 批大小
        h: 展示的图像高
        w: 展示的图像宽
    """

    if len(test_images.shape) == 3:
        test_images = np.array([test_images])

    for i in range(batch_size):

        plt.figure(figsize=(h, w))
        ax = plt.subplot(1, batch_size, i+1)

        # 随机裁剪后的图像
        image = test_images[i]

        # 模型的输出,未还原的预测关键点坐标值
        predicted_key_pts = test_outputs[i]

        # 还原后的真实的关键点坐标值
        predicted_key_pts = predicted_key_pts * data_std + data_mean
        
        # 展示图像和关键点
        show_all_keypoints(np.squeeze(image), predicted_key_pts)
            
        plt.axis('off')

    plt.show()
# 读取图像
img = mpimg.imread('1.jpg')

# 关键点占位符
kpt = np.ones((136, 1))

transform = Compose([Resize(256), RandomCrop(224)])

# 对图像先重新定义大小,并裁剪到 224*224的大小
rgb_img, kpt = transform([img, kpt])

norm = GrayNormalize()
to_chw = ToCHW()

# 对图像进行归一化和格式变换
img, kpt = norm([rgb_img, kpt])
img, kpt = to_chw([img, kpt])

img = np.array([img], dtype='float32')

# 加载保存好的模型进行预测
model = paddle.Model(SimpleNet(key_pts=68))
model.load(checkpoints_path)
model.prepare()

# 预测结果
out = model.predict_batch([img])
out = out[0].reshape((out[0].shape[0], 136, -1))

# 可视化
visualize_output(rgb_img, out, batch_size=1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9SjXIt7T-1612608269271)(output_46_0.png)]

# 读取图像
img = mpimg.imread('2.jpg')

# 关键点占位符
kpt = np.ones((136, 1))

transform = Compose([Resize(256), RandomCrop(224)])

# 对图像先重新定义大小,并裁剪到 224*224的大小
rgb_img, kpt = transform([img, kpt])

norm = GrayNormalize()
to_chw = ToCHW()

# 对图像进行归一化和格式变换
img, kpt = norm([rgb_img, kpt])
img, kpt = to_chw([img, kpt])

img = np.array([img], dtype='float32')

# 加载保存好的模型进行预测
model = paddle.Model(SimpleNet(key_pts=68))
model.load(checkpoints_path)
model.prepare()

# 预测结果
out = model.predict_batch([img])
out = out[0].reshape((out[0].shape[0], 136, -1))

# 可视化
visualize_output(rgb_img, out, batch_size=1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Sh3S6sKl-1612608269272)(output_47_0.png)]


六、趣味应用

当我们得到关键点的信息后,就可以进行一些趣味的应用。

# 定义功能函数

def show_fu(image, predicted_key_pts):
    """
    展示加了贴纸的图像
    Args:
        image:裁剪后的图像 [224, 224, 3]
        predicted_key_pts: 预测关键点的坐标
    """
    # 计算坐标,15 和 34点的中间值
    x = (int(predicted_key_pts[28]) + int(predicted_key_pts[66]))//2
    y = (int(predicted_key_pts[29]) + int(predicted_key_pts[67]))//2

    # 打开 春节小图
    star_image = mpimg.imread('light.jpg')

    # 处理通道
    if(star_image.shape[2] == 4):
        star_image = star_image[:,:,1:4]
    
    # 将春节小图放到原图上
    image[y:y+len(star_image[0]), x:x+len(star_image[1]),:] = star_image
    
    # 展示处理后的图片
    plt.imshow(image.astype('uint8'))

    # 展示关键点信息
    for i in range(len(predicted_key_pts)//2,):
        plt.scatter(predicted_key_pts[i*2], predicted_key_pts[i*2+1], s=20, marker='.', c='m') # 展示关键点信息


def custom_output(test_images, test_outputs, batch_size=1, h=20, w=10):
    """
    展示图像,预测关键点
    Args:
        test_images:裁剪后的图像 [224, 224, 3]
        test_outputs: 模型的输出
        batch_size: 批大小
        h: 展示的图像高
        w: 展示的图像宽
    """

    if len(test_images.shape) == 3:
        test_images = np.array([test_images])

    for i in range(batch_size):

        plt.figure(figsize=(h, w))
        ax = plt.subplot(1, batch_size, i+1)

        # 随机裁剪后的图像
        image = test_images[i]

        # 模型的输出,未还原的预测关键点坐标值
        predicted_key_pts = test_outputs[i]

        # 还原后的真实的关键点坐标值
        predicted_key_pts = predicted_key_pts * data_std + data_mean
        
        # 展示图像和关键点
        show_fu(np.squeeze(image), predicted_key_pts)
            
        plt.axis('off')

    plt.show()

# 读取图像
img = mpimg.imread('xiaojiejie.jpg')

# 关键点占位符
kpt = np.ones((136, 1))

transform = Compose([Resize(256), RandomCrop(224)])

# 对图像先重新定义大小,并裁剪到 224*224的大小
rgb_img, kpt = transform([img, kpt])

norm = GrayNormalize()
to_chw = ToCHW()

# 对图像进行归一化和格式变换
img, kpt = norm([rgb_img, kpt])
img, kpt = to_chw([img, kpt])

img = np.array([img], dtype='float32')

# 加载保存好的模型进行预测
# model = paddle.Model(SimpleNet())
# model.load(checkpoints_path)
# model.prepare()

# 预测结果
out = model.predict_batch([img])
out = out[0].reshape((out[0].shape[0], 136, -1))

# 可视化
custom_output(rgb_img, out, batch_size=1)

你可能感兴趣的:(paddlepaddle)