mxnet 中模型的输入,保存,以及在当前文件系统下建立目录来保存生成的模型。(笔记)

1.mxnet中创建文件夹,并输入输入和模型。

#检查是否有这个文件目录,没有则创建一个
if not os.path.exists(model_directory):
        os.mkdir(model_directory)
#检查并创建一个日志文件,并创建下级目录??后面是flag分别是可读可写,创建并打开一个新文件,在文件后可添加信息
log_file=os.open(model_directory+"//_train_.csv",os.O_RDWR|os.O_CREAT|os.O_APPEND)
#判断是否有模型,解决用那种方法来训练数据

后面的是在创建的文件夹中加一个可读可写的日志文件,我们可以把训练的模型输出情况记录下来。

2.在训练中我们可以选择两种方式来构建模型,一是自己搭建网络,二是从网上下载已经提前训练好的模型。具体的模型导入方式如下。

if n_epoch_load==1:
    module=mx.mod.Module(symbol=net,context=mx.gpu(0))
    arg_params=None
    aux_params=None
else:
        sym,arg_params,aux_params=mx.model.save_checkpoint(model_prefix,n_epoch_load)
        module=mx.mod.Module(symbol=sym,context=mx.gpu(0))

3.在训练过程中,如果我们想用日志文件来记录每个epoch的记录,我们可以如下:

def epoch_callback(epoch,symbol,arg_params,aux_params):
	#if epoch % save_period==0:
		#module.save_checkpoint(model_prefix,epoch,save_optimizer_states=True)
	os.write(log_file,str(logging.getLogger().setLevel(logging.DEBUG))+"\n")
	os.fsync(log_file)

module。fit(
	epoch_end_callback = epoch_callback,

)

在其中,我们可以用判断语句来判断模型是否训练的比较好,这是我们可以保存模型数据,具体检验情况还得看自己的判断。

 

你可能感兴趣的:(PYTHON)