github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。
1、数据制作
需要自己生成*.lst文件,里面内容如下:
index task1标签 task2标签 task3标签 图片路径(这行是说明,不需要写入,每一列用\t隔开)
2476 0.000000 0.000000 1.000000 photo_02_8159/00022552.jpg
7623 3.000000 2.000000 2.000000 photo_03_7397/00029434.jpg
14149 0.000000 0.000000 1.000000 photo_05_15560/00060839.jpg
6874 3.000000 1.000000 2.000000 photo_03_7397/00028414.jpg
6048 0.000000 0.000000 1.000000 photo_02_8159/00027259.jpg
14479 3.000000 3.000000 2.000000 photo_05_15560/00065068.jpg
10429 2.000000 0.000000 1.000000 photo_04_15224/00040186.jpg
6949 3.000000 0.000000 1.000000 photo_03_7397/00028521.jpg
81 3.000000 3.000000 2.000000 photo_01_19992/00002536.jpg
11725 2.000000 0.000000 1.000000 photo_05_15560/00051778.jpg
1517 2.000000 3.000000 2.000000 photo_02_8159/00021245.jpg
具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):
fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别
fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别
fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别
#分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到
s1 = mx.symbol.SoftmaxOutput(data=fc1, name='softmax1')
s2 = mx.symbol.SoftmaxOutput(data=fc2, name='softmax2')
s3 = mx.symbol.SoftmaxOutput(data=fc3, name='softmax3')
return mx.symbol.Group([s1,s2,s3])
3、编写ImageRecordIter选项
train = mx.io.ImageRecordIter(
path_imgrec='/path/to/train.rec',
label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],#label名称,于softmax名称一样,后面要加入_label
label_width=3, #重要,需要设置label宽度为3,因为有3个任务
data_shape=[3,224,224],
batch_size=64
)
val = mx.io.ImageRecordIter(
path_imgrec='/path/to/val.rec',
label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],
label_width=3,
batch_size=64,
data_shape=[3,224,224],
)
4、定义多任务训练迭代器
class MultiTask_iter(mx.io.DataIter):
def __init__(self, data_iter):
super(MultiTask_iter,self).__init__('multitask_iter')
self.data_iter = data_iter
self.batch_size = self.data_iter.batch_size
@property
def provide_data(self):
return self.data_iter.provide_data
@property
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
# the name of the label if corresponding to the model you define in get_fine_tune_model() function
return [('softmax1_label', [provide_label[1][0]]),#需要注意的地方
('softmax2_label', [provide_label[1][0]]),
('softmax3_label', [provide_label[1][0]])]
def hard_reset(self):
self.data_iter.hard_reset()
def reset(self):
self.data_iter.reset()
def next(self):
batch = self.data_iter.next()
#需要注意的地方
label = batch.label[0]
ll = label.asnumpy()
label1 = mx.nd.array(ll[:,0]).astype('float32')
label2 = mx.nd.array(ll[:,1]).astype('float32')
label3 = mx.nd.array(ll[:,2]).astype('float32')
# we set task 2 as: if label>0 or not
return mx.io.DataBatch(data=batch.data, label=[label1,label2,label3], \
pad=batch.pad, index=batch.index)
5、定义正确率计算方法
class Multi_Accuracy(mx.metric.EvalMetric):
"""Calculate accuracies of multi label"""
def __init__(self, num=None):
super(Multi_Accuracy, self).__init__('multi-accuracy', num)
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
if self.num is not None:
assert len(labels) == self.num
for i in range(len(labels)):
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
label = labels[i].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
if i is None:
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
else:
self.sum_metric[i] += (pred_label.flat == label.flat).sum()
self.num_inst[i] += len(pred_label.flat)
6、训练
train = MultiTask_iter(train)#调用多任务迭代器,其中train参数就是第3步的东西
val = MultiTask_iter(val)
new_sym = get_symbol(10,50,image_shape)
optimizer_params = {
'learning_rate': 0.001,
'momentum' : args.mom,
'wd' : args.wd,
}
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
model = mx.mod.Module(
context = devs,
symbol = new_sym,
data_names=['data'],
label_names=['softmax1_label','softmax2_label','softmax3_label']
)
saveroot = args.save_result+'/' + args.save_name
checkpoint = mx.callback.do_checkpoint(saveroot)
model.fit(train,
begin_epoch=0,
num_epoch=100000,
eval_data=val,
eval_metric=Multi_Accuracy(num=3),#需要注意的地方
optimizer='sgd',
optimizer_params=optimizer_params,
initializer=initializer,
allow_missing=True,
batch_end_callback=mx.callback.Speedometer(64, 50),
epoch_end_callback=checkpoint
)