libtorch之tensor的使用

1. tensor的创建

    tensor的创建有三种常用的形式,如下所示

    ones创建一个指定维度,数据全为1的tensor. 例子中的维度是2维,5行3列。

               torch::Tensor t = torch::ones({5,3});

    zeros创建一个指定维度,数据全为0的tensor,例子中的维度是三维,分别是4,5,3

               torch::Tensor m = torch::zeros({4,5,3});

   使用tensor函数创建使用指定数据初始化的tensor

               torch::Tensor n = torch::tensor({1,2,3});      创建一维数据是1,2,3的tensor

              torch::Tensor n = torch::tensor({{1,2,3},{4,5,6});      创建二维维的tensor

    但在使用时通常不会使用这种方式创建二维的tensor,而是创建一维的tensor,再通过维度调整函数调整维度,在数据多的时候比较实用。

    除以上三种常用的初始化方式以外,官网还给出了如下几种创建的方式

libtorch之tensor的使用_第1张图片

     创建tensor时,有时需要指定数据类型,可如下操作(其中kFloat指定了浮点数据)

             torch::Tensor a = torch::zeros({1,3,4,4},torch::kFloat);

             torch::Tensor b = torch::ones({1,3,4,4},torch::kFloat);

2.  tensor的维度信息的获取

      获取维度使用Tensor::sizes函数。代码如下:

                torch::Tensor t = torch::ones({5,3});

                torch::IntArrayRef s = t.sizes();    //获取维度,这里时s是一个vector数据,是5和3

               Std::vector s = t.sizes();     //这段代码和上面的代码是等效的,torch::InfArrayRef可以和std::vector互转。更可以说torch::ArrarRef可以和std::vector互转。

3.  tensor的维度调整

     tensor的维度调整使用reshape函数.

                torch::Tensor t = torch::zeros({75});

                torch::Tensor b = t.reshape({3,5,5});

     通过reshape函数将1x75的数据变成3x5x5的数据。这里需要申明的一点,这里只是数据维度的调整,并没有调整数据存储空间,tensor数据的存储是线性存储的。

      与维度相关的另一个函数是permute,它也调整数据空间的维度如

                       torch::Tensor c = b. permute(2,1,0)

       它将b的第3个围堵作为c的第一个维度,b的第二个维度作为c的第二个维度,b的第一个维度作为C的第三个维度,所以c的维度是5x5x3。注意这里只是维度的调整,数据的顺序一点都没变。拿二位数据举个例子

                1,2, 3,  4                        1,2,3,4,5,6

                 5,6, 7,  8                        7,8,9,10,11,12

                 9,10,11,12

              维度发生了变换,数据顺序依旧没有变换

3. tensor元素访问

       tensor元素的访问torch中提供的是select函数和index_select函数(我对这几个函数的理解还不够好,使用老不合自己预期),但我更喜欢下标方式的访问。

               torch::Tensor m = torch::zeros({4,5,3});

              torch::Tensor k = m[2];     //k的shape是5x3

      下标方式可以用来访问元素获取值,也可以设置元素的值,代码如下

              torch::Tensor t = torch::ones({5,3});

              for(int i=0; i<5; i++)

              {

                          t[i] = torch::tensor({1,2,3});      //通过下标设置值

              }

             torch::Tensor m = torch::zeros({4,5,3});

            m[0] = t;                         //通过下标设置值

             m[1] = t;

            m[2] = t;

            m[3] = t;

             torch::Tensor k = m[2];

        通过下标的访问,返回的是一个tensor对象,即使是最后一维元素也是,无法得到像float,int这样的数据类型,需要使用的tensor的item函数获取

         torch::Tensor pms = torch::ones({3,400.300});

         float v = pms[0][5][100].item();

      但需要注意,调用item需要确保它是一个具体的值,不会是一个一维或二位的tensor.

4.  tensor的算数操作

     tensor的算数操作,比较多,但概念和使用上都比较容易,和矩阵的概念一直。(后继在补上这部分内容)

5. tensor的拼接

       tensor的拼接有两种基本形式的拼接,一种是stack式的扩展维度,另一种是cat形式的在某个维度上的连接

        Stack形式的拼接

                  torch::Tensor target = torch::zeros({10},torch::kFloat);

                  target[5] = 1;

                  std::vector tags;

                  for(int i=0;i<5; i++)

                 {

                        tags.push_back(target);

                  }

                 torch::Tensor t = torch::stack(tags);

   Cat形式的拼接

              std::vector ls;

             for(int v : m_rpnCoreSize)

             {

                     std::tuple tv = torch::adaptive_max_pool2d(x,{v,v});

                    torch::Tensor tt = std::get<0>(tv);

                    ls.push_back(torch::flatten(tt));

               }

             return torch::cat(ls,0);

你可能感兴趣的:(深度学习,c++)