问题定义要考虑如下几方面内容:
深度学习算法工程师常被称做调参工程师.因为模型涉及的参数非常多,超参数就有数十个.但大部分超参数设定在一定范围内,对实验结果不会有太大影响,所以因该把更多精力放在数据分析,问题定义和算法设计上.
import mxnet as mx
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
mod = mx.mod.Module(symbol=symbol)
mod.fit(...,initializer=initializer,...)
如果没有为fit()方法指定初始化方式.默认初始化方式是mx.init.Uniform(0.01),表示所有参数都是在[-0.01,0.01]之间取值,这种默认的方式会导致训练异常,不建议使用.
优化函数和学习率可以直做为fit()的输入,主要指optimizer_params(字典)和optimizer(字符串)这两个参数.
import mxnet as mx
optimizer_params = {
'learing_rate':0.001}
mod = mx.mod.Module(symbol=symbol)
mod.fit(...,optimizer_params=optimizer_params, optimizer='sgd',
...)
#factor=0.1表示epoch数量在(1000,3000)之间,学习率设置为当前学习率的0.1倍,假设当前epoch=3200,
#当前学习率为初始学习率的0.01倍
import mxnet as mx
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=[1000, 3000, 4000], factor=0.1)
optimizer_params = {
'learning_rate':0.001,
'momentum':0.9,
'wd':0.0001,
'lr_scheduler':lr_scheduler}
mod = mx.mod.Module(symbol=symbol)
mod.fit(...,
optimizer_params=optimizer_params,
optimizer='sgd',
...)
import mxnet as mx
#prefix:例如prefix='resnet-50',则保存的模型名称类似于resnet-50-0001.params,resnet-50-0002.params...
#period:例如period=10,则每隔10个epoch保存一次,保存模型名称变为resnet-50-0010.params,resnet-50-0020.params...
checkpoint = mx.callback.do_checkpoint(prefix=prefix,period=1)
mod = mx.mod.Module(symbol=symbol)
mod.fit(...,
epoch_end_callback=checkpoint,
...)
- logging_test.py
import logging
logger = logging.getLogger() #得到一个记录器logger
logger.setLevel(logging.INFO) #设置日志级别为logging.INFO,表示代码正常运行时的日志.
stream_handler = logging.StreamHandler() #得到一个显式管理对象stream_handler
logger.addHandler(stream_handler) #添加显式管理对象
file_handler = logging.FileHandler('train.log') #得到文件管理对象file_handler
#file_handler = logging.FileHandler('train.log',mode='w') #先清空原有日志,再写入
logger.addHandler(file_handler) #添加文件管理对象
logger.info("Hello logging") #调用记录器的info()方法,传入对应信息,就能打印出来并保存到指定文件
评价指标可以用MXNet已有的指标,也可以自己定义.
import mxnet as mx
mod = mx.mod.Module(symbol=symbol)
mod.fit(...,
eval_metric='acc',
...)
在MXNet中,评价指标相关的类都维护在mxnet.metric模块中,其中最基本的类是mxnet.metric.EvalMetric.
MXNet已有的平价指标也是通过继承mxnet.metric.EvalMetric类来实现的.下面自己定义一个类别0的回率的例子.
import mxnet as mx
class Recall(mx.metric.EvalMetric):
#__init__执行了两个操作,一是将name参数赋给对象,这个name就是评价指标名称,二是重置操作
def __init__(self, name):
super(Recall, self).__init__('Recall')
self.name = name
self.reset()
#这是类的重置方法,对该类涉及的变量做清零等重置操作.
def reset(self):
self.num_inst = 0
self.sum_metric = 0.0
#这是该类的计算指标的方法,是该类的核心.TP/TP+FN
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
for pred, label in zip(preds, labels):
pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
true_positives = 0
false_negatives = 0
for index in range(len(pred.flat)):
if pred[index] == 0 and label[index] == 0:
true_positives += 1
if pred[index] != 0 and label[index] == 0:
false_negatives += 1
self.sum_metric += true_positives
self.num_inst += (true_positives+false_negatives)
#这是获取指标计算结果的方法
def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
import mxnet as mx
context = [mx.gpu(0), mx.gpu(1)]
mod = mx.mod.Module(symbol=symbol, context=context)
mod.fit(...)
迁移学习是指将基于某个任务或数据训练得到的模型特征提取能力,迁移到其它任务或数据集上.
import mxnet as mx
#mx.model.load_checkpoint同时读取".params"和".json"文件,得到3和输出
symbol, arg_params, aux_paramas = mx.model.load_checkpoint(
prefix = './models/resnet-18',epoch=0)
all_layers = symbol.get_internals()#得到网络结构的所有层信息
#读取要修改层的前一层信息()
#假设全连接层的前一层信息是'flatten',层名称需要加'_outout'这个后缀
new_symbol = all_layers['flatten' + '_output']
new_symbol = mx.sym.FullyConnected(data=new_symbol, num_hidden=10,name='new_fc')
#截取网络结构时也会丢掉损失函数层,所以得加上损失函数层
new_symbol = mx.sym.SoftmaxOutput(data=new_symbol)
#将预训练模型的的参数变量,这样fit()在初始化时,就会使用arg_params, aux_paramas 中
#与new_symbol对应名称的层参数初始化网络结构.
# allow_missing=True,此参数默认为False,但因为替换了预训练模型的全连接层,而 arg_params, aux_paramas中
#没有新的全连接层信息,而要按照定义方式初始化全连接层,就需要设置allow_missing=True
mod = mx.mod.Module(symbol=new_symbol)
mod.fit(...,
arg_params=arg_params,
aux_paramas=aux_paramas,
allow_missing=True,
...)
基于保存的模型继续训练,不需要从头开始.
import mxnet as mx
symbol, arg_params, aux_paramas = mx.model.load_checkpoint(
prefix = './models/resnet-18',epoch=0)
mod = mx.mod.Module(symbol=symbol)
#此处allow_missing为默认值False
mod.fit(...,
arg_params=arg_params,
aux_paramas=aux_paramas,
...)
import mxnet as mx
import argparse
import numpy as np
import gzip
import struct
import logging
from custom_metric import *
def get_network(num_classes):
"""
LeNet
"""
data = mx.sym.Variable("data")
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=6,
name="conv1")
relu1 = mx.sym.Activation(data=conv1, act_type="relu", name="relu1")
pool1 = mx.sym.Pooling(data=relu1, kernel=(2,2), stride=(2,2),
pool_type="max", name="pool1")
conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=16,
name="conv2")
relu2 = mx.sym.Activation(data=conv2, act_type="relu", name="relu2")
pool2 = mx.sym.Pooling(data=relu2, kernel=(2, 2), stride=(2, 2),
pool_type="max", name="pool2")
fc1 = mx.sym.FullyConnected(data=pool2, num_hidden=120, name="fc1")
relu3 = mx.sym.Activation(data=fc1, act_type="relu", name="relu3")
fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=84, name="fc2")
relu4 = mx.sym.Activation(data=fc2, act_type="relu", name="relu4")
fc3 = mx.sym.FullyConnected(data=relu4, num_hidden=num_classes, name="fc3")
sym = mx.sym.SoftmaxOutput(data=fc3, name="softmax")
return sym
def get_args():
parser = argparse.ArgumentParser(description='score a model on a dataset')
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--num-epoch', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.1, help="learning rate")
parser.add_argument('--save-result', type=str, default='output/')
parser.add_argument('--save-name', type=str, default='LeNet')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
if args.gpus:
context = [mx.gpu(int(index)) for index in
args.gpus.strip().split(",")]
else:
context = mx.cpu()
# get data
train_data = mx.io.MNISTIter(
image='train-images.idx3-ubyte',
label='train-labels.idx1-ubyte',
batch_size=args.batch_size,
shuffle=1)
val_data = mx.io.MNISTIter(
image='t10k-images.idx3-ubyte',
label='t10k-labels.idx1-ubyte',
batch_size=args.batch_size,
shuffle=0)
# get network(symbol)
sym = get_network(num_classes=args.num_classes)
optimizer_params = {
'learning_rate': args.lr}
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
mod = mx.mod.Module(symbol=sym, context=context)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)
file_handler = logging.FileHandler('output/train.log')
logger.addHandler(file_handler)
logger.info(args)
checkpoint = mx.callback.do_checkpoint(prefix=args.save_result +
args.save_name, period=20)
batch_callback = mx.callback.Speedometer(args.batch_size, 200)
# metric
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(Recall(name="class0_recall"))
eval_metric.add(['acc','ce'])
mod.fit(train_data=train_data,
eval_data=val_data,
eval_metric = eval_metric,
optimizer_params=optimizer_params,
optimizer='sgd',
batch_end_callback=batch_callback,
initializer=initializer,
num_epoch = args.num_epoch,
epoch_end_callback=checkpoint)
输出结果:
/home/yuyang/anaconda3/envs/mxnet/bin/python3.5 /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo7/train_mnist.py
[16:31:38] src/io/iter_mnist.cc:113: MNISTIter: load 60000 images, shuffle=1, shape=[64,1,28,28]
[16:31:38] src/io/iter_mnist.cc:113: MNISTIter: load 10000 images, shuffle=0, shape=[64,1,28,28]
Namespace(batch_size=64, gpus='0', lr=0.1, num_classes=10, num_epoch=10, save_name='LeNet', save_result='output/')
[16:31:42] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Epoch[0] Batch [0-200] Speed: 31317.88 samples/sec class0_recall=0.912637 accuracy=0.825249 cross-entropy=0.556456
Epoch[0] Batch [200-400] Speed: 33466.00 samples/sec class0_recall=0.969072 accuracy=0.949766 cross-entropy=0.162854
Epoch[0] Batch [400-600] Speed: 36086.80 samples/sec class0_recall=0.981206 accuracy=0.959844 cross-entropy=0.129993
Epoch[0] Batch [600-800] Speed: 29235.05 samples/sec class0_recall=0.984603 accuracy=0.970469 cross-entropy=0.094682
Epoch[0] Train-class0_recall=0.964352
Epoch[0] Train-accuracy=0.933081
Epoch[0] Train-cross-entropy=0.214967
Epoch[0] Time cost=1.899
Epoch[0] Validation-class0_recall=0.988764
Epoch[0] Validation-accuracy=0.980769
Epoch[0] Validation-cross-entropy=0.063637
Epoch[1] Batch [0-200] Speed: 32101.59 samples/sec class0_recall=0.985179 accuracy=0.974736 cross-entropy=0.079682
Epoch[1] Batch [200-400] Speed: 32740.68 samples/sec class0_recall=0.991277 accuracy=0.976719 cross-entropy=0.076194
Epoch[1] Batch [400-600] Speed: 37568.90 samples/sec class0_recall=0.989820 accuracy=0.980078 cross-entropy=0.070138
Epoch[1] Batch [600-800] Speed: 37054.64 samples/sec class0_recall=0.991086 accuracy=0.982891 cross-entropy=0.055028
Epoch[1] Train-class0_recall=0.989356
Epoch[1] Train-accuracy=0.979189
Epoch[1] Train-cross-entropy=0.068472
Epoch[1] Time cost=1.697
Epoch[1] Validation-class0_recall=0.993871
Epoch[1] Validation-accuracy=0.985076
Epoch[1] Validation-cross-entropy=0.047333
Epoch[2] Batch [0-200] Speed: 33036.56 samples/sec class0_recall=0.989860 accuracy=0.983753 cross-entropy=0.052290
Epoch[2] Batch [200-400] Speed: 37436.53 samples/sec class0_recall=0.992863 accuracy=0.983203 cross-entropy=0.055996
Epoch[2] Batch [400-600] Speed: 35763.55 samples/sec class0_recall=0.992169 accuracy=0.984844 cross-entropy=0.049648
Epoch[2] Batch [600-800] Speed: 34058.20 samples/sec class0_recall=0.995948 accuracy=0.988203 cross-entropy=0.039633
Epoch[2] Train-class0_recall=0.992397
Epoch[2] Train-accuracy=0.985242
Epoch[2] Train-cross-entropy=0.048564
Epoch[2] Time cost=1.750
Epoch[2] Validation-class0_recall=0.995914
Epoch[2] Validation-accuracy=0.986879
Epoch[2] Validation-cross-entropy=0.042824
Epoch[3] Batch [0-200] Speed: 33766.61 samples/sec class0_recall=0.992200 accuracy=0.987562 cross-entropy=0.039315
Epoch[3] Batch [200-400] Speed: 34258.88 samples/sec class0_recall=0.992863 accuracy=0.985859 cross-entropy=0.044844
Epoch[3] Batch [400-600] Speed: 35800.58 samples/sec class0_recall=0.996085 accuracy=0.988672 cross-entropy=0.037022
Epoch[3] Batch [600-800] Speed: 33737.29 samples/sec class0_recall=0.995948 accuracy=0.991094 cross-entropy=0.029925
Epoch[3] Train-class0_recall=0.993918
Epoch[3] Train-accuracy=0.988444
Epoch[3] Train-cross-entropy=0.037272
Epoch[3] Time cost=1.723
Epoch[3] Validation-class0_recall=0.996936
Epoch[3] Validation-accuracy=0.987280
Epoch[3] Validation-cross-entropy=0.039437
Epoch[4] Batch [0-200] Speed: 35346.01 samples/sec class0_recall=0.994540 accuracy=0.989661 cross-entropy=0.031420
Epoch[4] Batch [200-400] Speed: 35871.27 samples/sec class0_recall=0.994449 accuracy=0.989062 cross-entropy=0.035467
Epoch[4] Batch [400-600] Speed: 38914.38 samples/sec class0_recall=0.995301 accuracy=0.990625 cross-entropy=0.029810
Epoch[4] Batch [600-800] Speed: 32392.26 samples/sec class0_recall=0.997569 accuracy=0.994141 cross-entropy=0.022864
Epoch[4] Train-class0_recall=0.994932
Epoch[4] Train-accuracy=0.990878
Epoch[4] Train-cross-entropy=0.029375
Epoch[4] Time cost=1.682
Epoch[4] Validation-class0_recall=0.996936
Epoch[4] Validation-accuracy=0.988081
Epoch[4] Validation-cross-entropy=0.038098
Epoch[5] Batch [0-200] Speed: 35009.40 samples/sec class0_recall=0.995320 accuracy=0.991371 cross-entropy=0.024879
Epoch[5] Batch [200-400] Speed: 37891.87 samples/sec class0_recall=0.994449 accuracy=0.991406 cross-entropy=0.029174
Epoch[5] Batch [400-600] Speed: 37632.95 samples/sec class0_recall=0.998434 accuracy=0.993359 cross-entropy=0.023785
Epoch[5] Batch [600-800] Speed: 38003.66 samples/sec class0_recall=0.997569 accuracy=0.993984 cross-entropy=0.018242
Epoch[5] Train-class0_recall=0.995945
Epoch[5] Train-accuracy=0.992596
Epoch[5] Train-cross-entropy=0.023498
Epoch[5] Time cost=1.634
Epoch[5] Validation-class0_recall=0.996936
Epoch[5] Validation-accuracy=0.987280
Epoch[5] Validation-cross-entropy=0.040373
Epoch[6] Batch [0-200] Speed: 34940.66 samples/sec class0_recall=0.996880 accuracy=0.993470 cross-entropy=0.019952
Epoch[6] Batch [200-400] Speed: 37034.16 samples/sec class0_recall=0.994449 accuracy=0.993047 cross-entropy=0.023962
Epoch[6] Batch [400-600] Speed: 37740.89 samples/sec class0_recall=0.997651 accuracy=0.994375 cross-entropy=0.019386
Epoch[6] Batch [600-800] Speed: 40234.91 samples/sec class0_recall=0.995948 accuracy=0.996016 cross-entropy=0.014943
Epoch[6] Train-class0_recall=0.996283
Epoch[6] Train-accuracy=0.994380
Epoch[6] Train-cross-entropy=0.018967
Epoch[6] Time cost=1.598
Epoch[6] Validation-class0_recall=0.992850
Epoch[6] Validation-accuracy=0.986879
Epoch[6] Validation-cross-entropy=0.042056
Epoch[7] Batch [0-200] Speed: 37206.74 samples/sec class0_recall=0.997660 accuracy=0.996424 cross-entropy=0.014272
Epoch[7] Batch [200-400] Speed: 39221.24 samples/sec class0_recall=0.995242 accuracy=0.994062 cross-entropy=0.019821
Epoch[7] Batch [400-600] Speed: 36042.13 samples/sec class0_recall=0.999217 accuracy=0.995703 cross-entropy=0.015387
Epoch[7] Batch [600-800] Speed: 32745.00 samples/sec class0_recall=0.998379 accuracy=0.997422 cross-entropy=0.010844
Epoch[7] Train-class0_recall=0.997973
Epoch[7] Train-accuracy=0.996098
Epoch[7] Train-cross-entropy=0.014557
Epoch[7] Time cost=1.664
Epoch[7] Validation-class0_recall=0.992850
Epoch[7] Validation-accuracy=0.986779
Epoch[7] Validation-cross-entropy=0.040577
Epoch[8] Batch [0-200] Speed: 34629.35 samples/sec class0_recall=0.997660 accuracy=0.996502 cross-entropy=0.011716
Epoch[8] Batch [200-400] Speed: 35748.35 samples/sec class0_recall=0.997621 accuracy=0.994687 cross-entropy=0.016355
Epoch[8] Batch [400-600] Speed: 33267.15 samples/sec class0_recall=0.999217 accuracy=0.995703 cross-entropy=0.014401
Epoch[8] Batch [600-800] Speed: 36100.29 samples/sec class0_recall=1.000000 accuracy=0.997031 cross-entropy=0.010151
Epoch[8] Train-class0_recall=0.998648
Epoch[8] Train-accuracy=0.996148
Epoch[8] Train-cross-entropy=0.012642
Epoch[8] Time cost=1.724
Epoch[8] Validation-class0_recall=0.995914
Epoch[8] Validation-accuracy=0.988582
Epoch[8] Validation-cross-entropy=0.040548
Epoch[9] Batch [0-200] Speed: 35794.78 samples/sec class0_recall=0.996880 accuracy=0.997201 cross-entropy=0.009314
Epoch[9] Batch [200-400] Speed: 32352.01 samples/sec class0_recall=0.996828 accuracy=0.996094 cross-entropy=0.012941
Epoch[9] Batch [400-600] Speed: 32864.85 samples/sec class0_recall=0.999217 accuracy=0.997266 cross-entropy=0.011650
Epoch[9] Batch [600-800] Speed: 35118.66 samples/sec class0_recall=0.999190 accuracy=0.997812 cross-entropy=0.007090
Epoch[9] Train-class0_recall=0.998142
Epoch[9] Train-accuracy=0.997148
Epoch[9] Train-cross-entropy=0.010119
Epoch[9] Time cost=1.757
Epoch[9] Validation-class0_recall=0.995914
Epoch[9] Validation-accuracy=0.987480
Epoch[9] Validation-cross-entropy=0.043713
Process finished with exit code 0