【MXNet Gluon】自己动手实现fit函数,实现断点保存

承接图像分类、检测、分割、生成相关项目,私信。

【MXNet Gluon】自己动手实现断点保存

…用过caffe的炼丹师应该都知道,在用caffe训练模型时,可以通过命令行窗口提前终止训练过程,caffe会自动保存当前状态的参数,以供继续训练。
…但是,对于MXNet,无论你使用何种接口,都不存在这种机制。
在这里,提供一个实现方式。
实现该过程,分两个步骤:
一是写自己的训练过程函数
二是监控模型训练过程中来自命令行的按键相应

实现如下:

Gluon接口

# -*- coding:utf-8 -*-
'''
	相关头文件
'''
import signal
def signal_handler(signal,frame):
    print('You pressed Ctrl+C !')
    global BreakFlag
    BreakFlag=True
def train(net, train_list, val_iter, start_iter, epochs, lr, ctx):
    ctx = [ctx]
    loss1 = gluon.loss.L1Loss()
    loss2 = gluon.loss.L2Loss()
    bg_loss=BGLoss()
    fg_loss=FGLoss()
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
    print('start trainning...')
    for epoch in range(epochs):
        metric.reset()
        for i in range(len(train_list)/batch_size):
            batch_data,batch_label=getDataBatch(train_list,i,batchsize=batch_size,image_shape=256)
            data = gluon.utils.split_and_load(mx.nd.array(batch_data), ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(mx.nd.array(batch_label), ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            smoothing_constant=0.1
            with autograd.record():
                for x, y in zip(data, label):
                    output = net(x)
                    l1=loss2(output, y)
                    l2=bg_loss(output,y)
                    l3=fg_loss(output,y)
                    Ls.append(l1+l2+l3)
                    outputs.append(output)
                for L in Ls:
                    L.backward()
            trainer.step(batch_data.shape[0])
            metric.update(label, outputs)
			#响应函数
            if BreakFlag:
                print('Early Stop ! Saving model ...')
                net.save_params('params/net-%d-batch=%d.params' % (start_iter + epoch + 1,i))
                print('Saved model to params/net-%d-batch=%d.params' % (start_iter + epoch + 1,i))
                sys.exit(0)
        names, value = metric.get()
        metric.reset()
        logging.info('[Epoch %d] train-%s-loss : %f' % (start_iter + epoch, names[0],value[0]))
        net.save_params('params/net-%d.params' % (start_iter + epoch+1))
if __name__ == '__main__':
    '''
	    功能代码
    '''
    #启动监控键盘按键响应
    BreakFlag = False
    signal.signal(signal.SIGINT,signal_handler)
    train(net,train_data,valid_data,start_iter,num_epoch,lr,mx.gpu(0))

你可能感兴趣的:(MXNet从上手到入门)