libtorch选择显卡运行torchscript

#include 

std::string filename = "centernet.pt"//模型路径
int gpu_id = 1;									//gpu id 0代表第一块可见gpu
cudaSetDevice(gpu_id);					//切换显卡
torch::jit::script::Module module = torch::jit::load(filename,torch::Device(torch::DeviceType::CUDA,gpu_id));//加载模型

libtorch 加载torchscript模型有三个重载函数

TORCH_API script::Module load(
    std::istream& in,
    c10::optional<c10::Device> device = c10::nullopt,
    script::ExtraFilesMap& extra_files = default_extra_files);

TORCH_API script::Module load(
    const std::string& filename,
    c10::optional<c10::Device> device = c10::nullopt,
    script::ExtraFilesMap& extra_files = default_extra_files);

TORCH_API script::Module load(
    std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
    c10::optional<c10::Device> device = c10::nullopt,
    script::ExtraFilesMap& extra_files = default_extra_files);

目前我是从文件加载模型,用第二个函数,选择设备这里主要关注第二个参数

c10::optional<c10::Device> device

这里我们需要构造一个device类传入,我们看Device类定义

Device(DeviceType type, DeviceIndex index = -1)

这里很显然第一个是设备类型,第二个是设备索引
第一个是枚举类:我们选择torch::DeviceType::CUDA  也就是nvidia显卡计算平台
第一个就是显卡id  我们填0代表第一块显卡,1代表第二块显卡

你可能感兴趣的:(libtorch,libtorch,libtorch选择显卡)