多任务:分层特征融合网络 NDDR-CNN

In this paper, we propose a novel Convolutional Neural Network (CNN) structure for general-purpose multi-task learning (MTL), which enables automatic feature fusing at every layer from different tasks. This is in contrast with the most widely used MTL CNN structures which empirically or heuristically share features on some specific layers (e.g., share all the features except the last convolutional layer). The proposed layerwise feature fusing scheme is formulated by combining existing CNN components in a novel way, with clear mathematical interpretability as discriminative dimensionality reduction, which is referred to as Neural Discriminative Dimensionality Reduction (NDDR).Specifically, we first concatenate features with the same spatial resolution from different tasks according to their channel dimension. Then, we show that the discriminative dimensionality reduction can be fulfilled by 1 × 1 Convolution,Batch Normalization, and Weight Decay in one CNN. The use of existing CNN components ensures the end-to-end training and the extensibility of the proposed NDDR layer to various state-of-the-art CNN architectures in a “plug-and-play” manner. The detailed ablation analysis shows that the proposed NDDR layer is easy to train and also robust to different hyperparameters. Experiments on different task sets with various base network architectures demonstrate the promising performance and desirable generalizability of our proposed method. The code of our paper is available at https://github.com/ethanygao/NDDR-CNN.

  1. NDDR Layer
    其中K代表有K个任务,每个任务在 l 层对应的feature shape为(N, H, W,C),进行拼接后得到shape为(N,H,W,KC)的feature。 然后,对得到的feature**针对每个任务分别使用**C个(1,1,KC)的卷积核进行卷积,使用C个卷积核是为保证最后的输出为(N,H,W,C),可以直接输入到后续 网络中进行卷积。在完成卷积后,将得到的K个(N,H,W,C)分别在输入到原网络中进行卷积运算。如此反复,直到最后开始计算各个人物的损失。如下图所示:
  2. Shortcuts
  1. 结果对比
    论文分别采用了 VGG-16 和 ResNet-101作为基础网络。
    训练了针对单个任务的网路:single task baseline;针对多多任务的启发式网络:multi-task baseline;并且论文还训练了与文章密切相关的两个网络:cross-stitch network和sluice network作为对比。
    同时文章分别在Semantic Seg任务与Surface Normal Prediction任务中做了对比。
