caffe-将图片转化为siamese网络需要的数据库格式

本文转自:http://blog.csdn.net/sheng_ai/article/details/48174729
在此十分感谢博主分享!
由于最近要用到siamese网络,大概跑了一下caffe自带的siamese网络的例程,发现例程中对于数据格式的转换仅仅局限于mnist数据集,不能直接将其他图片格式的数据集转换为需要的格式,因此,在分析了数据转化的逻辑之后,发现每一张图片的转换过程可以按照下面的步骤执行:
caffe-将图片转化为siamese网络需要的数据库格式_第1张图片


    在参考了caffe自带的convert_imageset.cpp和之后,我编写的格式转换代码如下,这个代码目前只能处理黑白的图像,后期版本会增加对于彩色图像的支持

[cpp]  view plain copy print ?
  1. // This program converts a set of gray images to a leveldb by storing them  
  2. // as Datum proto buffers.  
  3. // Usage:  
  4. //   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME  
  5. //  
  6. // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE  
  7. // should be a list of files as well as their labels, in the format as  
  8. //   subfolder1/file1.JPEG 7  
  9. //   ....  
  10.   
  11. #include <algorithm>  
  12. #include <fstream>  // NOLINT(readability/streams)  
  13. #include <string>  
  14. #include <utility>  
  15. #include <vector>  
  16.   
  17. #include "boost/scoped_ptr.hpp"  
  18. #include "gflags/gflags.h"  
  19. #include "glog/logging.h"  
  20. #include "leveldb\db.h"  
  21.   
  22. #include "caffe/proto/caffe.pb.h"  
  23. #include "caffe/util/io.hpp"  
  24. #include "caffe/util/rng.hpp"  
  25.   
  26. #include "opencv2\opencv.hpp"  
  27.   
  28. using namespace caffe;  // NOLINT(build/namespaces)  
  29. using std::pair;  
  30. using boost::scoped_ptr;  
  31.   
  32. DEFINE_bool(gray, false,  
  33.     "When this option is on, treat images as grayscale ones");  
  34. DEFINE_bool(shuffle, false,  
  35.     "Randomly shuffle the order of images and their labels");  
  36. DEFINE_string(backend, "lmdb",  
  37.     "The backend {lmdb, leveldb} for storing the result");  
  38. DEFINE_int32(resize_width, 0, "Width images are resized to");  
  39. DEFINE_int32(resize_height, 0, "Height images are resized to");  
  40. DEFINE_bool(check_size, false,  
  41.     "When this option is on, check that all the datum have the same size");  
  42. DEFINE_bool(encoded, false,  
  43.     "When this option is on, the encoded image will be save in datum");  
  44. DEFINE_string(encode_type, "",  
  45.     "Optional: What type should we encode the image as ('png','jpg',...).");  
  46.   
  47.   
  48. static bool ReadImageToMemory(const string& FileName, const int Height,  
  49.                               const int Width, char *Pixels)  
  50. {  
  51.     // read image  
  52.     cv::Mat OriginImage = cv::imread(FileName, cv::IMREAD_GRAYSCALE);  
  53.     CHECK(OriginImage.data) << "Failed to read the image.\n";  
  54.   
  55.   
  56.     // resize the image  
  57.     cv::Mat ResizeImage;  
  58.     cv::resize(OriginImage, ResizeImage, cv::Size(Width, Height));  
  59.     CHECK(ResizeImage.rows == Height) << "The heighs of Image is no equal to the input height.\n";  
  60.     CHECK(ResizeImage.cols == Width) << "The width of Image is no equal to the input width.\n";  
  61.     CHECK(ResizeImage.channels() == 1) << "The channel of Image is no equal to one.\n";  
  62.   
  63.     LOG(INFO) << "The height of image is " << ResizeImage.rows << "\n";  
  64.     LOG(INFO) << "The width of image is " << ResizeImage.cols << "\n";  
  65.     LOG(INFO) << "The channels of image is " << ResizeImage.channels() << "\n";  
  66.   
  67.     // copy the image data to Pixels  
  68.     for (int HeightIndex = 0; HeightIndex < Height; ++HeightIndex)  
  69.     {  
  70.         const uchar* ptr = ResizeImage.ptr<uchar>(HeightIndex);  
  71.         int img_index = 0;  
  72.         for (int WidthIndex = 0; WidthIndex < Width; ++WidthIndex)  
  73.         {  
  74.             for (int ChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)  
  75.             {  
  76.                 int datum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;  
  77.                 *(Pixels + datum_index) = static_cast<char>(ptr[img_index++]);  
  78.             }  
  79.         }  
  80.     }  
  81.   
  82.     return true;  
  83. }  
  84.   
  85.   
  86. int main(int argc, char** argv)  
  87. {  
  88.     //::google::InitGoogleLogging(argv[0]);  
  89.   
  90. #ifndef GFLAGS_GFLAGS_H_  
  91.     namespace gflags = google;  
  92. #endif  
  93.   
  94.     gflags::SetUsageMessage("Convert a set of grey images to the leveldb\n"  
  95.         "format used as input for Caffe.\n"  
  96.         "Usage:\n"  
  97.         "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n");  
  98.     caffe::GlobalInit(&argc, &argv);  
  99.   
  100.     // 输入参数不足时报错  
  101.     if (argc < 4)  
  102.     {  
  103.         gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");  
  104.         return 1;  
  105.     }  
  106.   
  107.   
  108.     // 读取图像名字和标签  
  109.     std::ifstream infile(argv[2]);  
  110.     std::vector<std::pair<std::string, int> > lines;  
  111.     std::string filename;  
  112.     int label;  
  113.     while (infile >> filename >> label)  
  114.     {  
  115.         lines.push_back(std::make_pair(filename, label));  
  116.     }  
  117.   
  118.     // 打乱图片顺序  
  119.     if (FLAGS_shuffle)  
  120.     {  
  121.         // randomly shuffle data  
  122.         LOG(INFO) << "Shuffling data";  
  123.         shuffle(lines.begin(), lines.end());  
  124.     }  
  125.     LOG(INFO) << "A total of " << lines.size() << " images.";  
  126.   
  127.   
  128.   
  129.     // 设置图像的高度和宽度  
  130.     int resize_height = std::max<int>(0, FLAGS_resize_height);  
  131.     int resize_width = std::max<int>(0, FLAGS_resize_width);  
  132.   
  133.   
  134.     // 打开数据库  
  135.     // Open leveldb  
  136.     leveldb::DB* db;  
  137.     leveldb::Options options;  
  138.     options.create_if_missing = true;  
  139.     options.error_if_exists = true;  
  140.     leveldb::Status status = leveldb::DB::Open(  
  141.         options, argv[3], &db);  
  142.     CHECK(status.ok()) << "Failed to open leveldb " << argv[3]  
  143.         << ". Is it already existing?";  
  144.   
  145.   
  146.     // 保存到leveldb  
  147.     // Storing to leveldb  
  148.     std::string root_folder(argv[1]);  
  149.     char* Pixels = new char[2 * resize_height * resize_width];  
  150.     const int kMaxKeyLength = 10;  
  151.     char key[kMaxKeyLength];  
  152.     std::string value;  
  153.   
  154.     caffe::Datum datum;  
  155.     datum.set_channels(2);  // one channel for each image in the pair  
  156.     datum.set_height(resize_height);  
  157.     datum.set_width(resize_width);  
  158.   
  159.     //  
  160.     for (int LineIndex = 0; LineIndex < lines.size(); LineIndex++)  
  161.     {  
  162.         int PairIndex = caffe::caffe_rng_rand() % lines.size();  
  163.   
  164.         char* FirstImagePixel = Pixels;  
  165.         ReadImageToMemory(root_folder + lines[LineIndex].first, resize_height, resize_width, FirstImagePixel);  
  166.   
  167.         char *SecondImagePixel = Pixels + resize_width * resize_height;  
  168.         ReadImageToMemory(root_folder + lines[PairIndex].first, resize_height, resize_width, SecondImagePixel);  
  169.   
  170.         // set image pair data  
  171.         datum.set_data(Pixels, 2 * resize_height * resize_width);  
  172.   
  173.         // set label  
  174.         if (lines[LineIndex].second == lines[PairIndex].second)  
  175.         {  
  176.             datum.set_label(1);  
  177.         }  
  178.         else  
  179.         {  
  180.             datum.set_label(0);  
  181.         }  
  182.   
  183.         // serialize datum to string  
  184.         datum.SerializeToString(&value);  
  185.         sprintf_s(key, kMaxKeyLength, "%08d", LineIndex);  
  186.   
  187.         db->Put(leveldb::WriteOptions(), std::string(key), value);  
  188.     }  
  189.   
  190.     delete db;  
  191.     delete[] Pixels;  
  192.   
  193.     return 0;  
  194. }  



版权所有,欢迎转载,转载请注明出处,谢谢
原文地址:http://blog.csdn.net/sheng_ai/article/details/48174729

你可能感兴趣的:(caffe-将图片转化为siamese网络需要的数据库格式)