PyTorch C++系列教程1:用 VGG-16 识别 MNIST

PyTorch系列文章目录


文章目录

  • PyTorch系列文章目录
  • 前言
    • 安装
      • CPU 版本:
      • GPU (CUDA 9.0) 版本:
      • GPU (CUDA 10.0) 版本:
  • 一、VGG-16 的网络结构
  • 训练
  • 总结


前言

本文讲解如何用 PyTorch C 实现 VGG-16 来识别 MNIST 数据集。

安装

首先下载 libtorch:

CPU 版本:

wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-latest.zip -O libtorch.zip

GPU (CUDA 9.0) 版本:

wget https://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zip -O libtorch.zip

GPU (CUDA 10.0) 版本:

wget https://download.pytorch.org/libtorch/cu100/libtorch-shared-with-deps-latest.zip

然后将下载的压缩包解压缩。后面我们将使用解压后后的文件夹的绝对路径。

一、VGG-16 的网络结构

![在这里插入图首先引入头文件:

#include 

然后实现网络定义:

/* Sample code for training a FCN on MNIST dataset using PyTorch C++ API */
/* This code uses VGG-16 Layer Network */

struct Net: torch::nn::Module {
    // VGG-16 Layer
    // conv1_1 - conv1_2 - pool 1 - conv2_1 - conv2_2 - pool 2 - conv3_1 - conv3_2 - conv3_3 - pool 3 -
    // conv4_1 - conv4_2 - conv4_3 - pool 4 - conv5_1 - conv5_2 - conv5_3 - pool 5 - fc6 - fc7 - fc8
    Net() {
        // Initialize VGG-16
        // On how to pass strides and padding: https://github.com/pytorch/pytorch/issues/12649#issuecomment-430156160
        conv1_1 = register_module("conv1_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 3).padding(1)));
        conv1_2 = register_module("conv1_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(10, 20, 3).padding(1)));
        // Insert pool layer
        conv2_1 = register_module("conv2_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(20, 30, 3).padding(1)));
        conv2_2 = register_module("conv2_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(30, 40, 3).padding(1)));
        // Insert pool layer
        conv3_1 = register_module("conv3_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(40, 50, 3).padding(1)));
        conv3_2 = register_module("conv3_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(50, 60, 3).padding(1)));
        conv3_3 = register_module("conv3_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(60, 70, 3).padding(1)));
        // Insert pool layer
        conv4_1 = register_module("conv4_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(70, 80, 3).padding(1)));
        conv4_2 = register_module("conv4_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(80, 90, 3).padding(1)));
        conv4_3 = register_module("conv4_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(90, 100, 3).padding(1)));
        // Insert pool layer
        conv5_1 = register_module("conv5_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 110, 3).padding(1)));
        conv5_2 = register_module("conv5_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(110, 120, 3).padding(1)));
        conv5_3 = register_module("conv5_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(120, 130, 3).padding(1)));
        // Insert pool layer
        fc1 = register_module("fc1", torch::nn::Linear(130, 50));
        fc2 = register_module("fc2", torch::nn::Linear(50, 20));
        fc3 = register_module("fc3", torch::nn::Linear(20, 10));
    }

    // Implement Algorithm
    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(conv1_1->forward(x));
        x = torch::relu(conv1_2->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv2_1->forward(x));
        x = torch::relu(conv2_2->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv3_1->forward(x));
        x = torch::relu(conv3_2->forward(x));
        x = torch::relu(conv3_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv4_1->forward(x));
        x = torch::relu(conv4_2->forward(x));
        x = torch::relu(conv4_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv5_1->forward(x));
        x = torch::relu(conv5_2->forward(x));
        x = torch::relu(conv5_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = x.view({-1, 130});

        x = torch::relu(fc1->forward(x));
        x = torch::relu(fc2->forward(x));
        x = fc3->forward(x);

        return torch::log_softmax(x, 1);
    }

    // Declare layers
    torch::nn::Conv2d conv1_1{nullptr};
    torch::nn::Conv2d conv1_2{nullptr};
    torch::nn::Conv2d conv2_1{nullptr};
    torch::nn::Conv2d conv2_2{nullptr};
    torch::nn::Conv2d conv3_1{nullptr};
    torch::nn::Conv2d conv3_2{nullptr};
    torch::nn::Conv2d conv3_3{nullptr};
    torch::nn::Conv2d conv4_1{nullptr};
    torch::nn::Conv2d conv4_2{nullptr};
    torch::nn::Conv2d conv4_3{nullptr};
    torch::nn::Conv2d conv5_1{nullptr};
    torch::nn::Conv2d conv5_2{nullptr};
    torch::nn::Conv2d conv5_3{nullptr};

    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

训练

接下来我们测试训练网络,我们训练 10 个 epoch,学习率 0.01,使用 nll_loss 损失函数:

int main() {
 // Create multi-threaded data loader for MNIST data
 auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
 std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(
 torch::data::transforms::Stack<>())), 64);
 
    // Build VGG-16 Network
    auto net = std::make_shared<Net>();
 
    torch::optim::SGD optimizer(net->parameters(), 0.01); // Learning Rate 0.01
 
 // net.train();
 
 for(size_t epoch=1; epoch<=10; ++epoch) {
 size_t batch_index = 0;
 // Iterate data loader to yield batches from the dataset
 for (auto& batch: *data_loader) {
 // Reset gradients
 optimizer.zero_grad();
 // Execute the model
 torch::Tensor prediction = net->forward(batch.data);
 // Compute loss value
 torch::Tensor loss = torch::nll_loss(prediction, batch.target);
 // Compute gradients
 loss.backward();
 // Update the parameters
 optimizer.step();
 
 // Output the loss and checkpoint every 100 batches
 if (++batch_index % 2 == 0) {
 std::cout << "Epoch: " << epoch << " | Batch: " << batch_index 
 << " | Loss: " << loss.item<float>() << std::endl;
 torch::save(net, "net.pt");
 }
 }
 }
}

总结

完整代码请参考:
https://github.com/krshrimali/Digit-Recognition-MNIST-SVHN-PyTorch-CPP

参考资料
https://pytorch.org/cppdocs/
http://yann.lecun.com/exdb/mnist/

你可能感兴趣的:(深度学习,PyTorch系列,ONNX,pytorch,c++,深度学习)