Caffe 源码的修改(用于车辆的定位)

Caffe 源码的修改(用于车辆的定位)

主页 分类 标签 链接 关于

修改后的caffe源码 https://gitcafe.com/lxiongh/Caffe_for_Multi-label

Caffe 的源码仅适用于模式分类问题,其标签是一维的。为了能够将 Caffe 代码应用于车辆的检测问题上,首先要解决的问题就是将 Caffe 源码中与标签相关的代码由一维改为多维。本项目由于只需预测车辆的上下两角的坐标(x1,y1, x2,y2),即将标签修改为4维。

这次能够很顺利的修改源码,很大一原因是实验室的师弟太给力了,他们对 Caffe 源码的研究开始地很早,并在以 Caffe 作为其深度学习的工具,并且在也发表了一些很不错的文章。Thanks to TianShui, LiangLi, JinZhu. 『车辆定位项目链接 http://lxiongh.github.io/2015/02/04/car_localization/』

首先,利用 Understanding 软件,可以方便的查看到 caffe 源码的目录结构,如下图所示。Caffe 源码的修改(用于车辆的定位)_第1张图片

可以注意到,在 Caffe 源码里有一个『Tools』的目录,里面有一些相当有用的工具,如『compute_image_mean.cpp』、『convert_imageset.cpp』等,其中『convert_imageset.cpp』直接操作到了文本文件,如下列代码所示。那么函数『ReadImageToDatum』将就是突破口。

1 // convert_imageset.cpp, line 129-
2 for (int line_id = 0; line_id < lines.size(); ++line_id) {
3     if (!ReadImageToDatum(root_folder + lines[line_id].first,
4         lines[line_id].second, resize_height, resize_width, is_color, &datum)) {
5       continue;
6     }
7     // ...
8 }

Caffe 源码的修改(用于车辆的定位)_第2张图片

定义在io.cpp里的函数ReadImageToDatum完成了将图片数据转换成caffe能够处理的Datum类型,主要修改的文件大都集中在数据层,将其单标签改成多标签支持。针对车辆检测,对caffe所作的修改有如下的部分:

  • caffe.proto

caffe.proto

optional改为repeated,使得标签变量label为数组,即支持多标签。否则无此属性Datum.label_size()

  • data_layer.hpp

data_layer_hpp

lines_protected修改为public,使得后续能够利用指针直接访问lines_数据,其中保存了图片名及其对应的标签信息。详见test_det_net.cpp

  • data_layer.cpp

Caffe 源码的修改(用于车辆的定位)_第3张图片Caffe 源码的修改(用于车辆的定位)_第4张图片

修改top_label,使得其保存图片的多标签信息。

  • image_data_layer.cpp

Caffe 源码的修改(用于车辆的定位)_第5张图片image_data_layer_2.cppimage_data_layer_3.cpp

从文本文件里读取图片的路径及标签信息,将原来int label修改成std::<vector> vec_label。同时需要特别注意的就是不要忘记申请相应的存储空间(*top)[1]->Reshape(this->...),否则在初始化网络时就会出现错误。

  • memory_data_layer.cpp

memory_data_layer.cpp

虽然这个在实际应用中没有用到,但因其涉及到最底层的数据层,所以也修改了。

  • convert_imageset.cppCaffe 源码的修改(用于车辆的定位)_第6张图片

    这个程序是将图片数据打包与数据库的形式,默认为leveldb

  • io.hpp

io_hpp

  • io.cpp

io_1.cppio_1.cpp

这个文件涉及到最底层的数据读写工作。

  • test_det_net.cpp,由extract_feature.cpp修改而来的对输入的图片进行车辆预测,并画框输出
 1 #include <stdio.h> // for snprintf
 2 #include <string>
 3 #include <vector>
 4 
 5 #include "boost/algorithm/string.hpp"
 6 #include "google/protobuf/text_format.h"
 7 #include "leveldb/db.h"
 8 #include "leveldb/write_batch.h"
 9 
 10 #include "caffe/blob.hpp"
 11 #include "caffe/common.hpp"
 12 #include "caffe/net.hpp"
 13 #include "caffe/proto/caffe.pb.h"
 14 #include "caffe/util/io.hpp"
 15 #include "caffe/vision_layers.hpp"
 16 
 17 // liu 
 18 #include <opencv2/core/core.hpp>
 19 #include <opencv2/highgui/highgui.hpp>
 20 #include <opencv2/highgui/highgui_c.h>
 21 #include <opencv2/imgproc/imgproc.hpp>
 22 #include <opencv2/opencv.hpp>
 23 #include <sys/stat.h> // for mkdir
 24 
 25 using namespace caffe;  // NOLINT(build/namespaces)
 26 
 27 template<typename Dtype>
 28 int feature_extraction_pipeline(int argc, char** argv);
 29 
 30 int main(int argc, char** argv) {
 31   return feature_extraction_pipeline<float>(argc, argv);
 32 // return feature_extraction_pipeline<double>(argc, argv);
 33 }
 34 
 35 template<typename Dtype>
 36 int feature_extraction_pipeline(int argc, char** argv) {
 37   ::google::InitGoogleLogging(argv[0]);
 38   const int num_required_args = 5;
 39   if (argc < num_required_args) {
 40     LOG(ERROR)<<
 41     "This program takes in a trained network and an input data layer, and then"
 42     " extract features of the input data produced by the net.\n"
 43     "Usage: test_det_net pretrained_net_param"
 44     " feature_extraction_proto_file num_mini_batches"
 45         " output_dir"
 46     " [CPU/GPU] [DEVICE_ID=0]\n"
 47     "Note: the feature blob names is fixed as 'fc_8_det' in code\n";
 48     return 1;
 49   }
 50   int arg_pos = num_required_args;
 51 
 52   arg_pos = num_required_args;
 53   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
 54     LOG(ERROR)<< "Using GPU";
 55     uint device_id = 0;
 56     if (argc > arg_pos + 1) {
 57       device_id = atoi(argv[arg_pos + 1]);
 58       CHECK_GE(device_id, 0);
 59     }
 60     LOG(ERROR) << "Using Device_id=" << device_id;
 61     Caffe::SetDevice(device_id);
 62     Caffe::set_mode(Caffe::GPU);
 63   } else {
 64     LOG(ERROR) << "Using CPU";
 65     Caffe::set_mode(Caffe::CPU);
 66   }
 67   Caffe::set_phase(Caffe::TEST);
 68 
 69   arg_pos = 0;  // the name of the executable
 70   string pretrained_binary_proto(argv[++arg_pos]);
 71 
 72   string feature_extraction_proto(argv[++arg_pos]);
 73   shared_ptr<Net<Dtype> > feature_extraction_net(
 74       new Net<Dtype>(feature_extraction_proto));
 75   feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
 76     // to get image_paths
 77     const vector<shared_ptr<Layer<float> > > layers = feature_extraction_net->layers();
 78     const caffe::ImageDataLayer<float> *image_layer = dynamic_cast<caffe::ImageDataLayer<float>* >(layers[0].get());
 79     CHECK(image_layer);
 80             
 81   const string blob_name = "fc_8_det";
 82   
 83   CHECK(feature_extraction_net->has_blob(blob_name))   \
 84         << "Unknown feature blob name " << blob_name      \
 85         << " in the network " << feature_extraction_proto;
 86 
 87 
 88   int num_mini_batches = atoi(argv[++arg_pos]);
 89     string output_dir = argv[++arg_pos];
 90     CHECK_EQ(mkdir(output_dir.c_str(),0744), 0) << "mkdir " << output_dir << " failed";
 91 
 92   LOG(ERROR)<< "Extracting Features";
 93 
 94   vector<Blob<float>*> input_vec;
 95   int image_index=0;
 96   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
 97     feature_extraction_net->Forward(input_vec);
 98         
 99         const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net->blob_by_name(blob_name);
100     
101         int batch_size = feature_blob->num();
102         
103         int dim_features = feature_blob->count() / batch_size;
104         CHECK_EQ(dim_features, 4) << "the dim of feature is not equal to 4";
105         
106         Dtype* feature_blob_data;
107         int x1, y1, x2, y2;
108         for (int n = 0; n < batch_size; ++n) {
109             feature_blob_data = feature_blob->mutable_cpu_data() + feature_blob->offset(n);
110 
111             x1 = feature_blob_data[0];
112             y1 = feature_blob_data[1];
113             x2 = feature_blob_data[2];
114             y2 = feature_blob_data[3];
115             
116             string image_path = image_layer->lines_[image_index].first;
117             //LOG(ERROR) << "image_index " << image_index << " " << image_path \
118  << " x1 " << feature_blob_data[0] << " y1 " << feature_blob_data[1] \
119  << " x2 " << feature_blob_data[2] << " y2 " << feature_blob_data[3];
120             
121             cv::Mat img_origin = cv::imread(image_path);
122             
123             std::vector<string> part_names;
124             boost::split(part_names, image_path, boost::is_any_of("/"));
125             string subname = part_names[part_names.size()-1];             // the last element is the image name.
126             string out_path(output_dir + "/" + subname);
127             
128             //LOG(ERROR) << subname;
129             line(img_origin, cv::Point(x1, y1), cv::Point(x2, y1), cv::Scalar(0, 0, 255), 3);
130             line(img_origin, cv::Point(x2, y1), cv::Point(x2, y2), cv::Scalar(0, 0, 255), 3);
131             line(img_origin, cv::Point(x2, y2), cv::Point(x1, y2), cv::Scalar(0, 0, 255), 3);
132             line(img_origin, cv::Point(x1, y2), cv::Point(x1, y1), cv::Scalar(0, 0, 255), 3);
133             CHECK_EQ(imwrite(output_dir + "/" + subname, img_origin), true) << "write image " + out_path + " failed";
134             
135             image_index ++ ;
136             if (image_index>=image_layer->lines_.size()){
137                 LOG(ERROR) << "Restarting data prefetching from start.";
138                 image_index = 0;
139             }
140             // (image_index>image_layer->lines_.size()-1)?(image_index=0):(image_index++);
141         }
142         
143   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
144   // write the last batch
145   
146   LOG(ERROR)<< "Successfully extracted the features!";
147   return 0;  
148 }

你可能感兴趣的:(Caffe 源码的修改(用于车辆的定位))