以下链接是个人关于insightFace所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:a944284742相互讨论技术。
人脸识别0-00:insightFace目录:https://blog.csdn.net/weixin_43013761/article/details/99646731:
在作者发布初始的版本中,使用的是insightface-master\src下面的代码进行训练的,本人使用的是暂时最新的版本,在insightface-master\recognition目录下面,不知道当你看到这篇博客的时候,源码的作者是否又发布了新的版本,不过没关系,在上述的链接中,给出了本人的代码,下面我们开始讲解insightface-master\recognition\train.py,该是训练的核心代码,通过前面的博客我们拷贝了以分sample_config.py为config.py,该文件主要为模型训练提供了一系列的配置。
import numpy as np
import os
from easydict import EasyDict as edict
# config配置是最基本的配置,如果后面出现相同的,则被覆盖
config = edict()
config.bn_mom = 0.9 # 反向传播的momentum
config.workspace = 256 # mxnet需要的缓冲空间
config.emb_size = 128 # 输出特征向量的维度
config.ckpt_embedding = True # 是否检测输出的特征向量
config.net_se = 0 # 暂时不知道
config.net_act = 'prelu' # 激活函数
config.net_unit = 3 #
config.net_input = 1 #
config.net_blocks = [1,4,6,2]
config.net_output = 'E' # 输出层,链接层的类型,如"GDC"也是其中一种,具体查看recognition\symbol\symbol_utils.py
config.net_multiplier = 1.0
config.val_targets = ['lfw', 'cfp_fp', 'agedb_30'] # 测试数据,即.bin为后缀的文件
config.ce_loss = True #Focal loss,一种改进的交叉损失熵
config.fc7_lr_mult = 1.0 # 学习率的倍数
config.fc7_wd_mult = 1.0 # 权重刷衰减的倍数
config.fc7_no_bias = False #
config.max_steps = 0 # 训练的最大步骤吧,感觉有点懵逼,不过不影响大局
config.data_rand_mirror = True # 数据随机进行镜像翻转
config.data_cutoff = False # 数据进行随机裁剪
config.data_color = 0 # 估计是数据进行彩色增强
config.data_images_filter = 0 # 暂时不知道
config.count_flops = True # 是否计算一个网络占用的浮点数内存
config.memonger = False #not work now
# 可以看到很多的网络结构,就不为大家一一注释了
# 因为我也没有把每个网络都弄得很透彻,可以看到有很多网络结构,在训练的时候我们都是可以选择的
# r100 r100fc
# network settings r50 r50v1 d169 d201 y1 m1 m05 mnas mnas025
network = edict()
network.r100 = edict()
network.r100.net_name = 'fresnet'
network.r100.num_layers = 100
network.r100fc = edict()
network.r100fc.net_name = 'fresnet'
network.r100fc.num_layers = 100
network.r100fc.net_output = 'FC'
network.r50 = edict()
network.r50.net_name = 'fresnet'
network.r50.num_layers = 50
network.r50v1 = edict()
network.r50v1.net_name = 'fresnet'
network.r50v1.num_layers = 50
network.r50v1.net_unit = 1
network.d169 = edict()
network.d169.net_name = 'fdensenet'
network.d169.num_layers = 169
network.d169.per_batch_size = 64
network.d169.densenet_dropout = 0.0
network.d201 = edict()
network.d201.net_name = 'fdensenet'
network.d201.num_layers = 201
network.d201.per_batch_size = 64
network.d201.densenet_dropout = 0.0
network.y1 = edict()
network.y1.net_name = 'fmobilefacenet'
network.y1.emb_size = 128
network.y1.net_output = 'GDC'
network.y2 = edict()
network.y2.net_name = 'fmobilefacenet'
network.y2.emb_size = 256
network.y2.net_output = 'GDC'
network.y2.net_blocks = [2,8,16,4]
network.m1 = edict()
network.m1.net_name = 'fmobilenet'
network.m1.emb_size = 256
network.m1.net_output = 'GDC'
network.m1.net_multiplier = 1.0
network.m05 = edict()
network.m05.net_name = 'fmobilenet'
network.m05.emb_size = 256
network.m05.net_output = 'GDC'
network.m05.net_multiplier = 0.5
network.mnas = edict()
network.mnas.net_name = 'fmnasnet'
network.mnas.emb_size = 256
network.mnas.net_output = 'GDC'
network.mnas.net_multiplier = 1.0
network.mnas05 = edict()
network.mnas05.net_name = 'fmnasnet'
network.mnas05.emb_size = 256
network.mnas05.net_output = 'GDC'
network.mnas05.net_multiplier = 0.5
network.mnas025 = edict()
network.mnas025.net_name = 'fmnasnet'
network.mnas025.emb_size = 256
network.mnas025.net_output = 'GDC'
network.mnas025.net_multiplier = 0.25
# 可以看到存在emore与retina两个数据集,训练的时候我们只能指定一个。
# num_classes来自property,为人脸id数目,为了能够较好的拟合数据
# dataset settings
dataset = edict()
dataset.emore = edict()
dataset.emore.dataset = 'emore'
dataset.emore.dataset_path = '../../../2.dataset/1.officialData/1.traindata/faces_glint'
dataset.emore.num_classes = 180855
dataset.emore.image_shape = (112,112,3)
dataset.emore.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
dataset.retina = edict()
dataset.retina.dataset = 'retina'
dataset.retina.dataset_path = '../datasets/ms1m-retinaface-t1'
dataset.retina.num_classes = 93431
dataset.retina.image_shape = (112,112,3)
dataset.retina.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
# 损失函数是我们的重点,大家看了之后,不要觉得太复杂,
# loss_m1,loss_m2,loss_m3,其出现3个m,作者是为了减少代码量,把多个损失函数合并在一起了
# 即nsoftmax,arcface,cosface,combined
loss = edict()
loss.softmax = edict()
loss.softmax.loss_name = 'softmax'
loss.nsoftmax = edict()
loss.nsoftmax.loss_name = 'margin_softmax'
loss.nsoftmax.loss_s = 64.0
loss.nsoftmax.loss_m1 = 1.0
loss.nsoftmax.loss_m2 = 0.0
loss.nsoftmax.loss_m3 = 0.0
loss.arcface = edict()
loss.arcface.loss_name = 'margin_softmax'
loss.arcface.loss_s = 64.0
loss.arcface.loss_m1 = 1.0
loss.arcface.loss_m2 = 0.5
loss.arcface.loss_m3 = 0.0
loss.cosface = edict()
loss.cosface.loss_name = 'margin_softmax'
loss.cosface.loss_s = 64.0
loss.cosface.loss_m1 = 1.0
loss.cosface.loss_m2 = 0.0
loss.cosface.loss_m3 = 0.35
loss.combined = edict()
loss.combined.loss_name = 'margin_softmax'
loss.combined.loss_s = 64.0
loss.combined.loss_m1 = 1.0
loss.combined.loss_m2 = 0.3
loss.combined.loss_m3 = 0.2
loss.triplet = edict()
loss.triplet.loss_name = 'triplet'
loss.triplet.images_per_identity = 5
loss.triplet.triplet_alpha = 0.3
loss.triplet.triplet_bag_size = 7200
loss.triplet.triplet_max_ap = 0.0
loss.triplet.per_batch_size = 60
loss.triplet.lr = 0.05
loss.atriplet = edict()
loss.atriplet.loss_name = 'atriplet'
loss.atriplet.images_per_identity = 5
loss.atriplet.triplet_alpha = 0.35
loss.atriplet.triplet_bag_size = 7200
loss.atriplet.triplet_max_ap = 0.0
loss.atriplet.per_batch_size = 60
loss.atriplet.lr = 0.05
# default settings
default = edict()
# default network
default.network = 'r100'
#default.pretrained = ''
default.pretrained = '../models/model-y1-test2/model'
default.pretrained_epoch = 0
# default dataset
default.dataset = 'emore'
default.loss = 'arcface'
default.frequent = 20 # 每20个批次打印一次准确率等log
default.verbose = 2000 # 每训练2000次,对验证数据进行一次评估
default.kvstore = 'device' #键值存储
default.end_epoch = 10000 # 结束的epoch
default.lr = 0.01 # 初始学习率,如果每个批次训练的数目小,学习率也相应的降低
default.wd = 0.0005 # 大概是权重初始化波动的范围
default.mom = 0.9
default.per_batch_size = 48 # 每存在一个GPU,训练48个批次,如两个GPU,则实际训练的batch_size为96
default.ckpt = 0 #
default.lr_steps = '100000,160000,220000' # 每达到步数,学习率变为原来的百分之十
default.models_root = './models' # 模型保存的位置
# 对config = edict()进行更新
def generate_config(_network, _dataset, _loss):
for k, v in loss[_loss].items():
config[k] = v
if k in default:
default[k] = v
for k, v in network[_network].items():
config[k] = v
if k in default:
default[k] = v
for k, v in dataset[_dataset].items():
config[k] = v
if k in default:
default[k] = v
config.loss = _loss
config.network = _network
config.dataset = _dataset
config.num_workers = 1
if 'DMLC_NUM_WORKER' in os.environ:
config.num_workers = int(os.environ['DMLC_NUM_WORKER'])
注释也花了少些心思,如果对你有帮助希望能点个赞,这是对我最大的鼓励,下面再为大家贴出insightface-master\recognition\train.py代码:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import random
import logging
import sklearn
import pickle
import numpy as np
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer
from config import config, default, generate_config
from metric import *
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import flops_counter
sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
import verification
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbol'))
import fresnet
import fmobilefacenet
import fmobilenet
import fmnasnet
import fdensenet
print(mx.__file__)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
args = None
def parse_args():
parser = argparse.ArgumentParser(description='Train face network')
# general
# 训练的数据集默认配置
parser.add_argument('--dataset', default=default.dataset, help='dataset config')
# 默认网络结构选择
parser.add_argument('--network', default=default.network, help='network config')
# 使用默认损失函数
parser.add_argument('--loss', default=default.loss, help='loss config')
# 参数解析
args, rest = parser.parse_known_args()
generate_config(args.network, args.dataset, args.loss)
# 模型保存的目录
parser.add_argument('--models-root', default=default.models_root, help='root directory to save model.')
# 预训练模型加载
parser.add_argument('--pretrained', default=default.pretrained, help='pretrained model to load')
# 指定与训练模型训练的epoch数
parser.add_argument('--pretrained-epoch', type=int, default=default.pretrained_epoch,
help='pretrained epoch to load')
# 是否保存ckpt文件
parser.add_argument('--ckpt', type=int, default=default.ckpt,
help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')
# 验证每verbose个批次进行一次验证
parser.add_argument('--verbose', type=int, default=default.verbose,
help='do verification testing and model saving every verbose batches')
# 学习率
parser.add_argument('--lr', type=float, default=default.lr, help='start learning rate')
parser.add_argument('--lr-steps', type=str, default=default.lr_steps, help='steps of lr changing')
# 学习率衰减的权重
parser.add_argument('--wd', type=float, default=default.wd, help='weight decay')
# 梯度下降的动能
parser.add_argument('--mom', type=float, default=default.mom, help='momentum')
parser.add_argument('--frequent', type=int, default=default.frequent, help='')
# 每个GPU没批次训练的样本数目
parser.add_argument('--per-batch-size', type=int, default=default.per_batch_size, help='batch size in each context')
# 键值存储的设置
parser.add_argument('--kvstore', type=str, default=default.kvstore, help='kvstore setting')
args = parser.parse_args()
return args
def get_symbol(args):
# 获得一个特征向量
embedding = eval(config.net_name).get_symbol()
# 定义一个标签的占位符,用来存放标签
all_label = mx.symbol.Variable('softmax_label')
gt_label = all_label
is_softmax = True
# 如果损失函数为softmax
if config.loss_name == 'softmax':
# 定义一个全连接层的权重
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))
# 如果不设置bias,则直接进行全链接
if config.fc7_no_bias:
fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,
name='fc7')
# 如果设置_bias,则创建_bias之后进行全连接
else:
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=config.num_classes,
name='fc7')
# 如果损失函数为margin_softmax
elif config.loss_name == 'margin_softmax':
# 创建一个权重占位符
_weight = mx.symbol.Variable("fc7_weight", shape=(config.num_classes, config.emb_size),
lr_mult=config.fc7_lr_mult, wd_mult=config.fc7_wd_mult, init=mx.init.Normal(0.01))
# 获得loss中m的缩放系数
s = config.loss_s
# 先进行L2正则化,然后进行全链接
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=config.num_classes,
name='fc7')
# 其存在m1,m2,m3是为了把算法整合在一起
if config.loss_m1 != 1.0 or config.loss_m2 != 0.0 or config.loss_m3 != 0.0:
if config.loss_m1 == 1.0 and config.loss_m2 == 0.0:
s_m = s * config.loss_m3
gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=s_m, off_value=0.0)
fc7 = fc7 - gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy / s
t = mx.sym.arccos(cos_t)
if config.loss_m1 != 1.0:
t = t * config.loss_m1
if config.loss_m2 > 0.0:
t = t + config.loss_m2
body = mx.sym.cos(t)
if config.loss_m3 > 0.0:
body = body - config.loss_m3
new_zy = body * s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=1.0, off_value=0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7 + body
# 如果损失函数为triplet
elif config.loss_name.find('triplet') >= 0:
is_softmax = False
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3)
positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3,
end=2 * args.per_batch_size // 3)
negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size)
if config.loss_name == 'triplet':
ap = anchor - positive
an = anchor - negative
ap = ap * ap
an = an * an
ap = mx.symbol.sum(ap, axis=1, keepdims=1) # (T,1)
an = mx.symbol.sum(an, axis=1, keepdims=1) # (T,1)
triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')
triplet_loss = mx.symbol.mean(triplet_loss)
else:
ap = anchor * positive
an = anchor * negative
ap = mx.symbol.sum(ap, axis=1, keepdims=1) # (T,1)
an = mx.symbol.sum(an, axis=1, keepdims=1) # (T,1)
ap = mx.sym.arccos(ap)
an = mx.sym.arccos(an)
triplet_loss = mx.symbol.Activation(data=(ap - an + config.triplet_alpha), act_type='relu')
triplet_loss = mx.symbol.mean(triplet_loss)
triplet_loss = mx.symbol.MakeLoss(triplet_loss)
out_list = [mx.symbol.BlockGrad(embedding)]
# 如果使用了softmax
if is_softmax:
softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
if config.ce_loss:
# ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
body = mx.symbol.SoftmaxActivation(data=fc7)
body = mx.symbol.log(body)
_label = mx.sym.one_hot(gt_label, depth=config.num_classes, on_value=-1.0, off_value=0.0)
body = body * _label
ce_loss = mx.symbol.sum(body) / args.per_batch_size
out_list.append(mx.symbol.BlockGrad(ce_loss))
# 如果是triplet
else:
out_list.append(mx.sym.BlockGrad(gt_label))
out_list.append(triplet_loss)
# 聚集所有的符号
out = mx.symbol.Group(out_list)
return out
def train_net(args):
# 判断使用GPU还是CPU
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd) > 0:
for i in range(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx) == 0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
# 保存模型的前缀
prefix = os.path.join(args.models_root, '%s-%s-%s' % (args.network, args.loss, args.dataset), 'model')
# 保存模型的路径
prefix_dir = os.path.dirname(prefix)
print('prefix', prefix)
if not os.path.exists(prefix_dir):
os.makedirs(prefix_dir)
# GPU的数目
args.ctx_num = len(ctx)
# 计算总batch_size
args.batch_size = args.per_batch_size * args.ctx_num
args.rescale_threshold = 0
args.image_channel = config.image_shape[2]
config.batch_size = args.batch_size
# 每个GPU一个批次的大小
config.per_batch_size = args.per_batch_size
# 训练数据的目录
data_dir = config.dataset_path
path_imgrec = None
path_imglist = None
# 图片大小以及验证
image_size = config.image_shape[0:2]
assert len(image_size) == 2
assert image_size[0] == image_size[1]
print('image_size', image_size)
# 数据集id数目
print('num_classes', config.num_classes)
path_imgrec = os.path.join(data_dir, "train.rec")
print('Called with argument:', args, config)
data_shape = (args.image_channel, image_size[0], image_size[1])
mean = None
begin_epoch = 0
# 判断预训练模型是否存在,如果不存在,初始化权重
if len(args.pretrained) == 0:
arg_params = None
aux_params = None
sym = get_symbol(args) # 模型构建
if config.net_name == 'spherenet':
data_shape_dict = {'data': (args.per_batch_size,) + data_shape}
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
else: # 如果存在,则加载模型
print('loading', args.pretrained, args.pretrained_epoch)
_, arg_params, aux_params = mx.model.load_checkpoint(args.pretrained, args.pretrained_epoch)
sym = get_symbol(args)
# 浮点型数据占用空间计算
if config.count_flops:
all_layers = sym.get_internals()
_sym = all_layers['fc1_output']
FLOPs = flops_counter.count_flops(_sym, data=(1, 3, image_size[0], image_size[1]))
_str = flops_counter.flops_str(FLOPs)
print('Network FLOPs: %s' % _str)
# label_name = 'softmax_label'
# label_shape = (args.batch_size,)
model = mx.mod.Module(
context=mx.gpu(),
symbol=sym,
)
val_dataiter = None
# 主要获取数据的迭代器,triplet与sfotmax输入数据的迭代器是不一样的,具体哪里不一样,后续章节为大家分析
if config.loss_name.find('triplet') >= 0:
from triplet_image_iter import FaceImageIter
triplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap]
train_dataiter = FaceImageIter(
batch_size=args.batch_size,
data_shape=data_shape,
path_imgrec=path_imgrec,
shuffle=True,
rand_mirror=config.data_rand_mirror,
mean=mean,
cutoff=config.data_cutoff,
ctx_num=args.ctx_num,
images_per_identity=config.images_per_identity,
triplet_params=triplet_params,
mx_model=model,
)
_metric = LossValueMetric()
eval_metrics = [mx.metric.create(_metric)]
else:
from image_iter import FaceImageIter
train_dataiter = FaceImageIter(
batch_size=args.batch_size,
data_shape=data_shape,
path_imgrec=path_imgrec,
shuffle=True,
rand_mirror=config.data_rand_mirror,
mean=mean,
cutoff=config.data_cutoff,
color_jittering=config.data_color,
images_filter=config.data_images_filter,
)
metric1 = AccMetric()
eval_metrics = [mx.metric.create(metric1)]
if config.ce_loss:
metric2 = LossValueMetric()
eval_metrics.append(mx.metric.create(metric2))
# 对权重进行初始化
if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) # resnet style
else:
initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
_rescale = 1.0 / args.ctx_num
opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
_cb = mx.callback.Speedometer(args.batch_size, args.frequent)
# 加载所有测试数据集
ver_list = []
ver_name_list = []
for name in config.val_targets:
path = os.path.join(data_dir, name + ".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
ver_list.append(data_set)
ver_name_list.append(name)
print('ver', name)
# 对测试集进行测试
def ver_test(nbatch):
results = []
for i in range(len(ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,
None, None)
print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
# print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
results.append(acc2)
return results
# 最高的准曲率
highest_acc = [0.0, 0.0] # lfw and target
# for i in range(len(ver_list)):
# highest_acc.append(0.0)
global_step = [0]
save_step = [0]
lr_steps = [int(x) for x in args.lr_steps.split(',')]
print('lr_steps', lr_steps)
def _batch_callback(param):
# global global_step
global_step[0] += 1
mbatch = global_step[0]
# 降低学习率到原来的十分之一
for step in lr_steps:
if mbatch == step:
opt.lr *= 0.1
print('lr change to', opt.lr)
break
_cb(param)
# 每1000批次进行一次打印
if mbatch % 1000 == 0:
print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
# 进行
if mbatch >= 0 and mbatch % args.verbose == 0:
acc_list = ver_test(mbatch)
save_step[0] += 1
msave = save_step[0]
do_save = False
is_highest = False
if len(acc_list) > 0:
# lfw_score = acc_list[0]
# if lfw_score>highest_acc[0]:
# highest_acc[0] = lfw_score
# if lfw_score>=0.998:
# do_save = True
score = sum(acc_list)
if acc_list[-1] >= highest_acc[-1]:
if acc_list[-1] > highest_acc[-1]:
is_highest = True
else:
if score >= highest_acc[0]:
is_highest = True
highest_acc[0] = score
highest_acc[-1] = acc_list[-1]
# if lfw_score>=0.99:
# do_save = True
if is_highest:
do_save = True
if args.ckpt == 0:
do_save = False
elif args.ckpt == 2:
do_save = True
elif args.ckpt == 3:
msave = 1
# 模型保存
if do_save:
print('saving', msave)
arg, aux = model.get_params()
if config.ckpt_embedding:
all_layers = model.symbol.get_internals()
_sym = all_layers['fc1_output']
_arg = {}
for k in arg:
if not k.startswith('fc7'):
_arg[k] = arg[k]
mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
else:
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
if config.max_steps > 0 and mbatch > config.max_steps:
sys.exit(0)
epoch_cb = None
# 把train_dataiter转化为mx.ioPrefetchingIter迭代器
train_dataiter = mx.io.PrefetchingIter(train_dataiter)
model.fit(train_dataiter,
begin_epoch=begin_epoch,
num_epoch=999999,
eval_data=val_dataiter,
eval_metric=eval_metrics,
kvstore=args.kvstore,
optimizer=opt,
# optimizer_params = optimizer_params,
initializer=initializer,
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True,
batch_end_callback=_batch_callback,
epoch_end_callback=epoch_cb)
def main():
global args
args = parse_args()
train_net(args)
if __name__ == '__main__':
main()
以上除了def get_symbol(args)函数没有详细注释外,其他基本注释完成,该函数涉及到损失函数,比较复杂,下小节为大家详细讲解。记得关注点赞熬。