C++代码
#include
#include
#include
#include
#include
#include
#include
// Where to find the MNIST dataset.
const char* kDataRoot = "./data";
// The batch size for training.
const int64_t kTrainBatchSize = 64;
// The batch size for testing.
const int64_t kTestBatchSize = 1000;
// The number of epochs to train.
const int64_t kNumberOfEpochs = 10;
// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;
struct Net : torch::nn::Module {
Net()
: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
fc1(320, 50),
fc2(50, 10) {
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv2_drop", conv2_drop);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
x = torch::relu(
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
x = x.view({-1, 320});
x = torch::relu(fc1->forward(x));
x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
x = fc2->forward(x);
return torch::log_softmax(x, /*dim=*/1);
}
torch::nn::Conv2d conv1;
torch::nn::Conv2d conv2;
torch::nn::FeatureDropout conv2_drop;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
};
template <typename DataLoader>
void train(
int32_t epoch,
Net& model,
torch::Device device,
DataLoader& data_loader,
torch::optim::Optimizer& optimizer,
size_t dataset_size) {
model.train();
size_t batch_idx = 0;
for (auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
optimizer.zero_grad();
auto output = model.forward(data);
auto loss = torch::nll_loss(output, targets);
AT_ASSERT(!std::isnan(loss.template item<float>()));
loss.backward();
optimizer.step();
if (batch_idx++ % kLogInterval == 0) {
std::printf(
"\rTrain Epoch: %d [%5ld/%5ld] Loss: %.4f",
epoch,
batch_idx * batch.data.size(0),
dataset_size,
loss.template item<float>());
}
}
}
template <typename DataLoader>
void test(
Net& model,
torch::Device device,
DataLoader& data_loader,
size_t dataset_size) {
torch::NoGradGuard no_grad;
model.eval();
double test_loss = 0;
int32_t correct = 0;
for (const auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
auto output = model.forward(data);
test_loss += torch::nll_loss(
output,
targets,
/*weight=*/{},
Reduction::Sum)
.template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(targets).sum().template item<int64_t>();
}
test_loss /= dataset_size;
std::printf(
"\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
test_loss,
static_cast<double>(correct) / dataset_size);
}
auto main() -> int {
torch::manual_seed(1);
torch::DeviceType device_type;
if (torch::cuda::is_available()) {
std::cout << "CUDA available! Training on GPU." << std::endl;
device_type = torch::kCUDA;
} else {
std::cout << "Training on CPU." << std::endl;
device_type = torch::kCPU;
}
torch::Device device(device_type);
Net model;
model.to(device);
auto train_dataset = torch::data::datasets::MNIST(kDataRoot)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
const size_t train_dataset_size = train_dataset.size().value();
auto train_loader =
torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
std::move(train_dataset), kTrainBatchSize);
auto test_dataset = torch::data::datasets::MNIST(
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
const size_t test_dataset_size = test_dataset.size().value();
auto test_loader =
torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize);
torch::optim::SGD optimizer(
model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
train(epoch, model, device, *train_loader, optimizer, train_dataset_size);
test(model, device, *test_loader, test_dataset_size);
}
}
CMakeLists.txt
cmake_minimum_required(VERSION 3.15)
project(mnist_train)
set(CMAKE_CXX_STANDARD 14)
list(APPEND CMAKE_PREFIX_PATH "abs_path/libtorch-shared-with-deps-1.3.1/libtorch")
find_package(Torch REQUIRED)
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)
message(STATUS "Downloading MNIST dataset")
execute_process(
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/tools/download_mnist.py
-d ${CMAKE_BINARY_DIR}/data
ERROR_VARIABLE DOWNLOAD_ERROR)
if (DOWNLOAD_ERROR)
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")
endif()
endif()
add_executable(mnist_train main.cpp)
target_compile_features(mnist_train PUBLIC cxx_range_for)
target_link_libraries(mnist_train ${TORCH_LIBRARIES})