由于最近要用到siamese网络,大概跑了一下caffe自带的siamese网络的例程,发现例程中对于数据格式的转换仅仅局限于mnist数据集,不能直接将其他图片格式的数据集转换为需要的格式,因此,在分析了数据转化的逻辑之后,发现每一张图片的转换过程可以按照下面的步骤执行:
在参考了caffe自带的convert_imageset.cpp和之后,我编写的格式转换代码如下,这个代码目前只能处理黑白的图像,后期版本会增加对于彩色图像的支持
// This program converts a set of gray images to a leveldb by storing them
// as Datum proto buffers.
// Usage:
// convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "leveldb\db.h"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"
#include "opencv2\opencv.hpp"
using namespace caffe; // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;
DEFINE_bool(gray, false,
"When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
"Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
"The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
DEFINE_bool(check_size, false,
"When this option is on, check that all the datum have the same size");
DEFINE_bool(encoded, false,
"When this option is on, the encoded image will be save in datum");
DEFINE_string(encode_type, "",
"Optional: What type should we encode the image as ('png','jpg',...).");
static bool ReadImageToMemory(const string& FileName, const int Height,
const int Width, char *Pixels)
{
// read image
cv::Mat OriginImage = cv::imread(FileName, cv::IMREAD_GRAYSCALE);
CHECK(OriginImage.data) << "Failed to read the image.\n";
// resize the image
cv::Mat ResizeImage;
cv::resize(OriginImage, ResizeImage, cv::Size(Width, Height));
CHECK(ResizeImage.rows == Height) << "The heighs of Image is no equal to the input height.\n";
CHECK(ResizeImage.cols == Width) << "The width of Image is no equal to the input width.\n";
CHECK(ResizeImage.channels() == 1) << "The channel of Image is no equal to one.\n";
LOG(INFO) << "The height of image is " << ResizeImage.rows << "\n";
LOG(INFO) << "The width of image is " << ResizeImage.cols << "\n";
LOG(INFO) << "The channels of image is " << ResizeImage.channels() << "\n";
// copy the image data to Pixels
for (int HeightIndex = 0; HeightIndex < Height; ++HeightIndex)
{
const uchar* ptr = ResizeImage.ptr<uchar>(HeightIndex);
int img_index = 0;
for (int WidthIndex = 0; WidthIndex < Width; ++WidthIndex)
{
for (int ChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)
{
int datum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;
*(Pixels + datum_index) = static_cast<char>(ptr[img_index++]);
}
}
}
return true;
}
int main(int argc, char** argv)
{
//::google::InitGoogleLogging(argv[0]);
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert a set of grey images to the leveldb\n"
"format used as input for Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n");
caffe::GlobalInit(&argc, &argv);
// 输入参数不足时报错
if (argc < 4)
{
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
return 1;
}
// 读取图像名字和标签
std::ifstream infile(argv[2]);
std::vector<std::pair<std::string, int> > lines;
std::string filename;
int label;
while (infile >> filename >> label)
{
lines.push_back(std::make_pair(filename, label));
}
// 打乱图片顺序
if (FLAGS_shuffle)
{
// randomly shuffle data
LOG(INFO) << "Shuffling data";
shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
// 设置图像的高度和宽度
int resize_height = std::max<int>(0, FLAGS_resize_height);
int resize_width = std::max<int>(0, FLAGS_resize_width);
// 打开数据库
// Open leveldb
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status = leveldb::DB::Open(
options, argv[3], &db);
CHECK(status.ok()) << "Failed to open leveldb " << argv[3]
<< ". Is it already existing?";
// 保存到leveldb
// Storing to leveldb
std::string root_folder(argv[1]);
char* Pixels = new char[2 * resize_height * resize_width];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;
caffe::Datum datum;
datum.set_channels(2); // one channel for each image in the pair
datum.set_height(resize_height);
datum.set_width(resize_width);
//
for (int LineIndex = 0; LineIndex < lines.size(); LineIndex++)
{
int PairIndex = caffe::caffe_rng_rand() % lines.size();
char* FirstImagePixel = Pixels;
ReadImageToMemory(root_folder + lines[LineIndex].first, resize_height, resize_width, FirstImagePixel);
char *SecondImagePixel = Pixels + resize_width * resize_height;
ReadImageToMemory(root_folder + lines[PairIndex].first, resize_height, resize_width, SecondImagePixel);
// set image pair data
datum.set_data(Pixels, 2 * resize_height * resize_width);
// set label
if (lines[LineIndex].second == lines[PairIndex].second)
{
datum.set_label(1);
}
else
{
datum.set_label(0);
}
// serialize datum to string
datum.SerializeToString(&value);
sprintf_s(key, kMaxKeyLength, "%08d", LineIndex);
db->Put(leveldb::WriteOptions(), std::string(key), value);
}
delete db;
delete[] Pixels;
return 0;
}
版权所有,欢迎转载,转载请注明出处,谢谢