当训练好的模型集成到libtorch上之后,测算模型的推理时间不能简单的通测算普通CPU上运行程序时间的方法一致。主要需要解决一个GPU热身问题和cuda CPU同步问题。关于在Pytorch上测算模型推理素的代码和原理参见The Correct Way to Measure Inference Time of Deep Neural Networks。本文的代码也是参考该文章提供的代码,不过是在libtorch上运行。
直接给出libtorch上测量模型推理时间的代码:
#include
#include
torch::jit::script::Module m = torch::jit::load("traced_net1.pt");
m.to(device_type);
m.eval();
torch::NoGradGuard no_grad;
torch::Tensor input_t = torch::ones({ 10, 3 ,100 }).to(device_type);
inputs.push_back(input_t);
//GPU warm
for (int i = 0; i < 30; i++)
auto dd = m.forward(inputs);
std::chrono::time_point<std::chrono::steady_clock> start;
std::chrono::time_point<std::chrono::steady_clock> stop;
double processingTime;
start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < 300; i++)
{
auto b = m.forward(inputs);
}
stop = std::chrono::high_resolution_clock::now();
processingTime = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count() / 1000.0;
std::cout << "Time during: " << processingTime/300 << std::endl;//输出的是毫秒(ms)
主要思路就是先要给GPU热身,然后设置同步,并且要多次推理求推理时间平均值。
上述代码参照了https://github.com/pytorch/pytorch/issues/19106回答中的代码。我把这份代码也附在下面供大家参考:
#include
#include
#include
#include // One-stop header.
int main() {
std::string weightPath = "E:/toolsku/ReportLibtorchBug/resNet50.pt";
torch::jit::script::Module model;
torch::Device targetDevice = torch::kCPU;
try {
model = torch::jit::load(weightPath);// Deserialize the ScriptModule from a file using torch::jit::load().
if (torch::cuda::is_available()) {
std::cout << "GPU is available -> Switch to GPU mode" << std::endl;
targetDevice = torch::kCUDA;//to GPU
}
model.eval();
model.to(targetDevice);
}
catch (const c10::Error& e) {
std::cerr << "Error in loading the model!\n";
return -1;
}
torch::NoGradGuard no_grad;
std::cout << "Success in loading the model!\n";
std::vector<torch::Tensor> batch_data;// using a tensor list
int netHeight = 224, netWidth = 224;
cv::Size inpDimension(netWidth, netHeight);
torch::TensorOptions options(torch::kFloat32);
torch::Tensor means = torch::tensor({ 0.485, 0.456, 0.406 }, options).view({ 1, 3, 1, 1 }).to(targetDevice);
torch::Tensor stds = torch::tensor({ 0.229, 0.224, 0.225 }, options).view({ 1, 3, 1, 1 }).to(targetDevice);
std::string imgPath = "E:/toolsku/ReportLibtorchBug/puppy.jpg";
cv::Mat img = cv::imread(imgPath, cv::IMREAD_COLOR);
cv::resize(img, img, inpDimension, cv::INTER_CUBIC);
img.convertTo(img, CV_32FC3, 1.0f);
torch::Tensor input1 = torch::from_blob(img.data, { 1,netHeight, netWidth, 3 }, options).clone().toType(torch::kFloat32);
batch_data.push_back(input1);
torch::Tensor input_tensor = torch::cat(batch_data, 0);
input_tensor = input_tensor.to(targetDevice);
input_tensor = input_tensor.permute({ 0, 3, 1, 2 });
input_tensor = (input_tensor.div_(255.0) - means) / stds;
std::chrono::time_point<std::chrono::steady_clock> start;
std::chrono::time_point<std::chrono::steady_clock> stop;
double processingTime;
torch::Tensor out_tensor;
for (int i = 0; i < 500; ++i) {//warmup
out_tensor = model.forward({ input_tensor }).toTensor();
}
start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < 1000; ++i) {
out_tensor = model.forward({ input_tensor }).toTensor();
}
stop = std::chrono::high_resolution_clock::now();
processingTime = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count() / 1000.0;
std::cout << "Avg. processing time: " << (processingTime / 1000) << " ms\n";
out_tensor = torch::argmax(out_tensor, 1);
out_tensor = out_tensor.to(torch::kCPU);//make the variable is in CPU
std::cout << out_tensor << std::endl;
return 0;
}
此外知乎上也有一个关于libtorch与pytorch性能对比的文章:PyTorch vs LibTorch:网络推理速度谁更快?,其中也提供了一份推理时间计算的代码:
#include
#include
#include
...
start = std::chrono::system_clock::now();
output = civilnet->forward(inputs).toTensor();
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
forward_duration = std::chrono::system_clock::now() - start;
msg = gemfield_org::format(" time: %f", forward_duration.count() );
std::cout<<"civilnet->forward(inputs).toTensor() "<<msg<<std::endl;
这份代码存在一些bug,我这边
msg = gemfield_org::format(" time: %f", forward_duration.count() );
gemfield_org报错,不知道是哪个命名空间或者类。如果哪位大佬跑通了这份代码麻烦指教。