PyTorch C++系列教程2:使用自定义数据集



在上一篇文章中,我们讨论了如何使用 PyTorch C++ API 实现 VGG-16 来识别 MNIST 数据集。这篇文章我们讨论一下如何用 C++ API 使用自定义数据集。



auto data_loader = torch::data::make_data_loader(
			std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081))).map(
				torch::data::transforms::Stack<>()), 64);


首先,我们将数据集读入 tensor:

auto data_set = torch::data::datasets::MNIST("../data");

接下来,我们应用一些 transforms :

auto data_set =<>(0.13707, 0.3081)).map(torch::data::transforms::Stack<>())

我们 batch_size 为 64:

std::move(data_set, 64);

然后我们就可以将数据传给 data loader 然后由 data loader 传给网络。


我们需要了解一下这背后到底是怎么工作的,因此,我们看一下 MNIST 读取的源码文件 torch::data::datasets::MNIST 类,源码地址在这里:

namespace torch {
namespace data {
namespace datasets {
/// The MNIST dataset.
class TORCH_API MNIST : public Dataset {
  /// The mode in which the dataset is loaded.
  enum class Mode { kTrain, kTest };

  /// Loads the MNIST dataset from the root path.
  /// The supplied root path should contain the content of the unzipped
  /// MNIST dataset, available from
  explicit MNIST(const std::string& root, Mode mode = Mode::kTrain);

  /// Returns the Example at the given index.
  Example<> get(size_t index) override;

  /// Returns the size of the dataset.
  optional size() const override;

  /// Returns true if this is the training subset of MNIST.
  bool is_train() const noexcept;

  /// Returns all images stacked into a single tensor.
  const Tensor& images() const;

  /// Returns all targets stacked into a single tensor.
  const Tensor& targets() const;

  Tensor images_, targets_;
} // namespace datasets
} // namespace data
} // namespace torch

对于 MNIST 类的构造器:

MNIST::MNIST(const std::string& root, Mode mode)
    : images_(read_images(root, mode == Mode::kTrain)),
      targets_(read_targets(root, mode == Mode::kTrain)) {}


read_images(root, mode) 读取图像
read_targets(root, mode) 读取图像标签


read_images(root, mode)

Tensor read_images(const std::string& root, bool train) {
  // kTrainImagesFilename and kTestImagesFilename are specific to MNIST dataset here
  // No need for using join_paths here
  const auto path =
      join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);

  // Load images
  std::ifstream images(path, std::ios::binary);
  TORCH_CHECK(images, "Error opening images file at ", path);
  // kTrainSize = len(training data)
  // kTestSize = len(testing_data)
  const auto count = train ? kTrainSize : kTestSize;

  // Specific to MNIST data
  // From
  expect_int32(images, kImageMagicNumber);
  expect_int32(images, count);
  expect_int32(images, kImageRows);
  expect_int32(images, kImageColumns);

  // This converts images to tensors
  // Allocate an empty tensor of size of image (count, channels, height, width)
  auto tensor =
      torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
  // Read image and convert to tensor, tensor.numel());
  // Normalize the image from 0 to 255 to 0 to 1

read_targets(root, mode)

Tensor read_targets(const std::string& root, bool train) {
  // Specific to MNIST dataset (kTrainImagesFilename and kTestTargetsFilename)
  const auto path =
      join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
  // Read the labels
  std::ifstream targets(path, std::ios::binary);
  TORCH_CHECK(targets, "Error opening targets file at ", path);
  // kTrainSize = len(training_labels)
  // kTestSize = len(testing_labels)
  const auto count = train ? kTrainSize : kTestSize;
  expect_int32(targets, kTargetMagicNumber);
  expect_int32(targets, count);
  // Allocate an empty tensor of size of number of labels
  auto tensor = torch::empty(count, torch::kByte);
  // Convert to tensor<char*>(tensor.data_ptr()), count);


Example<> MNIST::get(size_t index) {
  return {images_[index], targets_[index]};
optional MNIST::size() const {
  return images_.size(0);



转换成 tensor
定义 get() 和 size() 两个函数
将类实例传给 data loader

/* Convert and Load image to tensor from location argument */
torch::Tensor read_data(std::string location) {
  // Read Data here
  // Return tensor form of the image
  return torch::Tensor;
/* Converts label to tensor type in the integer argument */
torch::Tensor read_label(int label) {
  // Read label here
  // Convert to tensor and return
  return torch::Tensor;
/* Loads images to tensor type in the string argument */
vector<torch::Tensor> process_images(vector<string> list_images) {
  cout << "Reading Images..." << endl;
  // Return vector of Tensor form of all the images
  return vector<torch::Tensor>;
/* Loads labels to tensor type in the string argument */
vector<torch::Tensor> process_labels(vector<string> list_labels) {
  cout << "Reading Labels..." << endl;
  // Return vector of Tensor form of all the labels
  return vector<torch::Tensor>;
class CustomDataset : public torch::data::dataset<CustomDataset> {
  // Declare 2 vectors of tensors for images and labels
  vector<torch::Tensor> images, labels;
  // Constructor
  CustomDataset(vector<string> list_images, vector<string> list_labels) {
    images = process_images(list_images);
    labels = process_labels(list_labels);
  // Override get() function to return tensor at location index
  torch::data::Example<> get(size_t index) override {
    torch::Tensor sample_img =;
    torch::Tensor sample_label =;
    return {sample_img.clone(), sample_label.clone()};
  // Return the length of data
  torch::optional<size_t> size() const override {
    return labels.size();

这里我们使用 OpenCV 来读取图像数据,读取的方法相对比较简单:

cv::imread(std::string location, int)

注意要转换成 PyTorch 使用的 tensor 顺,即 batch_size, channels, height, width:

torch::Tensor read_data(std::string loc) {
	// Read Image from the location of image
	cv::Mat img = cv::imread(loc, 1);

  // Convert image to tensor
	torch::Tensor img_tensor = torch::from_blob(, {img.rows, img.cols, 3}, torch::kByte);
	img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width
	return img_tensor.clone();


// Read Label (int) and convert to torch::Tensor type
torch::Tensor read_label(int label) {
 torch::Tensor label_tensor = torch::full({1}, label);
 return label_tensor.clone();


/* Convert and Load image to tensor from location argument */
torch::Tensor read_data(std::string loc) {
  // Read Data here
  // Return tensor form of the image
  cv::Mat img = cv::imread(loc, 1);
 cv::resize(img, img, cv::Size(1920, 1080), cv::INTER_CUBIC);
 std::cout << "Sizes: " << img.size() << std::endl;
 torch::Tensor img_tensor = torch::from_blob(, {img.rows, img.cols, 3}, torch::kByte);
 img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width
 return img_tensor.clone();
/* Converts label to tensor type in the integer argument */
torch::Tensor read_label(int label) {
  // Read label here
  // Convert to tensor and return
  torch::Tensor label_tensor = torch::full({1}, label);
 return label_tensor.clone();
/* Loads images to tensor type in the string argument */
vector<torch::Tensor> process_images(vector<string> list_images) {
  cout << "Reading Images..." << endl;
  // Return vector of Tensor form of all the images
  vector<torch::Tensor> states;
 for (std::vector<string>::iterator it = list_images.begin(); it != list_images.end(); ++it) {
 torch::Tensor img = read_data(*it);
 return states;
/* Loads labels to tensor type in the string argument */
vector<torch::Tensor> process_labels(vector<string> list_labels) {
  cout << "Reading Labels..." << endl;
  // Return vector of Tensor form of all the labels
  vector<torch::Tensor> labels;
 for (std::vector<int>::iterator it = list_labels.begin(); it != list_labels.end(); ++it) {
 torch::Tensor label = read_label(*it);
 return labels;
class CustomDataset : public torch::data::dataset<CustomDataset> {
  // Declare 2 vectors of tensors for images and labels
  vector<torch::Tensor> images, labels;
  // Constructor
  CustomDataset(vector<string> list_images, vector<string> list_labels) {
    images = process_images(list_images);
    labels = process_labels(list_labels);
  // Override get() function to return tensor at location index
  torch::data::Example<> get(size_t index) override {
    torch::Tensor sample_img =;
    torch::Tensor sample_label =;
    return {sample_img.clone(), sample_label.clone()};
  // Return the length of data
  torch::optional<size_t> size() const override {
    return labels.size();
int main(int argc, char** argv) {
  vector<string> list_images; // list of path of images
  vector<int> list_labels; // list of integer labels
  // Dataset init and apply transforms - None!
  auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>());

在下一篇教程中,我们将介绍如何在 CNN 中使用自定义的 data loader。


