人脸识别0-04:insightFace-模型训练注释详解-史上最全

以下链接是个人关于insightFace所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。
人脸识别0-00:insightFace目录:https://blog.csdn.net/weixin_43013761/article/details/99646731:

版本更替

在作者发布初始的版本中,使用的是insightface-master\src下面的代码进行训练的,本人使用的是暂时最新的版本,在insightface-master\recognition目录下面,不知道当你看到这篇博客的时候,源码的作者是否又发布了新的版本,不过没关系,在上述的链接中,给出了本人的代码,下面我们开始讲解insightface-master\recognition\train.py,该是训练的核心代码,通过前面的博客我们拷贝了以分sample_config.py为config.py,该文件主要为模型训练提供了一系列的配置。

config.py

import numpy as np
import os
from easydict import EasyDict as edict

# config配置是最基本的配置,如果后面出现相同的,则被覆盖
config = edict()

config.bn_mom = 0.9 # 反向传播的momentum
config.workspace = 256 # mxnet需要的缓冲空间
config.emb_size = 128 #  输出特征向量的维度
config.ckpt_embedding = True # 是否检测输出的特征向量
config.net_se = 0 # 暂时不知道
config.net_act = 'prelu' # 激活函数
config.net_unit = 3 #
config.net_input = 1 #
config.net_blocks = [1,4,6,2]
config.net_output = 'E' # 输出层,链接层的类型,如"GDC"也是其中一种,具体查看recognition\symbol\symbol_utils.py
config.net_multiplier = 1.0
config.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] # 测试数据,即.bin为后缀的文件
config.ce_loss = True #Focal loss,一种改进的交叉损失熵
config.fc7_lr_mult = 1.0 # 学习率的倍数
config.fc7_wd_mult = 1.0 # 权重刷衰减的倍数
config.fc7_no_bias = False #
config.max_steps = 0 # 训练的最大步骤吧,感觉有点懵逼,不过不影响大局
config.data_rand_mirror = True # 数据随机进行镜像翻转
config.data_cutoff = False # 数据进行随机裁剪
config.data_color = 0 # 估计是数据进行彩色增强
config.data_images_filter = 0 # 暂时不知道
config.count_flops = True # 是否计算一个网络占用的浮点数内存
config.memonger = False #not work now



# 可以看到很多的网络结构,就不为大家一一注释了
# 因为我也没有把每个网络都弄得很透彻,可以看到有很多网络结构,在训练的时候我们都是可以选择的
# r100 r100fc
# network settings r50 r50v1 d169 d201 y1 m1 m05 mnas mnas025
network = edict()

network.r100 = edict()
network.r100.net_name = 'fresnet'
network.r100.num_layers = 100

network.r100fc = edict()
network.r100fc.net_name = 'fresnet'
network.r100fc.num_layers = 100
network.r100fc.net_output = 'FC'

network.r50 = edict()
network.r50.net_name = 'fresnet'
network.r50.num_layers = 50

network.r50v1 = edict()
network.r50v1.net_name = 'fresnet'
network.r50v1.num_layers = 50
network.r50v1.net_unit = 1

network.d169 = edict()
network.d169.net_name = 'fdensenet'
network.d169.num_layers = 169
network.d169.per_batch_size = 64
network.d169.densenet_dropout = 0.0

network.d201 = edict()
network.d201.net_name = 'fdensenet'
network.d201.num_layers = 201
network.d201.per_batch_size = 64
network.d201.densenet_dropout = 0.0

network.y1 = edict()
network.y1.net_name = 'fmobilefacenet'
network.y1.emb_size = 128
network.y1.net_output = 'GDC'

network.y2 = edict()
network.y2.net_name = 'fmobilefacenet'
network.y2.emb_size = 256
network.y2.net_output = 'GDC'
network.y2.net_blocks = [2,8,16,4]

network.m1 = edict()
network.m1.net_name = 'fmobilenet'
network.m1.emb_size = 256
network.m1.net_output = 'GDC'
network.m1.net_multiplier = 1.0

network.m05 = edict()
network.m05.net_name = 'fmobilenet'
network.m05.emb_size = 256
network.m05.net_output = 'GDC'
network.m05.net_multiplier = 0.5

network.mnas = edict()
network.mnas.net_name = 'fmnasnet'
network.mnas.emb_size = 256
network.mnas.net_output = 'GDC'
network.mnas.net_multiplier = 1.0

network.mnas05 = edict()
network.mnas05.net_name = 'fmnasnet'
network.mnas05.emb_size = 256
network.mnas05.net_output = 'GDC'
network.mnas05.net_multiplier = 0.5

network.mnas025 = edict()
network.mnas025.net_name = 'fmnasnet'
network.mnas025.emb_size = 256
network.mnas025.net_output = 'GDC'
network.mnas025.net_multiplier = 0.25



# 可以看到存在emore与retina两个数据集,训练的时候我们只能指定一个。
# num_classes来自property,为人脸id数目,为了能够较好的拟合数据
# dataset settings
dataset = edict()

dataset.emore = edict()
dataset.emore.dataset = 'emore'
dataset.emore.dataset_path = '../../../2.dataset/1.officialData/1.traindata/faces_glint'
dataset.emore.num_classes = 180855
dataset.emore.image_shape = (112,112,3)
dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30']

dataset.retina = edict()
dataset.retina.dataset = 'retina'
dataset.retina.dataset_path = '../datasets/ms1m-retinaface-t1'
dataset.retina.num_classes = 93431
dataset.retina.image_shape = (112,112,3)
dataset.retina.val_targets = ['lfw', 'cfp_fp', 'agedb_30']


# 损失函数是我们的重点,大家看了之后,不要觉得太复杂,
# loss_m1,loss_m2,loss_m3,其出现3个m,作者是为了减少代码量,把多个损失函数合并在一起了
# 即nsoftmax,arcface,cosface,combined
loss = edict()
loss.softmax = edict()
loss.softmax.loss_name = 'softmax'

loss.nsoftmax = edict()
loss.nsoftmax.loss_name = 'margin_softmax'
loss.nsoftmax.loss_s = 64.0
loss.nsoftmax.loss_m1 = 1.0
loss.nsoftmax.loss_m2 = 0.0
loss.nsoftmax.loss_m3 = 0.0

loss.arcface = edict()
loss.arcface.loss_name = 'margin_softmax'
loss.arcface.loss_s = 64.0
loss.arcface.loss_m1 = 1.0
loss.arcface.loss_m2 = 0.5
loss.arcface.loss_m3 = 0.0

loss.cosface = edict()
loss.cosface.loss_name = 'margin_softmax'
loss.cosface.loss_s = 64.0
loss.cosface.loss_m1 = 1.0
loss.cosface.loss_m2 = 0.0
loss.cosface.loss_m3 = 0.35

loss.combined = edict()
loss.combined.loss_name = 'margin_softmax'
loss.combined.loss_s = 64.0
loss.combined.loss_m1 = 1.0
loss.combined.loss_m2 = 0.3
loss.combined.loss_m3 = 0.2

loss.triplet = edict()
loss.triplet.loss_name = 'triplet'
loss.triplet.images_per_identity = 5
loss.triplet.triplet_alpha = 0.3
loss.triplet.triplet_bag_size = 7200
loss.triplet.triplet_max_ap = 0.0
loss.triplet.per_batch_size = 60
loss.triplet.lr = 0.05

loss.atriplet = edict()
loss.atriplet.loss_name = 'atriplet'
loss.atriplet.images_per_identity = 5
loss.atriplet.triplet_alpha = 0.35
loss.atriplet.triplet_bag_size = 7200
loss.atriplet.triplet_max_ap = 0.0
loss.atriplet.per_batch_size = 60
loss.atriplet.lr = 0.05

# default settings
default = edict()

# default network
default.network = 'r100'
#default.pretrained = ''
default.pretrained = '../models/model-y1-test2/model'
default.pretrained_epoch = 0
# default dataset
default.dataset = 'emore'
default.loss = 'arcface'
default.frequent = 20 # 每20个批次打印一次准确率等log
default.verbose = 2000 # 每训练2000次,对验证数据进行一次评估
default.kvstore = 'device' #键值存储

default.end_epoch = 10000 # 结束的epoch
default.lr = 0.01 # 初始学习率,如果每个批次训练的数目小,学习率也相应的降低
default.wd = 0.0005 # 大概是权重初始化波动的范围
default.mom = 0.9
default.per_batch_size = 48 # 每存在一个GPU,训练48个批次,如两个GPU,则实际训练的batch_size为96
default.ckpt = 0 #
default.lr_steps = '100000,160000,220000'  # 每达到步数,学习率变为原来的百分之十
default.models_root = './models' # 模型保存的位置


# 对config = edict()进行更新
def generate_config(_network, _dataset, _loss):
    for k, v in loss[_loss].items():
      config[k] = v
      if k in default:
        default[k] = v
    for k, v in network[_network].items():
      config[k] = v
      if k in default:
        default[k] = v
    for k, v in dataset[_dataset].items():
      config[k] = v
      if k in default:
        default[k] = v
    config.loss = _loss
    config.network = _network
    config.dataset = _dataset
    config.num_workers = 1
    if 'DMLC_NUM_WORKER' in os.environ:
      config.num_workers = int(os.environ['DMLC_NUM_WORKER'])

注释也花了少些心思,如果对你有帮助希望能点个赞,这是对我最大的鼓励,下面再为大家贴出insightface-master\recognition\train.py代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import random
import logging
import sklearn
import pickle
import numpy as np
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
from config import config, default, generate_config
from metric import *

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import flops_counter

sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
import verification

sys.path.append(os.path.join(os.path.dirname(__file__), 'symbol'))
import fresnet
import fmobilefacenet
import fmobilenet
import fmnasnet
import fdensenet

print(mx.__file__)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

args = None


def parse_args():
    parser = argparse.ArgumentParser(description='Train face network')

    # general
    # 训练的数据集默认配置
    parser.add_argument('--dataset', default=default.dataset, help='dataset config')

    # 默认网络结构选择
    parser.add_argument('--network', default=default.network, help='network config')

    # 使用默认损失函数
    parser.add_argument('--loss', default=default.loss, help='loss config')

    # 参数解析
    args, rest = parser.parse_known_args()
    generate_config(args.network, args.dataset, args.loss)

    # 模型保存的目录
    parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.')

    # 预训练模型加载
    parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load')

    # 指定与训练模型训练的epoch数
    parser.add_argument('--pretrained-epoch', type=int, default=default.pretrained_epoch,
                        help='pretrained epoch to load')

    # 是否保存ckpt文件
    parser.add_argument('--ckpt', type=int, default=default.ckpt,
                        help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')

    # 验证每verbose个批次进行一次验证
    parser.add_argument('--verbose', type=int, default=default.verbose,
                        help='do verification testing and model saving every verbose batches')

    # 学习率
    parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate')

    parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing')

    # 学习率衰减的权重
    parser.add_argument('--wd', type=float, default=default.wd, help='weight decay')

    # 梯度下降的动能
    parser.add_argument('--mom', type=float, default=default.mom, help='momentum')

    parser.add_argument('--frequent', type=int, default=default.frequent, help='')

    # 每个GPU没批次训练的样本数目
    parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context')

    # 键值存储的设置
    parser.add_argument('--kvstore', type=str, default=default.kvstore, help='kvstore setting')
    args = parser.parse_args()
    return args


def get_symbol(args):
    # 获得一个特征向量
    embedding = eval(config.net_name).get_symbol()

    # 定义一个标签的占位符,用来存放标签
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    is_softmax = True

    # 如果损失函数为softmax
    if config.loss_name == 'softmax':
        # 定义一个全连接层的权重
        _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
                                     lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))

        # 如果不设置bias,则直接进行全链接
        if config.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,
                                        name='fc7')
        # 如果设置_bias,则创建_bias之后进行全连接
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=config.num_classes,
                                        name='fc7')
    # 如果损失函数为margin_softmax
    elif config.loss_name == 'margin_softmax':
        # 创建一个权重占位符
        _weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
                                     lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))


        # 获得loss中m的缩放系数
        s = config.loss_s

        # 先进行L2正则化,然后进行全链接
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,
                                    name='fc7')

        # 其存在m1,m2,m3是为了把算法整合在一起
        if config.loss_m1 != 1.0 or config.loss_m2 != 0.0 or config.loss_m3 != 0.0:
            if config.loss_m1 == 1.0 and config.loss_m2 == 0.0:
                s_m = s * config.loss_m3
                gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=s_m, off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if config.loss_m1 != 1.0:
                    t = t * config.loss_m1
                if config.loss_m2 > 0.0:
                    t = t + config.loss_m2
                body = mx.sym.cos(t)
                if config.loss_m3 > 0.0:
                    body = body - config.loss_m3
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=1.0, off_value=0.0)

                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    # 如果损失函数为triplet
    elif config.loss_name.find('triplet') >= 0:
        is_softmax = False
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
        anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3)
        positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3,
                                        end=2 * args.per_batch_size // 3)
        negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size)
        if config.loss_name == 'triplet':
            ap = anchor - positive
            an = anchor - negative
            ap = ap * ap
            an = an * an
            ap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)
            an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)
            triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')
            triplet_loss = mx.symbol.mean(triplet_loss)
        else:
            ap = anchor * positive
            an = anchor * negative
            ap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)
            an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)
            ap = mx.sym.arccos(ap)
            an = mx.sym.arccos(an)
            triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')
            triplet_loss = mx.symbol.mean(triplet_loss)
        triplet_loss = mx.symbol.MakeLoss(triplet_loss)
    out_list = [mx.symbol.BlockGrad(embedding)]

    # 如果使用了softmax
    if is_softmax:
        softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')
        out_list.append(softmax)
        if config.ce_loss:
            # ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
            body = mx.symbol.SoftmaxActivation(data=fc7)
            body = mx.symbol.log(body)
            _label = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=-1.0, off_value=0.0)
            body = body * _label
            ce_loss = mx.symbol.sum(body) / args.per_batch_size
            out_list.append(mx.symbol.BlockGrad(ce_loss))
    # 如果是triplet
    else:
        out_list.append(mx.sym.BlockGrad(gt_label))
        out_list.append(triplet_loss)

    # 聚集所有的符号
    out = mx.symbol.Group(out_list)
    return out


def train_net(args):
    # 判断使用GPU还是CPU
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))

    # 保存模型的前缀
    prefix = os.path.join(args.models_root, '%s-%s-%s' % (args.network, args.loss, args.dataset), 'model')
    # 保存模型的路径
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)

    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)

    # GPU的数目
    args.ctx_num = len(ctx)

    # 计算总batch_size
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0

    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    # 每个GPU一个批次的大小
    config.per_batch_size = args.per_batch_size

    # 训练数据的目录
    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None

    # 图片大小以及验证
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)

    # 数据集id数目
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0

    # 判断预训练模型是否存在,如果不存在,初始化权重
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym = get_symbol(args)  # 模型构建
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size,) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  # 如果存在,则加载模型
        print('loading', args.pretrained, args.pretrained_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    # 浮点型数据占用空间计算
    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym, data=(1, 3, image_size[0], image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        print('Network FLOPs: %s' % _str)

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=mx.gpu(),
        symbol=sym,
    )
    val_dataiter = None

    # 主要获取数据的迭代器,triplet与sfotmax输入数据的迭代器是不一样的,具体哪里不一样,后续章节为大家分析
    if config.loss_name.find('triplet') >= 0:
        from triplet_image_iter import FaceImageIter
        triplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))

    # 对权重进行初始化
    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    # 加载所有测试数据集
    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    # 对测试集进行测试
    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,
                                                                               None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    # 最高的准曲率
    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

   
    def _batch_callback(param):
        # global global_step
        
        global_step[0] += 1
        mbatch = global_step[0]
        # 降低学习率到原来的十分之一
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        # 每1000批次进行一次打印
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        # 进行
        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                # lfw_score = acc_list[0]
                # if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    # if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1
            
            # 模型保存
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    # 把train_dataiter转化为mx.ioPrefetchingIter迭代器
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=999999,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore=args.kvstore,
              optimizer=opt,
              # optimizer_params   = optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)


def main():
    global args
    args = parse_args()
    train_net(args)


if __name__ == '__main__':
    main()

以上除了def get_symbol(args)函数没有详细注释外,其他基本注释完成,该函数涉及到损失函数,比较复杂,下小节为大家详细讲解。记得关注点赞熬。

你可能感兴趣的:(人脸技术)