mxnet fine-tune

网络finetune时,有两种思路,微调整个网络或者仅仅微调网络后几层,现以finetune mobilefacenet为例,加载预训练模型,并固定参数,仅训练自己添加的全连接层 

 

通过查找发现fixed_param_names可达到fix的效果,代码如下

symbol, arg_params, aux_params = mx.model.load_checkpoint('model-y1-arcface', 0000)    
all_layers = symbol.get_internals()    
net = all_layers['fc1_output']    
fixed_names=net.list_arguments()



model = mx.module.Module(symbol=symbol, context=devs,fixed_param_names=fixed_names)

但是实际运行时,发现参数还是发生变化了,通过多方查找和分析,才知道,因为mobilefacenet的BN中,添加了momentum=0.9,即使固定了参数,但训练后整体参数还是会发生变化,所以应该将mobilefacenet中的momentum改为1,但是如何修改呢

 

mxnet中网络模型存为params和json两种文件,load_checkpoint后得到symbol,arg_params,aux_params,其中arg_params中存储网络可训练参数,aux_params中存的是网络固定参数,例如均值和方差等,打开发现json发现里面有mean,var等,大胆猜测一波,json里面存的就是网络模型和未训练参数(不知道是不是全部),ctrl+f发现里面竟然也有momentum??????,进一步发现momentum后面的0.9和mobilefacenet网络定义时BN层中0.9是那么的相似,emmmmmm一个邪恶的想法诞生了---暴力将json中momentum指定为1,重新训练

 

有趣的事情发生了,参数不变了!!!!!不可思议!!!!

 

好了到这里了,如果这个方法对大家有效,记着留个赞啊,如果没用的话,那我只想说:

                                      关我屁事

哈哈哈,最后附关键代码

import logging
import mxnet as mx
import numpy as np
import os.path, time,sys


# data iterators: generate data iterator from .rec file
def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 112, 112)):
    train = mx.io.ImageRecordIter(
        path_imgrec=rec_train,
        path_imglist=lst_train,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape,
        shuffle=True,
        # shuffle=False,
        rand_crop=True,
        mirror =True,
        rand_mirror=True,
        max_rotate_angle=0)
    val = mx.io.ImageRecordIter(
        path_imgrec=rec_val,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape)
    return train,val
 
# load and tune model
def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint('/home/xxx/anaconda2/envs/mobilefacenet/insightface_root/insightface/models/MobileFaceNet/model-y1-arcface', 0000)
    # model tuning
    all_layers = symbol.get_internals()
    net = all_layers['fc1_output']

    fixed_names=net.list_arguments()

    _weight_newfc1 = mx.symbol.Variable("newfc1_weight", shape=(10, 128), lr_mult=1.0, wd_mult=5)    
    net = mx.symbol.FullyConnected(data=net, num_hidden=10, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
  
    #finetune_lr = dict({k: 0 for k in arg_params})
    #print(finetune_lr)

    new_args = dict({k:arg_params[k] for k in arg_params if 'newfc1' not in k})

    return (net, new_args,aux_params,fixed_names)
 
#model training
def fit(symbol, arg_params, aux_params, iter_train, iter_val, num_epoch, batch_size, gpu_avaliable,fixed_names):
    devs = [mx.gpu(i) for i in gpu_avaliable]
    model = mx.module.Module(symbol=symbol, context=devs,fixed_param_names=fixed_names)
    # metric
    com_metric = mx.metric.CompositeEvalMetric()
    com_metric.add(mx.metric.Accuracy())
    
    # optimizer: fix the weight of certain layers except the last fully connect layer
    sgd = mx.optimizer.Optimizer.create_optimizer('sgd',learning_rate=0.01,momentum=0,wd=0.01)
    finetune_lr = dict({k: 0 for k in arg_params})
    #print(finetune_lr)
    #sgd.set_lr_mult(finetune_lr)
    # training
    model.fit(iter_train, iter_val,
        num_epoch=num_epoch,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        #epoch_end_callback  = mx.callback.do_checkpoint('/home/xxxx/anaconda2/envs/mobilefacenet/insightface_root/insightface/newmodels/chkmodel2', 0),
        kvstore='device',
        optimizer=sgd,
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    
    arg, aux = model.get_params()
    mx.model.save_checkpoint(prefix, 107,model.symbol, arg, aux)               #(name,index,symbol,arg_params, aux_params)
    return model.score(iter_val, com_metric)
 
#=======================================================================================================================
# set logger, print message on screen and file
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))
logging.getLogger('').addHandler(console)
 
# data and pre-train model
prefix='/home/****/anaconda2/envs/mobilefacenet/insightface_root/insightface/models

rec_train='/home/****/anaconda2/envs/jiemian/dataset_rec/train_train.rec'

model_name='ahah'

rec_val='/home/****/anaconda2/envs/jiemian/dataset_rec/train_val.rec'
lst_train=rec_train[:-3]+'lst'
 
# parameter
num_classes = 10

batch_per_gpu = 3
num_epoch =30
gpu_avaliable=[0]
num_gpus = len(gpu_avaliable)
batch_size = batch_per_gpu * num_gpus
print(batch_size)
#-----------------------------------------------------------------------------------------------------------------------
 
(new_sym,new_args,aux_params,fixnames)=get_fine_tune_model(model_name)

#mx.viz.plot_network(new_sym).view()                     #          model architecture

print('========================= 1 =============================')
(iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)

print('========================= 2 =============================')
mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, num_epoch, batch_size, gpu_avaliable,fixnames)
print(mod_score)

溜了溜了

你可能感兴趣的:(mxnet)