在上一篇文章中,我们讨论了如何使用 PyTorch C++ API 实现 VGG-16 来识别 MNIST 数据集。这篇文章我们讨论一下如何用 C++ API 使用自定义数据集。
我们先来看一下上一篇教程中我们是怎么读取数据的:
auto data_loader = torch::data::make_data_loader(
std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081))).map(
torch::data::transforms::Stack<>()), 64);
我们来细细讲解。
首先,我们将数据集读入 tensor:
auto data_set = torch::data::datasets::MNIST("../data");
接下来,我们应用一些 transforms :
auto data_set = data_set.map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(torch::data::transforms::Stack<>())
我们 batch_size 为 64:
std::move(data_set, 64);
然后我们就可以将数据传给 data loader 然后由 data loader 传给网络。
我们需要了解一下这背后到底是怎么工作的,因此,我们看一下 MNIST 读取的源码文件 torch::data::datasets::MNIST 类,源码地址在这里:
namespace torch {
namespace data {
namespace datasets {
/// The MNIST dataset.
class TORCH_API MNIST : public Dataset {
public:
/// The mode in which the dataset is loaded.
enum class Mode { kTrain, kTest };
/// Loads the MNIST dataset from the root path.
///
/// The supplied root path should contain the content of the unzipped
/// MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);
/// Returns the Example at the given index.
Example<> get(size_t index) override;
/// Returns the size of the dataset.
optional size() const override;
/// Returns true if this is the training subset of MNIST.
bool is_train() const noexcept;
/// Returns all images stacked into a single tensor.
const Tensor& images() const;
/// Returns all targets stacked into a single tensor.
const Tensor& targets() const;
private:
Tensor images_, targets_;
};
} // namespace datasets
} // namespace data
} // namespace torch
对于 MNIST 类的构造器:
MNIST::MNIST(const std::string& root, Mode mode)
: images_(read_images(root, mode == Mode::kTrain)),
targets_(read_targets(root, mode == Mode::kTrain)) {}
我们可以看到这里调用了两个函数:
read_images(root, mode) 读取图像
read_targets(root, mode) 读取图像标签
我们来看看这两个函数具体怎么工作。
read_images(root, mode)
Tensor read_images(const std::string& root, bool train) {
// kTrainImagesFilename and kTestImagesFilename are specific to MNIST dataset here
// No need for using join_paths here
const auto path =
join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
// Load images
std::ifstream images(path, std::ios::binary);
TORCH_CHECK(images, "Error opening images file at ", path);
// kTrainSize = len(training data)
// kTestSize = len(testing_data)
const auto count = train ? kTrainSize : kTestSize;
// Specific to MNIST data
// From http://yann.lecun.com/exdb/mnist/
expect_int32(images, kImageMagicNumber);
expect_int32(images, count);
expect_int32(images, kImageRows);
expect_int32(images, kImageColumns);
// This converts images to tensors
// Allocate an empty tensor of size of image (count, channels, height, width)
auto tensor =
torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
// Read image and convert to tensor
images.read(reinterpret_cast(tensor.data_ptr()), tensor.numel());
// Normalize the image from 0 to 255 to 0 to 1
return tensor.to(torch::kFloat32).div_(255);
}
read_targets(root, mode)
Tensor read_targets(const std::string& root, bool train) {
// Specific to MNIST dataset (kTrainImagesFilename and kTestTargetsFilename)
const auto path =
join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
// Read the labels
std::ifstream targets(path, std::ios::binary);
TORCH_CHECK(targets, "Error opening targets file at ", path);
// kTrainSize = len(training_labels)
// kTestSize = len(testing_labels)
const auto count = train ? kTrainSize : kTestSize;
expect_int32(targets, kTargetMagicNumber);
expect_int32(targets, count);
// Allocate an empty tensor of size of number of labels
auto tensor = torch::empty(count, torch::kByte);
// Convert to tensor
targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count);
return tensor.to(torch::kInt64);
}
还有一些辅助函数:
Example<> MNIST::get(size_t index) {
return {images_[index], targets_[index]};
}
optional MNIST::size() const {
return images_.size(0);
}
上面两个函数分别用于获取一个图像及其标签,和返回数据集大小。
自定义数据集的流程
通过上面的源代码查看,我们知道了自定义数据集的大概流程:
读取数据和标签
转换成 tensor
定义 get() 和 size() 两个函数
初始化类
将类实例传给 data loader
自定义数据集示例
接下来我们看一个具体示例。下面是整个代码的大体框架:
#include
#include
#include
#include
#include
#include
#include
/* Convert and Load image to tensor from location argument */
torch::Tensor read_data(std::string location) {
// Read Data here
// Return tensor form of the image
return torch::Tensor;
}
/* Converts label to tensor type in the integer argument */
torch::Tensor read_label(int label) {
// Read label here
// Convert to tensor and return
return torch::Tensor;
}
/* Loads images to tensor type in the string argument */
vector<torch::Tensor> process_images(vector<string> list_images) {
cout << "Reading Images..." << endl;
// Return vector of Tensor form of all the images
return vector<torch::Tensor>;
}
/* Loads labels to tensor type in the string argument */
vector<torch::Tensor> process_labels(vector<string> list_labels) {
cout << "Reading Labels..." << endl;
// Return vector of Tensor form of all the labels
return vector<torch::Tensor>;
}
class CustomDataset : public torch::data::dataset<CustomDataset> {
private:
// Declare 2 vectors of tensors for images and labels
vector<torch::Tensor> images, labels;
public:
// Constructor
CustomDataset(vector<string> list_images, vector<string> list_labels) {
images = process_images(list_images);
labels = process_labels(list_labels);
};
// Override get() function to return tensor at location index
torch::data::Example<> get(size_t index) override {
torch::Tensor sample_img = images.at(index);
torch::Tensor sample_label = labels.at(index);
return {sample_img.clone(), sample_label.clone()};
};
// Return the length of data
torch::optional<size_t> size() const override {
return labels.size();
};
};
这里我们使用 OpenCV 来读取图像数据,读取的方法相对比较简单:
cv::imread(std::string location, int)
注意要转换成 PyTorch 使用的 tensor 顺,即 batch_size, channels, height, width:
torch::Tensor read_data(std::string loc) {
// Read Image from the location of image
cv::Mat img = cv::imread(loc, 1);
// Convert image to tensor
torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);
img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width
return img_tensor.clone();
};
读取标签:
// Read Label (int) and convert to torch::Tensor type
torch::Tensor read_label(int label) {
torch::Tensor label_tensor = torch::full({1}, label);
return label_tensor.clone();
}
最终代码:
#include
#include
#include
#include
#include
#include
#include
/* Convert and Load image to tensor from location argument */
torch::Tensor read_data(std::string loc) {
// Read Data here
// Return tensor form of the image
cv::Mat img = cv::imread(loc, 1);
cv::resize(img, img, cv::Size(1920, 1080), cv::INTER_CUBIC);
std::cout << "Sizes: " << img.size() << std::endl;
torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);
img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width
return img_tensor.clone();
}
/* Converts label to tensor type in the integer argument */
torch::Tensor read_label(int label) {
// Read label here
// Convert to tensor and return
torch::Tensor label_tensor = torch::full({1}, label);
return label_tensor.clone();
}
/* Loads images to tensor type in the string argument */
vector<torch::Tensor> process_images(vector<string> list_images) {
cout << "Reading Images..." << endl;
// Return vector of Tensor form of all the images
vector<torch::Tensor> states;
for (std::vector<string>::iterator it = list_images.begin(); it != list_images.end(); ++it) {
torch::Tensor img = read_data(*it);
states.push_back(img);
}
return states;
}
/* Loads labels to tensor type in the string argument */
vector<torch::Tensor> process_labels(vector<string> list_labels) {
cout << "Reading Labels..." << endl;
// Return vector of Tensor form of all the labels
vector<torch::Tensor> labels;
for (std::vector<int>::iterator it = list_labels.begin(); it != list_labels.end(); ++it) {
torch::Tensor label = read_label(*it);
labels.push_back(label);
}
return labels;
}
class CustomDataset : public torch::data::dataset<CustomDataset> {
private:
// Declare 2 vectors of tensors for images and labels
vector<torch::Tensor> images, labels;
public:
// Constructor
CustomDataset(vector<string> list_images, vector<string> list_labels) {
images = process_images(list_images);
labels = process_labels(list_labels);
};
// Override get() function to return tensor at location index
torch::data::Example<> get(size_t index) override {
torch::Tensor sample_img = images.at(index);
torch::Tensor sample_label = labels.at(index);
return {sample_img.clone(), sample_label.clone()};
};
// Return the length of data
torch::optional<size_t> size() const override {
return labels.size();
};
};
int main(int argc, char** argv) {
vector<string> list_images; // list of path of images
vector<int> list_labels; // list of integer labels
// Dataset init and apply transforms - None!
auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>());
}
在下一篇教程中,我们将介绍如何在 CNN 中使用自定义的 data loader。
在下一篇教程中,我们将介绍如何在 CNN 中使用自定义的 data loader。