1. tensor的创建
tensor的创建有三种常用的形式,如下所示
ones创建一个指定维度,数据全为1的tensor. 例子中的维度是2维,5行3列。
torch::Tensor t = torch::ones({5,3});
zeros创建一个指定维度,数据全为0的tensor,例子中的维度是三维,分别是4,5,3
torch::Tensor m = torch::zeros({4,5,3});
使用tensor函数创建使用指定数据初始化的tensor
torch::Tensor n = torch::tensor({1,2,3}); 创建一维数据是1,2,3的tensor
torch::Tensor n = torch::tensor({{1,2,3},{4,5,6}); 创建二维维的tensor
但在使用时通常不会使用这种方式创建二维的tensor,而是创建一维的tensor,再通过维度调整函数调整维度,在数据多的时候比较实用。
除以上三种常用的初始化方式以外,官网还给出了如下几种创建的方式
创建tensor时,有时需要指定数据类型,可如下操作(其中kFloat指定了浮点数据)
torch::Tensor a = torch::zeros({1,3,4,4},torch::kFloat);
torch::Tensor b = torch::ones({1,3,4,4},torch::kFloat);
2. tensor的维度信息的获取
获取维度使用Tensor::sizes函数。代码如下:
torch::Tensor t = torch::ones({5,3});
torch::IntArrayRef s = t.sizes(); //获取维度,这里时s是一个vector数据,是5和3
Std::vector
3. tensor的维度调整
tensor的维度调整使用reshape函数.
torch::Tensor t = torch::zeros({75});
torch::Tensor b = t.reshape({3,5,5});
通过reshape函数将1x75的数据变成3x5x5的数据。这里需要申明的一点,这里只是数据维度的调整,并没有调整数据存储空间,tensor数据的存储是线性存储的。
与维度相关的另一个函数是permute,它也调整数据空间的维度如
torch::Tensor c = b. permute(2,1,0)
它将b的第3个围堵作为c的第一个维度,b的第二个维度作为c的第二个维度,b的第一个维度作为C的第三个维度,所以c的维度是5x5x3。注意这里只是维度的调整,数据的顺序一点都没变。拿二位数据举个例子
1,2, 3, 4 1,2,3,4,5,6
5,6, 7, 8 7,8,9,10,11,12
9,10,11,12
维度发生了变换,数据顺序依旧没有变换
3. tensor元素访问
tensor元素的访问torch中提供的是select函数和index_select函数(我对这几个函数的理解还不够好,使用老不合自己预期),但我更喜欢下标方式的访问。
torch::Tensor m = torch::zeros({4,5,3});
torch::Tensor k = m[2]; //k的shape是5x3
下标方式可以用来访问元素获取值,也可以设置元素的值,代码如下
torch::Tensor t = torch::ones({5,3});
for(int i=0; i<5; i++)
{
t[i] = torch::tensor({1,2,3}); //通过下标设置值
}
torch::Tensor m = torch::zeros({4,5,3});
m[0] = t; //通过下标设置值
m[1] = t;
m[2] = t;
m[3] = t;
torch::Tensor k = m[2];
通过下标的访问,返回的是一个tensor对象,即使是最后一维元素也是,无法得到像float,int这样的数据类型,需要使用的tensor的item函数获取
torch::Tensor pms = torch::ones({3,400.300});
float v = pms[0][5][100].item
但需要注意,调用item需要确保它是一个具体的值,不会是一个一维或二位的tensor.
4. tensor的算数操作
tensor的算数操作,比较多,但概念和使用上都比较容易,和矩阵的概念一直。(后继在补上这部分内容)
5. tensor的拼接
tensor的拼接有两种基本形式的拼接,一种是stack式的扩展维度,另一种是cat形式的在某个维度上的连接
Stack形式的拼接
torch::Tensor target = torch::zeros({10},torch::kFloat);
target[5] = 1;
std::vector
for(int i=0;i<5; i++)
{
tags.push_back(target);
}
torch::Tensor t = torch::stack(tags);
Cat形式的拼接
std::vector
for(int v : m_rpnCoreSize)
{
std::tuple
torch::Tensor tt = std::get<0>(tv);
ls.push_back(torch::flatten(tt));
}
return torch::cat(ls,0);