caffe源码修改:抽取任意一张图片的特征

转自http://blog.csdn.net/lingerlanlan/article/details/39400375

目前caffe不是很完善,输入的图片数据需要在prototxt指定路径。但是我们往往有这么一个需求:训练后得到一个模型文件,我们想拿这个模型文件来对一张图片抽取特征或者预测分类等。如果非得在prototxt指定路径,就很不方便。因此,这样的工具才是我们需要的:给一个可执行文件通过命令行来传递图片路径,然后caffe读入图片数据,进行一次正向传播。


因此我做了这么一个工具,用来抽取任意一张图片的特征。

这工具的使用方法如下:


extract_one_feature.bin ./model/caffe_reference_imagenet_model ./examples/_temp/imagenet_val.prototxt fc7 ./examples/_temp/features /media/G/imageset/clothing/针织衫/针织衫_426.jpg CPU

参数1:./model/caffe_reference_imagenet_model是训练后的模型文件

参数2:./examples/_temp/imagenet_val.prototxt 网络配置文件

参数3:fc7是blob的名字

参数4:./examples/_temp/features 将该图片的特征保存在该文件

参数5:图片路径

参数6:GPU或者CPU模式


(其实我还想到更好的工具,如果该可执行文件是监听模式的,就是通过一定的方式,给该进程传递 图片路径,进程接到任务就执行。

这样子的话,就不需要每次抽一张图片都要申请内存空间。(*^__^*) 嘻嘻……)


下面给出初步修改方法,大家可以根据自己需求再修改。


extract_one_feature.cpp(该文件参考过源码中extract_features.cpp修改)

[cpp]  view plain copy
  1. #include <stdio.h>  // for snprintf  
  2. #include <string>  
  3. #include <vector>  
  4. #include <iostream>  
  5. #include <fstream>  
  6.   
  7. #include "boost/algorithm/string.hpp"  
  8. #include "google/protobuf/text_format.h"  
  9. #include "leveldb/db.h"  
  10. #include "leveldb/write_batch.h"  
  11.   
  12. #include "caffe/blob.hpp"  
  13. #include "caffe/common.hpp"  
  14. #include "caffe/net.hpp"  
  15. #include "caffe/proto/caffe.pb.h"  
  16. #include "caffe/util/io.hpp"  
  17. #include "caffe/vision_layers.hpp"  
  18.   
  19. using namespace caffe;  // NOLINT(build/namespaces)  
  20.   
  21. template<typename Dtype>  
  22. int feature_extraction_pipeline(int argc, char** argv);  
  23.   
  24. int main(int argc, char** argv) {  
  25.   return feature_extraction_pipeline<float>(argc, argv);  
  26. //  return feature_extraction_pipeline<double>(argc, argv);  
  27. }  
  28.   
  29. template<typename Dtype>  
  30. class writeDb  
  31. {  
  32. public:  
  33.     void open(string dbName)  
  34.     {  
  35.         db.open(dbName.c_str());  
  36.     }  
  37.     void write(const Dtype &data)  
  38.     {  
  39.         db<<data;  
  40.     }  
  41.     void write(const string &str)  
  42.     {  
  43.         db<<str;  
  44.     }  
  45.     virtual ~writeDb()  
  46.     {  
  47.         db.close();  
  48.     }  
  49. private:  
  50.     std::ofstream db;  
  51. };  
  52.   
  53. template<typename Dtype>  
  54. int feature_extraction_pipeline(int argc, char** argv) {  
  55.   ::google::InitGoogleLogging(argv[0]);  
  56.   const int num_required_args = 6;  
  57.   if (argc < num_required_args) {  
  58.     LOG(ERROR)<<  
  59.     "This program takes in a trained network and an input data layer, and then"  
  60.     " extract features of the input data produced by the net.\n"  
  61.     "Usage: extract_features  pretrained_net_param"  
  62.     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"  
  63.     "  save_feature_leveldb_name1[,name2,...]  img_path  [CPU/GPU]"  
  64.     "  [DEVICE_ID=0]\n"  
  65.     "Note: you can extract multiple features in one pass by specifying"  
  66.     " multiple feature blob names and leveldb names seperated by ','."  
  67.     " The names cannot contain white space characters and the number of blobs"  
  68.     " and leveldbs must be equal.";  
  69.     return 1;  
  70.   }  
  71.   int arg_pos = num_required_args;  
  72.   
  73.   arg_pos = num_required_args;  
  74.   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {  
  75.     LOG(ERROR)<< "Using GPU";  
  76.     uint device_id = 0;  
  77.     if (argc > arg_pos + 1) {  
  78.       device_id = atoi(argv[arg_pos + 1]);  
  79.       CHECK_GE(device_id, 0);  
  80.     }  
  81.     LOG(ERROR) << "Using Device_id=" << device_id;  
  82.     Caffe::SetDevice(device_id);  
  83.     Caffe::set_mode(Caffe::GPU);  
  84.   } else {  
  85.     LOG(ERROR) << "Using CPU";  
  86.     Caffe::set_mode(Caffe::CPU);  
  87.   }  
  88.   Caffe::set_phase(Caffe::TEST);  
  89.   
  90.   arg_pos = 0;  // the name of the executable  
  91.   string pretrained_binary_proto(argv[++arg_pos]);//网络模型参数文件  
  92.   
  93.   string feature_extraction_proto(argv[++arg_pos]);  
  94.   
  95.   shared_ptr<Net<Dtype> > feature_extraction_net(  
  96.       new Net<Dtype>(feature_extraction_proto));  
  97.   
  98.   feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);//将网络参数load进内存  
  99.   
  100.   
  101.   string extract_feature_blob_names(argv[++arg_pos]);  
  102.   vector<string> blob_names;//要抽取特征的layer的名字,可以是多个  
  103.   boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));  
  104.   
  105.   string save_feature_leveldb_names(argv[++arg_pos]);  
  106.   vector<string> leveldb_names;// 这里我改写成一个levedb为一个文件,数据格式不使用真正的levedb,而是自定义  
  107.   boost::split(leveldb_names, save_feature_leveldb_names,  
  108.                boost::is_any_of(","));  
  109.   CHECK_EQ(blob_names.size(), leveldb_names.size()) <<  
  110.       " the number of blob names and leveldb names must be equal";  
  111.   size_t num_features = blob_names.size();  
  112.   
  113.   for (size_t i = 0; i < num_features; i++) {  
  114.     CHECK(feature_extraction_net->has_blob(blob_names[i]))  //检测blob的名字在网络中是否存在  
  115.         << "Unknown feature blob name " << blob_names[i]  
  116.         << " in the network " << feature_extraction_proto;  
  117.   }  
  118.   
  119.   
  120.   vector<shared_ptr<writeDb<Dtype> > > feature_dbs;  
  121.   for (size_t i = 0; i < num_features; ++i) //打开db,准备写入数据  
  122.   {  
  123.     LOG(INFO)<< "Opening db " << leveldb_names[i];  
  124.     writeDb<Dtype>* db = new writeDb<Dtype>();  
  125.     db->open(leveldb_names[i]);  
  126.     feature_dbs.push_back(shared_ptr<writeDb<Dtype> >(db));  
  127.   }  
  128.   
  129.   
  130.   
  131.   LOG(ERROR)<< "Extacting Features";  
  132.   
  133.   const shared_ptr<Layer<Dtype> > layer = feature_extraction_net->layer_by_name("data");//获取第一层  
  134.   MyImageDataLayer<Dtype>* my_layer = (MyImageDataLayer<Dtype>*)layer.get();  
  135.   my_layer->setImgPath(argv[++arg_pos],1);//"/media/G/imageset/clothing/针织衫/针织衫_1.jpg"  
  136.   //设置图片路径  
  137.   
  138.   vector<Blob<float>*> input_vec;  
  139.   vector<int> image_indices(num_features, 0);  
  140.   int num_mini_batches = 1;//atoi(argv[++arg_pos]);//共多少次迭代。  每次迭代的数量在prototxt用batchsize指定  
  141.   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) //共num_mini_batches次迭代  
  142.   {  
  143.     feature_extraction_net->Forward(input_vec);//一次正向传播  
  144.     for (int i = 0; i < num_features; ++i) //多层特征  
  145.     {  
  146.       const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net  
  147.           ->blob_by_name(blob_names[i]);  
  148.       int batch_size = feature_blob->num();  
  149.       int dim_features = feature_blob->count() / batch_size;  
  150.   
  151.       Dtype* feature_blob_data;  
  152.   
  153.       for (int n = 0; n < batch_size; ++n)  
  154.       {  
  155.         feature_blob_data = feature_blob->mutable_cpu_data() +  
  156.             feature_blob->offset(n);  
  157.         feature_dbs[i]->write("3 ");  
  158.         for (int d = 0; d < dim_features; ++d)  
  159.         {  
  160.           feature_dbs[i]->write((Dtype)(d+1));  
  161.           feature_dbs[i]->write(":");  
  162.           feature_dbs[i]->write(feature_blob_data[d]);  
  163.           feature_dbs[i]->write(" ");  
  164.         }  
  165.         feature_dbs[i]->write("\n");  
  166.   
  167.       }  // for (int n = 0; n < batch_size; ++n)  
  168.     }  // for (int i = 0; i < num_features; ++i)  
  169.   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)  
  170.   
  171.   
  172.   LOG(ERROR)<< "Successfully extracted the features!";  
  173.   return 0;  
  174. }  

my_data_layer.cpp(参考image_data_layer修改)

[cpp]  view plain copy
  1. #include <fstream>  // NOLINT(readability/streams)  
  2. #include <iostream>  // NOLINT(readability/streams)  
  3. #include <string>  
  4. #include <utility>  
  5. #include <vector>  
  6.   
  7. #include "caffe/layer.hpp"  
  8. #include "caffe/util/io.hpp"  
  9. #include "caffe/util/math_functions.hpp"  
  10. #include "caffe/util/rng.hpp"  
  11. #include "caffe/vision_layers.hpp"  
  12.   
  13. namespace caffe {  
  14.   
  15.   
  16. template <typename Dtype>  
  17. MyImageDataLayer<Dtype>::~MyImageDataLayer<Dtype>() {  
  18. }  
  19.   
  20.   
  21. template <typename Dtype>  
  22. void MyImageDataLayer<Dtype>::setImgPath(string path,int label)  
  23. {  
  24.     lines_.clear();  
  25.     lines_.push_back(std::make_pair(path, label));  
  26. }  
  27.   
  28.   
  29. template <typename Dtype>  
  30. void MyImageDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,  
  31.       vector<Blob<Dtype>*>* top) {  
  32.   Layer<Dtype>::SetUp(bottom, top);  
  33.   const int new_height  = this->layer_param_.image_data_param().new_height();  
  34.   const int new_width  = this->layer_param_.image_data_param().new_width();  
  35.   CHECK((new_height == 0 && new_width == 0) ||  
  36.       (new_height > 0 && new_width > 0)) << "Current implementation requires "  
  37.       "new_height and new_width to be set at the same time.";  
  38.   
  39.   /* 
  40.    * 因为下面需要随便拿一张图片来初始化blob。 
  41.    * 因此需要硬盘上有一张图片。 
  42.    * 1 从prototxt读取一张图片的路径, 
  43.    * 2 其实也可以在这里将用于初始化的图片路径写死 
  44.   */  
  45.   
  46.   /*1*/  
  47.   /* 
  48.   const string& source = this->layer_param_.image_data_param().source(); 
  49.   LOG(INFO) << "Opening file " << source; 
  50.   std::ifstream infile(source.c_str()); 
  51.   string filename; 
  52.   int label; 
  53.   while (infile >> filename >> label) { 
  54.     lines_.push_back(std::make_pair(filename, label)); 
  55.   } 
  56.   */  
  57.   
  58.   /*2*/  
  59.   lines_.push_back(std::make_pair("/home/linger/init.jpg",1));  
  60.   
  61.   //上面1和2代码可以任意用一段  
  62.   
  63.   lines_id_ = 0;  
  64.   // Read a data point, and use it to initialize the top blob. (随便)读取一张图片,来初始化blob  
  65.   Datum datum;  
  66.   CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,  
  67.                          new_height, new_width, &datum));  
  68.   // image  
  69.   const int crop_size = this->layer_param_.image_data_param().crop_size();  
  70.   const int batch_size = 1;//this->layer_param_.image_data_param().batch_size();  
  71.   const string& mean_file = this->layer_param_.image_data_param().mean_file();  
  72.   if (crop_size > 0) {  
  73.     (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);  
  74.     prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);  
  75.   } else {  
  76.     (*top)[0]->Reshape(batch_size, datum.channels(), datum.height(),  
  77.                        datum.width());  
  78.     prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),  
  79.         datum.width());  
  80.   }  
  81.   LOG(INFO) << "output data size: " << (*top)[0]->num() << ","  
  82.       << (*top)[0]->channels() << "," << (*top)[0]->height() << ","  
  83.       << (*top)[0]->width();  
  84.   // label  
  85.   (*top)[1]->Reshape(batch_size, 1, 1, 1);  
  86.   prefetch_label_.Reshape(batch_size, 1, 1, 1);  
  87.   // datum size  
  88.   datum_channels_ = datum.channels();  
  89.   datum_height_ = datum.height();  
  90.   datum_width_ = datum.width();  
  91.   datum_size_ = datum.channels() * datum.height() * datum.width();  
  92.   CHECK_GT(datum_height_, crop_size);  
  93.   CHECK_GT(datum_width_, crop_size);  
  94.   // check if we want to have mean  
  95.   if (this->layer_param_.image_data_param().has_mean_file()) {  
  96.     BlobProto blob_proto;  
  97.     LOG(INFO) << "Loading mean file from" << mean_file;  
  98.     ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);  
  99.     data_mean_.FromProto(blob_proto);  
  100.     CHECK_EQ(data_mean_.num(), 1);  
  101.     CHECK_EQ(data_mean_.channels(), datum_channels_);  
  102.     CHECK_EQ(data_mean_.height(), datum_height_);  
  103.     CHECK_EQ(data_mean_.width(), datum_width_);  
  104.   } else {  
  105.     // Simply initialize an all-empty mean.  
  106.     data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);  
  107.   }  
  108.   // Now, start the prefetch thread. Before calling prefetch, we make two  
  109.   // cpu_data calls so that the prefetch thread does not accidentally make  
  110.   // simultaneous cudaMalloc calls when the main thread is running. In some  
  111.   // GPUs this seems to cause failures if we do not so.  
  112.   prefetch_data_.mutable_cpu_data();  
  113.   prefetch_label_.mutable_cpu_data();  
  114.   data_mean_.cpu_data();  
  115.   
  116.   
  117. }  
  118.   
  119. //--------------------------------下面是读取一张图片数据-----------------------------------------------  
  120. template <typename Dtype>  
  121. void MyImageDataLayer<Dtype>::fetchData() {  
  122.       Datum datum;  
  123.       CHECK(prefetch_data_.count());  
  124.       Dtype* top_data = prefetch_data_.mutable_cpu_data();  
  125.       Dtype* top_label = prefetch_label_.mutable_cpu_data();  
  126.       ImageDataParameter image_data_param = this->layer_param_.image_data_param();  
  127.       const Dtype scale = image_data_param.scale();//image_data_layer相关参数  
  128.       const int batch_size = 1;//image_data_param.batch_size(); 这里我们只需要一张图片  
  129.   
  130.       const int crop_size = image_data_param.crop_size();  
  131.       const bool mirror = image_data_param.mirror();  
  132.       const int new_height = image_data_param.new_height();  
  133.       const int new_width = image_data_param.new_width();  
  134.   
  135.       if (mirror && crop_size == 0) {  
  136.         LOG(FATAL) << "Current implementation requires mirror and crop_size to be "  
  137.             << "set at the same time.";  
  138.       }  
  139.       // datum scales  
  140.       const int channels = datum_channels_;  
  141.       const int height = datum_height_;  
  142.       const int width = datum_width_;  
  143.       const int size = datum_size_;  
  144.       const int lines_size = lines_.size();  
  145.       const Dtype* mean = data_mean_.cpu_data();  
  146.   
  147.       for (int item_id = 0; item_id < batch_size; ++item_id) {//读取一图片  
  148.         // get a blob  
  149.         CHECK_GT(lines_size, lines_id_);  
  150.         if (!ReadImageToDatum(lines_[lines_id_].first,  
  151.               lines_[lines_id_].second,  
  152.               new_height, new_width, &datum)) {  
  153.           continue;  
  154.         }  
  155.         const string& data = datum.data();  
  156.         if (crop_size) {  
  157.           CHECK(data.size()) << "Image cropping only support uint8 data";  
  158.           int h_off, w_off;  
  159.           // We only do random crop when we do training.  
  160.             h_off = (height - crop_size) / 2;  
  161.             w_off = (width - crop_size) / 2;  
  162.   
  163.             // Normal copy 正常读取,把裁剪后的图片数据读给top_data  
  164.             for (int c = 0; c < channels; ++c) {  
  165.               for (int h = 0; h < crop_size; ++h) {  
  166.                 for (int w = 0; w < crop_size; ++w) {  
  167.                   int top_index = ((item_id * channels + c) * crop_size + h)  
  168.                                   * crop_size + w;  
  169.                   int data_index = (c * height + h + h_off) * width + w + w_off;  
  170.                   Dtype datum_element =  
  171.                       static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));  
  172.                   top_data[top_index] = (datum_element - mean[data_index]) * scale;  
  173.                 }  
  174.               }  
  175.             }  
  176.   
  177.         } else {  
  178.           // Just copy the whole data 正常读取,把图片数据读给top_data  
  179.           if (data.size()) {  
  180.             for (int j = 0; j < size; ++j) {  
  181.               Dtype datum_element =  
  182.                   static_cast<Dtype>(static_cast<uint8_t>(data[j]));  
  183.               top_data[item_id * size + j] = (datum_element - mean[j]) * scale;  
  184.             }  
  185.           } else {  
  186.             for (int j = 0; j < size; ++j) {  
  187.               top_data[item_id * size + j] =  
  188.                   (datum.float_data(j) - mean[j]) * scale;  
  189.             }  
  190.           }  
  191.         }  
  192.         top_label[item_id] = datum.label();//读取该图片的标签  
  193.   
  194.       }  
  195. }  
  196.   
  197. template <typename Dtype>  
  198. Dtype MyImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  199.       vector<Blob<Dtype>*>* top) {  
  200.   
  201.   //更新input  
  202.     fetchData();  
  203.   
  204.   // Copy the data  
  205.   caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),  
  206.              (*top)[0]->mutable_cpu_data());  
  207.   caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),  
  208.              (*top)[1]->mutable_cpu_data());  
  209.   
  210.   return Dtype(0.);  
  211. }  
  212.   
  213. #ifdef CPU_ONLY  
  214. STUB_GPU_FORWARD(ImageDataLayer, Forward);  
  215. #endif  
  216.   
  217. INSTANTIATE_CLASS(MyImageDataLayer);  
  218.   
  219. }  // namespace caffe  


在data_layers.hpp添加一下代码,参考ImageDataLayer写的。

[cpp]  view plain copy
  1. template <typename Dtype>  
  2. class MyImageDataLayer : public Layer<Dtype>  {  
  3.  public:  
  4.   explicit MyImageDataLayer(const LayerParameter& param)  
  5.       : Layer<Dtype>(param) {}  
  6.   virtual ~MyImageDataLayer();  
  7.   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,  
  8.       vector<Blob<Dtype>*>* top);  
  9.   
  10.   virtual inline LayerParameter_LayerType type() const {  
  11.     return LayerParameter_LayerType_MY_IMAGE_DATA;  
  12.   }  
  13.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  14.   virtual inline int ExactNumTopBlobs() const { return 2; }  
  15.   void fetchData();  
  16.   void setImgPath(string path,int label);  
  17.  protected:  
  18.   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  19.       vector<Blob<Dtype>*>* top);  
  20.   
  21.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  22.       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}  
  23.   
  24.   
  25.   vector<std::pair<std::string, int> > lines_;  
  26.   int lines_id_;  
  27.   int datum_channels_;  
  28.   int datum_height_;  
  29.   int datum_width_;  
  30.   int datum_size_;  
  31.   Blob<Dtype> prefetch_data_;  
  32.   Blob<Dtype> prefetch_label_;  
  33.   Blob<Dtype> data_mean_;  
  34.   Caffe::Phase phase_;  
  35. };  


修改caffe.proto,在适当的位置添加下面信息,也是参考image_data写的。


MY_IMAGE_DATA = 36;


optional MyImageDataParameter my_image_data_param = 36;


// Message that stores parameters used by MyImageDataLayer
message MyImageDataParameter {
  // Specify the data source.
  optional string source = 1;
  // For data pre-processing, we can do simple scaling and subtracting the
  // data mean, if provided. Note that the mean subtraction is always carried
  // out before scaling.
  optional float scale = 2 [default = 1];
  optional string mean_file = 3;
  // Specify the batch size.
  optional uint32 batch_size = 4;
  // Specify if we would like to randomly crop an image.
  optional uint32 crop_size = 5 [default = 0];
  // Specify if we want to randomly mirror data.
  optional bool mirror = 6 [default = false];
  // The rand_skip variable is for the data layer to skip a few data points
  // to avoid all asynchronous sgd clients to start at the same point. The skip
  // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
  // be larger than the number of keys in the leveldb.
  optional uint32 rand_skip = 7 [default = 0];
  // Whether or not ImageLayer should shuffle the list of files at every epoch.
  optional bool shuffle = 8 [default = false];
  // It will also resize images if new_height or new_width are not zero.
  optional uint32 new_height = 9 [default = 0];
  optional uint32 new_width = 10 [default = 0];
}


以上每行位置不在一起,可以参考读一个image_data对应的位置。


你可能感兴趣的:(caffe源码修改:抽取任意一张图片的特征)