深度学习【22】Mxnet多任务(multi-task)训练

github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。

1、数据制作
需要自己生成*.lst文件,里面内容如下:


index   task1标签   task2标签    task3标签    图片路径(这行是说明,不需要写入,每一列用\t隔开)
2476    0.000000    0.000000    1.000000    photo_02_8159/00022552.jpg 
7623    3.000000    2.000000    2.000000    photo_03_7397/00029434.jpg
14149   0.000000    0.000000    1.000000    photo_05_15560/00060839.jpg
6874    3.000000    1.000000    2.000000    photo_03_7397/00028414.jpg
6048    0.000000    0.000000    1.000000    photo_02_8159/00027259.jpg
14479   3.000000    3.000000    2.000000    photo_05_15560/00065068.jpg
10429   2.000000    0.000000    1.000000    photo_04_15224/00040186.jpg
6949    3.000000    0.000000    1.000000    photo_03_7397/00028521.jpg
81      3.000000    3.000000    2.000000    photo_01_19992/00002536.jpg
11725   2.000000    0.000000    1.000000    photo_05_15560/00051778.jpg
1517    2.000000    3.000000    2.000000    photo_02_8159/00021245.jpg

具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):

    fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别
    fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别
    fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别
    #分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到
    s1 = mx.symbol.SoftmaxOutput(data=fc1, name='softmax1') 
    s2 = mx.symbol.SoftmaxOutput(data=fc2, name='softmax2')
    s3 = mx.symbol.SoftmaxOutput(data=fc3, name='softmax3')
    return  mx.symbol.Group([s1,s2,s3])

3、编写ImageRecordIter选项

    train = mx.io.ImageRecordIter(
        path_imgrec='/path/to/train.rec',
        label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],#label名称,于softmax名称一样,后面要加入_label
        label_width=3, #重要,需要设置label宽度为3,因为有3个任务
        data_shape=[3,224,224],
        batch_size=64
    )

    val = mx.io.ImageRecordIter(
        path_imgrec='/path/to/val.rec',
        label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],
        label_width=3,
        batch_size=64,
        data_shape=[3,224,224],
    )

4、定义多任务训练迭代器

class MultiTask_iter(mx.io.DataIter):
    def __init__(self, data_iter):
        super(MultiTask_iter,self).__init__('multitask_iter')
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]
        # the name of the label if corresponding to the model you define in get_fine_tune_model() function
        return [('softmax1_label', [provide_label[1][0]]),#需要注意的地方
        ('softmax2_label', [provide_label[1][0]]),
        ('softmax3_label', [provide_label[1][0]])]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        #需要注意的地方
        label = batch.label[0]
        ll = label.asnumpy()
        label1 = mx.nd.array(ll[:,0]).astype('float32')
        label2 = mx.nd.array(ll[:,1]).astype('float32')
        label3 = mx.nd.array(ll[:,2]).astype('float32')
        # we set task 2 as: if label>0 or not

        return mx.io.DataBatch(data=batch.data, label=[label1,label2,label3], \
                pad=batch.pad, index=batch.index)

5、定义正确率计算方法

class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""

    def __init__(self, num=None):
        super(Multi_Accuracy, self).__init__('multi-accuracy', num)

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        if self.num is not None:
            assert len(labels) == self.num

        for i in range(len(labels)):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')

            mx.metric.check_label_shapes(label, pred_label)

            if i is None:
                self.sum_metric += (pred_label.flat == label.flat).sum()
                self.num_inst += len(pred_label.flat)
            else:
                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
                self.num_inst[i] += len(pred_label.flat)

6、训练

    train = MultiTask_iter(train)#调用多任务迭代器,其中train参数就是第3步的东西
    val = MultiTask_iter(val)

    new_sym = get_symbol(10,50,image_shape)

    optimizer_params = {
            'learning_rate': 0.001,
            'momentum' : args.mom,
            'wd' : args.wd,
           }
    initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
    model = mx.mod.Module(
        context       = devs,
        symbol        = new_sym,
        data_names=['data'],
        label_names=['softmax1_label','softmax2_label','softmax3_label']
    )
    saveroot = args.save_result+'/' + args.save_name
    checkpoint = mx.callback.do_checkpoint(saveroot)


    model.fit(train,
              begin_epoch=0,
              num_epoch=100000,
              eval_data=val,
              eval_metric=Multi_Accuracy(num=3),#需要注意的地方
              optimizer='sgd',

              optimizer_params=optimizer_params,

              initializer=initializer,
              allow_missing=True,
              batch_end_callback=mx.callback.Speedometer(64, 50),

              epoch_end_callback=checkpoint
              )

你可能感兴趣的:(深度学习)