libtorch aten::Tensor 与 std::vector 互换

在使用 libtorch 的过程中我们可能会遇到需要 libtorch 中的 at::Tensor 类型转化成 std::vector 常规类型存储,或者从 std::vector 生成一个 at::Tensor 供我们使用。

at::Tensor 转 std::vector

这里以 at::Tensor 里面的数据类型都是 float 为例,将 std::vectorT 直接设置成对应的基本数据类型即可:

aten::Tensor ten; // 假设 ten 里面已经有数据了
std::vector<float> v(ten.data_ptr<float>(), ten.data_ptr<float>() + ten.numel());

std::vector 转 aten::Tensor

int64_t 类型的数据:

  auto opts = torch::TensorOptions().dtype(torch::kInt64);
  auto tensor = torch::from_blob(value.data(), {int64_t(value.size())}, opts).clone();

float 类型的数据:

  auto opts = torch::TensorOptions().dtype(torch::kFloat32);
  auto tensor = torch::from_blob(value.data(), {int64_t(value.size())}, opts).clone();

其他类型的数据根据 torch::kxxxx 的类型以此类推。还有一个接口直接是 aten::from_blob 也可以使用,可以直接去 pytorch 源码中搜相关函数的使用方法。

注意: 上面在接口最后都加了一个 .clone() ,这个是在实际使用中遇到,创建的 aten::Tensor 数据是我们期望的,但是把该 Tensor 送入其他节点,在最后构建的图上,相关数据就变成了异常值,可能跟底层的实现对数据指针的操作有关,加一个 .clone() 就可以避免遇到值异常的问题。

你可能感兴趣的:(C++,深度学习)