参考博客:https://blog.csdn.net/tlzhatao/article/details/86555269?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
tracing
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# ScriptModule
output = traced_script_module(torch.ones(1, 3, 224, 224))
traced_script_module.save("resnet18_model.pt")
annotation
import torch
import torchvision
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(1, 3, 224, 224))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = MyModule()
my_script_module.save("model.pt")
官网下载对应的libtorch:https://download.pytorch.org/libtorch/nightly/cu100/libtorch-shared-with-deps-latest.zip,网页下载太慢可以用wget。
准备example-app.cpp文件
#include "torch/script.h"
#include "torch/torch.h"
#include
#include
#include
using namespace std;
int main(int argc, const char* argv[])
{
if (argc != 2) {
std::cerr << "usage: example-app \n" ;
return -1;
}
// 读取我们的权重信息
torch::jit::script::Module module = torch::jit::load(argv[1]);
module.to(at::kCUDA);
std::cout << "ok\n";
// 建立一个输入,维度为(1,3,224,224),并移动至cuda
for(int i=0;i<100;i++){
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));
clock_t startTime,endTime;
startTime = clock();
at::Tensor output = module.forward(inputs).toTensor();
endTime = clock();//计时结束
cout << "The run time is: " <<(double)(endTime - startTime) / CLOCKS_PER_SEC << "s" << endl;
}
// Execute the model and turn its output into a tensor.
//std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
}
准备CmakeLists.txt文件
注意最后一行标准是14,否则会报错。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
set(Torch_DIR /home/wxy/code/libtorch)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
新建build文件
cmake -DCMAKE_PREFIX_PATH=/home/wxy/code/libtorch ..
make
./example-app ../model.pt
官网下载对应的libtorch:https://download.pytorch.org/libtorch/nightly/cu90/libtorch-win-shared-with-deps-latest.zip,网页下载太慢可以用wget。
准备example-app.cpp文件
注意,如果是1.1版本及以下,是个指针,1.2以上已经改了,用到module的地方也相应改一下:
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
#include // One-stop header.
#include
#include
#include
using namespace std;
int main(int argc, const char* argv[]) {
/*if (argc != 2) {
std::cerr << "usage: example-app \n";
return -1;
}*/
// Deserialize the ScriptModule from a file using torch::jit::load().
torch::jit::script::Module module = torch::jit::load("F:/python_project/pytorch/inference_c++/demo/resnet18_model.pt");
module.to(at::kCUDA);
std::cout << "ok\n";
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::rand({1, 3, 224, 224}).to(at::kCUDA));
// Execute the model and turn its output into a tensor.
auto output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
}
准备CmakeLists.txt文件
注意最后一行标准是14,否则会报错。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
set(Torch_DIR F:/python_project/pytorch/inference_c++/libtorch-win-shared-with-deps-latest/libtorch)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app “${TORCH_LIBRARIES}”)
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
新建build文件
cmake -DCMAKE_PREFIX_PATH=F:/python_project/pytorch/inference_c++/libtorch-win-shared-with-deps-latest/libtorch -DCMAKE_GENERATOR_PLATFORM=x64 -DCMAKE_BUILD_TYPE=Release ..
make
最后用VS打开工程,因为我的cuda路径不是默认的路径,所以有一些库目录可能需要改一下,最后生成.exe文件运行可能需要nvToolsExt64_1.dll,在网上下载一个就可以了,这里提供一个链接:https://www.dll4free.com/nvtoolsext64_1.dll.html