实现Lenet-5模型,实现模型的训练与验证,并编写两种识别方式:
1. 直接使用数据集中原始数据测试;
2. 使用图片测试;
Lenet-5 模型
- 把这个模型代码放入下面两个程序中编译。
#include
#include
// BatchNorm
// Dropout
class Lenet5 : public torch::nn::Module{
private:
// 卷积特征运算
torch::nn::Conv2d conv1;
torch::nn::Conv2d conv2;
torch::nn::Conv2d conv3;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
public:
Lenet5():
conv1(torch::nn::Conv2dOptions(1, 6, 5).stride(1).padding(2)), // 1 * 28 * 28 -> 6 * 28 * 28 -> 6 * 14 * 14
conv2(torch::nn::Conv2dOptions(6, 16, 5).stride(1).padding(0)), // 6 * 14 * 14 -> 16 * 10 * 10 -> 16 * 5 * 5
conv3(torch::nn::Conv2dOptions(16, 120, 5).stride(1).padding(0)), // 16 * 5 * 5 -> 120 * 1 * 1 (不需要池化)
fc1(120, 84), // 120 -> 84
fc2(84, 10){ // 84 -> 10 (分量最大的小标就是识别的数字)
// 注册需要学习的矩阵(Kernel Matrix)
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
// override
torch::Tensor forward(torch::Tensor x){ // {n * 1 * 28 * 28}
// 1. conv
x = conv1->forward(x); // {n * 6 * 28 * 28}
x = torch::max_pool2d(x, 2); // {n * 6 * 14 * 14}
x = torch::relu(x); // 激活函数 // {n * 6 * 14 * 14}
// 2. conv
x = conv2->forward(x); // {n * 16 * 10 * 10}
x = torch::max_pool2d(x, 2); // {n * 16 * 5 * 5}
x = torch::relu(x); // 激活函数 // {n * 16 * 5 * 5}
// 3. conv
x = conv3->forward(x); // {n * 120 * 1 * 1}
x = torch::relu(x); // 激活函数 // {n * 120 * 1 * 1}
// 做数据格式转换
x = x.view({-1, 120}); // {n * 120}
// 4. fc
x = fc1->forward(x);
x = torch::relu(x);
// 5. fc
x = fc2->forward(x);
return torch::log_softmax(x, 1); // CrossEntryLoss = log_softmax + nll
}
};
训练与验证main.cpp
template
void train(std::shared_ptr &model, DataLoader &loader, torch::optim::Adam &optimizer){
model->train();
// 迭代数据
int n = 0;
for(torch::data::Example &batch: loader){
torch::Tensor data = batch.data;
auto target = batch.target;
optimizer.zero_grad(); // 清空上一次的梯度
// 计算预测值
torch::Tensor y = model->forward(data);
// 计算误差
torch::Tensor loss = torch::nll_loss(y, target);
// 计算梯度: 前馈求导
loss.backward();
// 根据梯度更新参数矩阵
optimizer.step();
// 为了观察效果,输出损失
// std::cout << "\t|--批次:" << std::setw(2) << std::setfill(' ')<< ++n
// << ",\t损失值:" << std::setw(8) << std::setprecision(4) << loss.item() << std::endl;
}
// 输出误差
}
template
void valid(std::shared_ptr &model, DataLoader &loader) {
model->eval();
// 禁止求导的图跟踪
torch::NoGradGuard no_grad;
// 循环测试集
double sum_loss = 0.0;
int32_t num_correct = 0;
int32_t num_samples = 0;
for(const torch::data::Example<> &batch: loader){
// 每个批次预测值
auto data = batch.data;
auto target = batch.target;
num_samples += data.sizes()[0];
auto y = model->forward(data);
// 计算纯预测的结果
auto pred = y.argmax(1);
// 计算损失值
sum_loss += torch::nll_loss(y, target, {}, at::Reduction::Sum).item();
// 比较预测结果与真实的标签值
num_correct += pred.eq(target).sum().item();
}
// 输出正确值
std::cout << std::setw(8) << std::setprecision(4)
<< "平均损失值:" << sum_loss / num_samples
<< ",\t准确率:" << 100.0 * num_correct / num_samples << " %" << std::endl;
}
int main(int argc, const char** argv){
// 数据集
auto ds_train = torch::data::datasets::MNIST(".\\data", torch::data::datasets::MNIST::Mode::kTrain);
auto ds_valid = torch::data::datasets::MNIST(".\\data", torch::data::datasets::MNIST::Mode::kTest);
// torch::data::transforms::Normalize<> norm(0.1307, 0.3081);
torch::data::transforms::Stack<> stack;
// 数据批次加载器
// auto n_train = ds_train.map(norm);
auto s_train = ds_train.map(stack);
auto train_loader = torch::data::make_data_loader(std::move(s_train), 1000);
// auto n_valid = ds_valid.map(norm);
auto s_valid = ds_valid.map(stack);
auto valid_loader = torch::data::make_data_loader(std::move(s_valid), 1000);
// 1. 创建模型对象
std::shared_ptr model = std::make_shared();
// for(auto &batch: *train_loader){
// auto data = batch.data;
// auto target = batch.target;
// data = data.view({-1, 1, 28, 28});
// auto pred = model->forward(data);
// // pred <-> target 存在误差,计算误差,计算调整5 * 5 核矩阵的依据,调整的方向是 loss(pred - target) -> 0
// }
// 优化器(管理模型中可训练矩阵)
torch::optim::Adam optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(0.001)); // 根据经验一般设置为10e-4
std::cout<< "开始训练" << std::endl;
int epoch = 20;
int interval = 1; // 从测试间隔
for(int e = 0; e < epoch; e++){
std::printf("第%02d论训练\n", e+1);
train(model, *train_loader, optimizer);
if (e % interval == 0){
valid(model, *valid_loader);
}
}
std:: cout << "训练结束" << std::endl;
torch::save(model, "lenet5.pt");
return 0;
}
识别实现main.cpp
int main(){
const char * data_filename = ".\\data";
// 加载模型
std::shared_ptr model = std::make_shared();
torch::load(model, "lenet5.pt");
// 一. 使用测试集中数据识别
auto imgs = torch::data::datasets::MNIST(data_filename, torch::data::datasets::MNIST::Mode::kTest);
// 取一张图像
for(int i = 0; i < 10; i++){
torch::data::Example<> example = imgs.get(i);
// std::cout << "识别的数字是:" << example.target.item() << std::endl;
// 获取图像
torch::Tensor a_img = example.data;
// 预测
a_img = a_img.view({-1, 1, 28, 28}); // 我们的模型只接受4为的固定的数据格式(N * C * H * W)(NCHW格式)
torch::Tensor y = model->forward(a_img);
int32_t result = y.argmax(1).item();
std::cout << "识别的结果是:" << result << "->" << example.target.item() << std::endl;
}
std::cout << "----------------------------------" << std::endl;
// 二. 使用图像文件来识别
// 读取图像
cv::Mat im = cv::imread("img_9_9.png"); // 换图像,测试是否准确
cv::cvtColor(im, im, cv::COLOR_BGR2GRAY); // 注意:png图是3-4通道,需要转换为1通道灰度图。
// 转换为Tensor,处理成0-1之间的数字
im.convertTo(im, CV_32FC1, 1.0f / 255.0f);
torch::Tensor t_img = torch::from_blob(im.data, {1, 28, 28});
t_img = t_img.view({-1, 1, 28, 28});
// 识别
torch::Tensor y_ = model->forward(t_img);
int32_t pred = y_.argmax(1).item();
std::cout << "识别的结果是:" << pred << std::endl;
return 0;
}
编译脚本CMakeLists.txt
cmake_minimum_required(VERSION 3.16)
project(main)
set(CMAKE_PREFIX_PATH "C:/libtorch")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
find_package(Torch REQUIRED)
# opencv的配置
include_directories("C:/opencv_new/install/include")
link_directories("C:/opencv_new/install/x64/vc16/lib")
add_executable(main main.cpp)
target_link_libraries(main "${TORCH_LIBRARIES}" "opencv_core420d.lib" "opencv_imgcodecs420d.lib" "opencv_imgproc420d.lib" )
set_property(TARGET main PROPERTY CXX_STANDARD 11)