《Weighted Maximum Mean Discrepancy for Unsupervised Domain Adaptation》源码剖析

介绍

这是对论文《Weighted Maximum Mean Discrepancy for Unsupervised Domain Adaptation》作者在github给出的代码示例进行的分析,代码基于Caffe框架扩展,fork于mmd-caffe,主要对MMDLoss(WMMDLoss)进行定义和实现,理论分析,源码地址。

数据集

该方法为迁移学习中domain adaptation类,作者将domain 、target dataset 文件名集中于一个txt文件中,训练数据集和测试数据集文件申明分别位于\data\amazon_to_caltech\traina2c.txt和\data\amazon_to_caltech\testa2c.txt,该文件可以由images/code/data_constructor.py脚本生成,示例采用amazon -> caltech作为数据集。

网络结构

在\model\train_val.prototxt中定义了基于AlexNet的网络结构,可视化如下:
《Weighted Maximum Mean Discrepancy for Unsupervised Domain Adaptation》源码剖析_第1张图片
《Weighted Maximum Mean Discrepancy for Unsupervised Domain Adaptation》源码剖析_第2张图片
在几次卷积等操作后,在第二个fc层使用mmd(实则wmmd方法,作者未改名,the backward function in mmd_layer.cu is adjusted to replace the conventional mmd with weighted mmd specified)和一次dropout,再进行一次mmd(wmmd),除该网络结构模型外,根据作者文章阐述,WMMD方法加在GoggleNet最后一个inception和全连接层,LeNet最后一个卷积层。

wmmd方法使用和实现

可以看到,关键的部分在fc7和fc8_office,具体结构定义如下:

layer {
  name: "fc7"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7"
  param {
    lr_mult: 0.1
    decay_mult: 1.0
  }
  param {
    lr_mult: 0.2
    decay_mult: 0.0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 1.0
    }
  }
}
layer {
  name: "mmd_fc7"
  type: "MMDLoss"
  bottom: "fc7"
  bottom: "label"
  bottom: "weight"
  top: "fc7"
  mmd_param {
    num_of_kernel: 10
    mmd_lambda: 0.3
    iter_of_epoch: 32
    method_param {
      top_num: 5
      i_lambda: 0.0
    }
    method: "none"
    kernel_mul: 2.0
    fix_gamma: false
    mmd_lr: 10.0
    quad_weight: 0.0
    mmd_lock: 1
    num_class: 10
    entropy_thresh: 10
  }
}
layer {
  name: "relu7"
  type: "ReLU"
  bottom: "fc7"
  top: "fc7"
}
layer {
  name: "drop7"
  type: "Dropout"
  bottom: "fc7"
  top: "fc7"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc8_office"
  type: "InnerProduct"
  bottom: "fc7"
  top: "fc8_office"
  param {
    lr_mult: 0.1
    decay_mult: 1.0
  }
  param {
    lr_mult: 0.2
    decay_mult: 0.0
  }
  inner_product_param {
    num_output: 10
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "mmd_fc8"
  type: "MMDLoss"
  bottom: "fc8_office"
  bottom: "label"
  bottom: "weight"
  top: "fc8_office"
  mmd_param {
    num_of_kernel: 10
    mmd_lambda: 0.1
    iter_of_epoch: 32
    method_param {
      top_num: 5
      i_lambda: 0.0
    }
    method: "none"
    kernel_mul: 2.0
    fix_gamma: false
    mmd_lr: 5.0
    quad_weight: 0.0
    mmd_lock: 1
    num_class: 10
    entropy_thresh: 10
  }
}

显然MMDLoss是自定义层,在neuron_layer.hpp 看到MMDLossLayer的定义:

template <typename Dtype>
class NeuronLayer : public Layer<Dtype> {
 public:
  explicit NeuronLayer(const LayerParameter& param)
     : Layer<Dtype>(param) {}
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
 
  virtual inline int ExactNumBottomBlobs() const { return 1; }
  virtual inline int ExactNumTopBlobs() const { return 1; }
};

template <typename Dtype>
class MMDLossLayer : public NeuronLayer<Dtype> {
 public:
  explicit MMDLossLayer(const LayerParameter& param)
      : NeuronLayer<Dtype>(param){}
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  virtual inline const char* type() const { return "MMDLoss"; }
  virtual inline int ExactNumBottomBlobs() const { return -1; }
  virtual inline int ExactNumTopBlobs() const { return -1; }

 protected:
  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
  Dtype* beta_;
  //Dtype* weig_;
  Dtype* sum_of_weig_;
  Blob<Dtype> mmd_data_;
  Dtype mmd_lambda_;
  int input_num_;
  int data_dim_;
  int size_of_source_;
  int size_of_target_;
  Dtype gamma_;
  int num_of_kernel_;
  int* source_index_;
  int* target_index_;
  int iter_of_epoch_;
  int now_iter_;
  int now_iter_test_;
  bool fix_gamma_;
  Dtype** Q_;
  Dtype* sum_of_epoch_;
  Dtype* variance_;
  Dtype I_lambda_;
  int all_sample_num_;
  int top_k_;
  Dtype* sum_of_pure_mmd_;
  int method_number_;
  Dtype kernel_mul_;
  int class_num;
  float mmd_lr_;
  float quad_weight_;
  int mmd_lock_;
  int num_class_;
  int test_inter_;
  int total_iter_test_;
  int total_target_num;
  //for evalutation
  Dtype *count_soft;
  Dtype *avg_entropy;
  int *count_hard;
  int *count_tmp;
  int *source_num_batch;
  int *target_num_batch;
  int *source_num_resamp;
  int *target_num_resamp;
  float cross_entropy;
  float entropy_stand;
  float entropy_thresh_;
};
 

其次mmd_layer.cpp中有该类方法的实现,和mmd_layer.cu中cuda方法实现,//TODO。
在mmd_layer.cpp中LayerSetUp方法为WMMDLoss层的数据处理和几个method(max,none,L2,max_ratio)的处理。

#include 
#include 
#include 

#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/neuron_layers.hpp"

namespace caffe {

template <typename Dtype>
void MMDLossLayer<Dtype>::LayerSetUp(
    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  NeuronLayer<Dtype>::LayerSetUp(bottom, top);
  input_num_ = bottom[0]->count(0, 1);
  data_dim_ = bottom[0]->count(1);
  num_of_kernel_ = this->layer_param_.mmd_param().num_of_kernel();
  mmd_lambda_ = this->layer_param_.mmd_param().mmd_lambda();
  iter_of_epoch_ = this->layer_param_.mmd_param().iter_of_epoch();
  fix_gamma_ = this->layer_param_.mmd_param().fix_gamma();
  beta_ = new Dtype[num_of_kernel_];
  mmd_lr_ = this->layer_param_.mmd_param().mmd_lr();
  quad_weight_ = this->layer_param_.mmd_param().quad_weight();
  mmd_lock_ = this->layer_param_.mmd_param().mmd_lock();//specify when the weights of each layer need update
  num_class_ = this->layer_param_.mmd_param().num_class();
  test_inter_ = this->layer_param_.mmd_param().test_inter();
  total_iter_test_ = this->layer_param_.mmd_param().total_iter_test();
  //weig_ = new Dtype[10];
  sum_of_weig_ = new Dtype[num_class_];

  caffe_set(num_class_,Dtype(0),sum_of_weig_);
  //caffe_set(10,Dtype(1.0),weig_);
  caffe_set(num_of_kernel_, Dtype(1.0) / num_of_kernel_, beta_);
  now_iter_ = 0;
  now_iter_test_ = 0;

  sum_of_epoch_ = new Dtype[num_of_kernel_];
  caffe_set(num_of_kernel_, Dtype(0), sum_of_epoch_);
  gamma_ = Dtype(-1);
  Q_ = new Dtype* [num_of_kernel_];
  for(int i = 0; i < num_of_kernel_; i++){
      Q_[i] = new Dtype[num_of_kernel_];
      caffe_set(num_of_kernel_, Dtype(0), Q_[i]);
  }
  variance_ = new Dtype[num_of_kernel_];
  caffe_set(num_of_kernel_, Dtype(0), variance_);
  sum_of_pure_mmd_ = new Dtype[num_of_kernel_];
  caffe_set(num_of_kernel_, Dtype(0), sum_of_pure_mmd_);
  all_sample_num_ = 0;
  total_target_num = 0;
  kernel_mul_ = this->layer_param_.mmd_param().kernel_mul();
  if(this->layer_param_.mmd_param().method() == "max"){
        method_number_ = 1;
        top_k_ = this->layer_param_.mmd_param().method_param().top_num();
  }
  else if(this->layer_param_.mmd_param().method() == "none"){
        method_number_ = 0;
  }
  else if(this->layer_param_.mmd_param().method() == "L2"){
        method_number_ = 4;
        top_k_ = this->layer_param_.mmd_param().method_param().top_num();
        I_lambda_ = this->layer_param_.mmd_param().method_param().i_lambda();
  }
  else if(this->layer_param_.mmd_param().method() == "max_ratio"){
        top_k_ = this->layer_param_.mmd_param().method_param().top_num();
        method_number_ = 3;
  }
  LOG(INFO) << this->layer_param_.mmd_param().method() << " num: " << method_number_;
  source_index_ = new int[input_num_];
  target_index_ = new int[input_num_];

  mmd_data_.Reshape(1, 1, 1, data_dim_);

  class_num = bottom[2]->count(0);

  count_soft = new Dtype[class_num];
  avg_entropy = new Dtype[class_num];
  count_tmp = new int [class_num];
  caffe_set(class_num,Dtype(0.0),count_soft);
  caffe_set(class_num,Dtype(0.0),avg_entropy);
  count_hard = new int[class_num];
  caffe_set(class_num,0,count_hard);
  caffe_set (class_num,0,count_tmp);
  source_num_batch = new int[class_num];
  target_num_batch = new int[class_num];
  source_num_resamp = new int[class_num];
  target_num_resamp= new int[class_num];
  caffe_set(class_num,0,source_num_batch);
  caffe_set(class_num,0,target_num_batch);
  caffe_set(class_num,0,source_num_resamp);
  caffe_set(class_num,0,target_num_resamp);
  cross_entropy = 0.0;
  entropy_stand = 1.0;//for all class
  entropy_thresh_ = this->layer_param_.mmd_param().entropy_thresh();//for every sample

}

template <typename Dtype>
void MMDLossLayer<Dtype>::Reshape(
    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  NeuronLayer<Dtype>::Reshape(bottom, top);
}

template <typename Dtype>
void MMDLossLayer<Dtype>::Forward_cpu(
  const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
}

template <typename Dtype>
void MMDLossLayer<Dtype>::Backward_cpu(
  const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
  const vector<Blob<Dtype>*>& bottom) {
}

#ifdef CPU_ONLY
STUB_GPU(MMDLossLayer);
#endif

INSTANTIATE_CLASS(MMDLossLayer);
REGISTER_LAYER_CLASS(MMDLoss);

}

作者其它修改的地方

  • mmd layer: the backward function in mmd_layer.cu is adjusted to replace the conventional mmd with weighted mmd specified in the paper
  • softmax loss layer: Instead of ignoring the empirical loss on target domain, we modified the function so that logistic loss is added based on pseudo label.
  • data layer: we add an parameter to the data label so that the number of classes is conveyed to mmd layer.

datalayer

作者修改了image_data_layer.cpp,以提供一个新的top输出-weight,所以在datalayer有三个top输出,所以在声明ImageDataLayer的地方,ExactNumTopBlobs返回值从2修改为3,以下是主要修改的地方:

template <typename Dtype>
class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
 public:
  explicit ImageDataLayer(const LayerParameter& param)
      : BasePrefetchingDataLayer<Dtype>(param) {}
  virtual ~ImageDataLayer();
  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  virtual inline const char* type() const { return "ImageData"; }
  virtual inline int ExactNumBottomBlobs() const { return 0; }
  virtual inline int ExactNumTopBlobs() const { return 3; }

 protected:
  shared_ptr<Caffe::RNG> prefetch_rng_;
  virtual void ShuffleImages();
  virtual void InternalThreadEntry();

  vector<std::pair<std::string, int*> > lines_;
  int lines_id_;
};
  // image
  const int crop_size = this->layer_param_.transform_param().crop_size();
  const int batch_size = this->layer_param_.image_data_param().batch_size();
  if (crop_size > 0) {
    top[0]->Reshape(batch_size, channels, crop_size, crop_size);
    this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
    this->transformed_data_.Reshape(1, channels, crop_size, crop_size);
  } else {
    top[0]->Reshape(batch_size, channels, height, width);
    this->prefetch_data_.Reshape(batch_size, channels, height, width);
    this->transformed_data_.Reshape(1, channels, height, width);
  }
  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label
  vector<int> label_shape;
  label_shape.push_back(batch_size);
  label_shape.push_back(label_dim);
  top[1]->Reshape(label_shape);
  this->prefetch_label_.Reshape(label_shape);
  //weight_
  vector<int> weight_shape;
  weight_shape.push_back(class_num);
  top[2]->Reshape(weight_shape);
  this->prefetch_weight_.Reshape(weight_shape);

  Dtype* top_weight_= top[2]->mutable_cpu_data();
  Dtype* prefetch_weight = this->prefetch_weight_.mutable_cpu_data();
  caffe_set(class_num,Dtype(1.0/class_num),top_weight_);
  caffe_copy(class_num,top[2]->cpu_data(),prefetch_weight);

你可能感兴趣的:(代码剖析)