mxnet finetune例子(只finetune某几层)

1.百度mxnet model zoo下载相应的pre-train model:

http://mxnet.incubator.apache.org/model_zoo/index.html


2.把数据转为.rec,可参照官方例子的第一块内容:

http://mxnet.incubator.apache.org/how_to/finetune.html


3.定义数据迭代器生成函数:

def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

4.定义pre-train模型读取函数以及模型修改函数

def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

5.模型训练,训练过程中,把不进行调整的层的学习率设置为0,从而达到只finetune后几层的效果

def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):
    devs = [mx.gpu(i) for i in gpu_avaliable]
    model = mx.mod.Module(symbol=symbol, context=devs)
    # metric
    com_metric = mx.metric.CompositeEvalMetric()
    com_metric.add(mx.metric.Accuracy())
    com_metric.add(mAP(class_str)) # remove if unnecessary
    # optimizer: fix the weight of certain layers except the last fully connect layer
    sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
    finetune_lr = dict({k: 0 for k in arg_params})
    sgd.set_lr_mult(finetune_lr)
    # training
    model.fit(iter_train, iter_val,
        num_epoch=num_epoch,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer=sgd,
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    return model.score(iter_val, com_metric)



完整代码:

import logging
import mxnet as mx
import numpy as np
import os.path, time,sys
from mAP_metric import mAP

print ("\n******File updated %ds ago%s******" % (time.time()-os.path.getmtime(sys.argv[0])))# file updatation check

# data iterators: generate data iterator from .rec file
def get_iterators(batch_size, rec_train, rec_val, lst_train, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec=rec_train,
        path_imglist=lst_train,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape,
        shuffle=True,
        # shuffle=False,
        rand_crop=True,
        mirror =True,
        rand_mirror=True,
        max_rotate_angle=0)
    val = mx.io.ImageRecordIter(
        path_imgrec=rec_val,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        data_shape=data_shape)
    return train,val

# load and tune model
def get_fine_tune_model(model_name):
    # load model
    symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
    # model tuning
    all_layers = symbol.get_internals()
    if model_name=="vgg16":
        net = all_layers['drop7_output']
    else:
        net = all_layers['flatten0_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    # eliminate weights of new layer
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args,aux_params)

#model training
def fit(symbol, arg_params, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable):
    devs = [mx.gpu(i) for i in gpu_avaliable]
    model = mx.mod.Module(symbol=symbol, context=devs)
    # metric
    com_metric = mx.metric.CompositeEvalMetric()
    com_metric.add(mx.metric.Accuracy())
    com_metric.add(mAP(class_str)) # remove if unnecessary
    # optimizer: fix the weight of certain layers except the last fully connect layer
    sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
    finetune_lr = dict({k: 0 for k in arg_params})
    sgd.set_lr_mult(finetune_lr)
    # training
    model.fit(iter_train, iter_val,
        num_epoch=num_epoch,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer=sgd,
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    return model.score(iter_val, com_metric)

#=======================================================================================================================
# set logger, print message on screen and file
logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s',filename='acc_record.log',filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter('%(asctime)-15s %(message)s'))
logging.getLogger('').addHandler(console)

# data and pre-train model
rec_train='./data/rec/hico_train_full.rec'
# rec_train='./data/rec/hico_train_200500.rec'
model_name='vgg16'
# model_name='resnet-152'
rec_val='./data/rec/hico_val.rec'
lst_train=rec_train[:-3]+'lst'

# parameter
num_classes = 600
class_str=[]
for i in range(num_classes):
    class_str.append("c"+str(i))
batch_per_gpu = 40
num_epoch =10
gpu_avaliable=[0,1,2,3]
num_gpus = len(gpu_avaliable)
batch_size = batch_per_gpu * num_gpus
if rec_train=='./data/rec/hico_train_full.rec':
    print ('-----------Batchs per epoch: %d-----------' % (7000.0/batch_size))
if rec_train=='./data/rec/hico_train_200500.rec':
    print ('-----------Batchs per epoch: %d-----------' % (137120.0/batch_size))
#-----------------------------------------------------------------------------------------------------------------------

(new_sym,new_args,aux_params)=get_fine_tune_model(model_name)
(iter_train, iter_val) = get_iterators(batch_size,rec_train,rec_val,lst_train)
mod_score = fit(new_sym, new_args, aux_params, iter_train, iter_val, class_str, num_epoch, batch_size, gpu_avaliable)
print(mod_score)


 
 

你可能感兴趣的:(mxnet)