针对于https://github.com/deepinsight/insightface
针对train_softmax.py
line 129
到line 272
的get_symbol
函数,定义了模型结构。其框架大体如下
def get_symbol(args,arg_params,aux_params):
embedding = fresnet.get_symbol(...) #以resnet为例
all_label = mx.symbol.Variable('softmax_label')
gt_label = all_label
extra_loss = None
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult)
fc7 = mx.sym.FullyConnected(data=embedding,weight=_weight,no_bias=True,num_hidden=args.num_classes,name='fc7') # 以softmax为例
out_list = [mx.symbol.BlockGrad(embedding)]
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
out = mx.symbol.Group(out_list)
return (out, arg_params, aux_params)
此处,out中包含了mx.symbol.BlockGrad(embedding)
和softmax
,前者是阻止计算梯度的特征向量embedding
,后者是可以计算梯度的softmax
输出。前者用于verification.test
;后者用于model.update()
,而model.update()
调用了类AccMetric
中的update()
。下面是详细解析。
sym
,模型参数arg_params
与aux_params
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
其中,sym
即为上面的out
,包含有mx.symbol.BlockGrad(embedding)
和softmax
model = mx.mod.Module(
context = ctx,
symbol = sym,
)
其中,ctx
为设备,例如ctx = [mx.gpu(0)]
,sym
即为上述的sym
ver_test(nbatch)
中的verification.test
。在test
的line 235
,line 236
,line 249
,使用了out
中的mx.symbol.BlockGrad(embedding)
model.forward(db, is_train=False)
net_out = model.get_outputs()
...
_embeddings = net_out[0].asnumpy()
其中,model.get_outputs()
是为了得到列表net_out
,net_out[0].asnumpy()
则是取出embedding
并将其转化为array
softmax
为例,line 182
定义了fc7
层,并在line 269
定义了损失函数 fc7 = mx.sym.FullyConnected(data=embedding,weight=_weight,no_bias=True,num_hidden=args.num_classes,name='fc7')
...
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
另外,如果需要自定义损失函数,可以用mxnet.symbol.MakeLoss
方法,例子如下(线性回归,参考自https://www.zhihu.com/question/51043310/answer/148268459)
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
x = mx.sym.Variable('data')
y = mx.sym.FullyConnected(data=x, num_hidden=1)
label = mx.sym.Variable('label')
loss = mx.sym.MakeLoss(mx.sym.square(y - label))
pred_loss = mx.sym.Group([mx.sym.BlockGrad(y), loss])
ex = pred_loss.simple_bind(mx.cpu(), data=(32, 2))
# test
test_data = mx.nd.array(np.random.random(size=(32, 2)))
test_label = mx.nd.array(np.random.random(size=(32, 1)))
ex.forward(is_train=True, data=test_data, label=test_label)
ex.backward()
print ex.arg_dict
fc_w = ex.arg_dict['fullyconnected0_weight'].asnumpy()
fc_w_grad = ex.grad_arrays[1].asnumpy()
fc_bias = ex.arg_dict['fullyconnected0_bias'].asnumpy()
fc_bias_grad = ex.grad_arrays[2].asnumpy()
logging.info('fc_weight:{}, fc_weights_grad:{}'.format(fc_w, fc_w_grad))
logging.info('fc_bias:{}, fc_bias_grad:{}'.format(fc_bias, fc_bias_grad))
line 45
到line 83
,分别定义了AccMetric
和LossValueMetric
。如过loss_type<10
则取前者,否则取后者。以AccMetric
为例class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
preds = [preds[1]] #use softmax output
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
在line 464
,开始训练网络
model.fit(train_dataiter,
begin_epoch = begin_epoch,
num_epoch = end_epoch,
eval_data = val_dataiter,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
#optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = epoch_cb )
其中,eval_metric
为AccMetric
。model.fit
的定义,位于site-packages/mxnet/module/base_module.py
,fit
的定义,位于line 395
到line 562
,更新度量值时,调用line 521
:
self.metric_update()
而metric_update
方法,则调用AccMetric
中的update
:
def update(self, labels, preds):
self.count+=1
preds = [preds[1]] #use softmax output
...
其中,preds
即为sym
,而preds[1]
即为sym[1]
,也就是softmax