caffe-将图片转化为siamese网络需要的数据库格式

  由于最近要用到siamese网络,大概跑了一下caffe自带的siamese网络的例程,发现例程中对于数据格式的转换仅仅局限于mnist数据集,不能直接将其他图片格式的数据集转换为需要的格式,因此,在分析了数据转化的逻辑之后,发现每一张图片的转换过程可以按照下面的步骤执行:
caffe-将图片转化为siamese网络需要的数据库格式_第1张图片


    在参考了caffe自带的convert_imageset.cpp和之后,我编写的格式转换代码如下,这个代码目前只能处理黑白的图像,后期版本会增加对于彩色图像的支持

// This program converts a set of gray images to a leveldb by storing them
// as Datum proto buffers.
// Usage:
//   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
//   subfolder1/file1.JPEG 7
//   ....

#include <algorithm>
#include <fstream>  // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>

#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "leveldb\db.h"

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"

#include "opencv2\opencv.hpp"

using namespace caffe;  // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;

DEFINE_bool(gray, false,
    "When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
    "Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
    "The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
DEFINE_bool(check_size, false,
    "When this option is on, check that all the datum have the same size");
DEFINE_bool(encoded, false,
    "When this option is on, the encoded image will be save in datum");
DEFINE_string(encode_type, "",
    "Optional: What type should we encode the image as ('png','jpg',...).");


static bool ReadImageToMemory(const string& FileName, const int Height,
                              const int Width, char *Pixels)
{
    // read image
    cv::Mat OriginImage = cv::imread(FileName, cv::IMREAD_GRAYSCALE);
    CHECK(OriginImage.data) << "Failed to read the image.\n";


    // resize the image
    cv::Mat ResizeImage;
    cv::resize(OriginImage, ResizeImage, cv::Size(Width, Height));
    CHECK(ResizeImage.rows == Height) << "The heighs of Image is no equal to the input height.\n";
    CHECK(ResizeImage.cols == Width) << "The width of Image is no equal to the input width.\n";
    CHECK(ResizeImage.channels() == 1) << "The channel of Image is no equal to one.\n";

    LOG(INFO) << "The height of image is " << ResizeImage.rows << "\n";
    LOG(INFO) << "The width of image is " << ResizeImage.cols << "\n";
    LOG(INFO) << "The channels of image is " << ResizeImage.channels() << "\n";

    // copy the image data to Pixels
    for (int HeightIndex = 0; HeightIndex < Height; ++HeightIndex)
    {
        const uchar* ptr = ResizeImage.ptr<uchar>(HeightIndex);
        int img_index = 0;
        for (int WidthIndex = 0; WidthIndex < Width; ++WidthIndex)
        {
            for (int ChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)
            {
                int datum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;
                *(Pixels + datum_index) = static_cast<char>(ptr[img_index++]);
            }
        }
    }

    return true;
}


int main(int argc, char** argv)
{
    //::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
    namespace gflags = google;
#endif

    gflags::SetUsageMessage("Convert a set of grey images to the leveldb\n"
        "format used as input for Caffe.\n"
        "Usage:\n"
        "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n");
    caffe::GlobalInit(&argc, &argv);

    // 输入参数不足时报错
    if (argc < 4)
    {
        gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
        return 1;
    }


    // 读取图像名字和标签
    std::ifstream infile(argv[2]);
    std::vector<std::pair<std::string, int> > lines;
    std::string filename;
    int label;
    while (infile >> filename >> label)
    {
        lines.push_back(std::make_pair(filename, label));
    }

    // 打乱图片顺序
    if (FLAGS_shuffle)
    {
        // randomly shuffle data
        LOG(INFO) << "Shuffling data";
        shuffle(lines.begin(), lines.end());
    }
    LOG(INFO) << "A total of " << lines.size() << " images.";



    // 设置图像的高度和宽度
    int resize_height = std::max<int>(0, FLAGS_resize_height);
    int resize_width = std::max<int>(0, FLAGS_resize_width);


    // 打开数据库
    // Open leveldb
    leveldb::DB* db;
    leveldb::Options options;
    options.create_if_missing = true;
    options.error_if_exists = true;
    leveldb::Status status = leveldb::DB::Open(
        options, argv[3], &db);
    CHECK(status.ok()) << "Failed to open leveldb " << argv[3]
        << ". Is it already existing?";


    // 保存到leveldb
    // Storing to leveldb
    std::string root_folder(argv[1]);
    char* Pixels = new char[2 * resize_height * resize_width];
    const int kMaxKeyLength = 10;
    char key[kMaxKeyLength];
    std::string value;

    caffe::Datum datum;
    datum.set_channels(2);  // one channel for each image in the pair
    datum.set_height(resize_height);
    datum.set_width(resize_width);

    //
    for (int LineIndex = 0; LineIndex < lines.size(); LineIndex++)
    {
        int PairIndex = caffe::caffe_rng_rand() % lines.size();

        char* FirstImagePixel = Pixels;
        ReadImageToMemory(root_folder + lines[LineIndex].first, resize_height, resize_width, FirstImagePixel);

        char *SecondImagePixel = Pixels + resize_width * resize_height;
        ReadImageToMemory(root_folder + lines[PairIndex].first, resize_height, resize_width, SecondImagePixel);

        // set image pair data
        datum.set_data(Pixels, 2 * resize_height * resize_width);

        // set label
        if (lines[LineIndex].second == lines[PairIndex].second)
        {
            datum.set_label(1);
        }
        else
        {
            datum.set_label(0);
        }

        // serialize datum to string
        datum.SerializeToString(&value);
        sprintf_s(key, kMaxKeyLength, "%08d", LineIndex);

        db->Put(leveldb::WriteOptions(), std::string(key), value);
    }

    delete db;
    delete[] Pixels;

    return 0;
}



版权所有,欢迎转载,转载请注明出处,谢谢微笑




你可能感兴趣的:(caffe-将图片转化为siamese网络需要的数据库格式)