Eigen::Tensor使用,定义高维矩阵

在实际项目中,需要存储大于等于三维的矩阵,而平常中我们使用Eigen::MatrixXd二维数据,这里我们使用Eigen::Tensor来定义

1.Using the Tensor module
#include 
2.定义矩阵
2.一般矩阵

官方文档

  // 定义一个2x3x4大小的矩阵
  Eigen::Tensor<float, 3> a(2, 3, 4);
  // 初始化为0
  a.setZero();
  // 访问元素
  a(0, 1, 0) = 12.0f;
  for (int i = 0; i < 2; i++) {
    for (int j = 0; j < 3; j++) {
      for (int k = 0; k < 4; k++) {
        std::cout << a(i, j, k) << " ";
      }
      std::cout << std::endl;
    }
    std::cout << std::endl << std::endl;
  }
  // 输出维度
  std::cout<<a.dimension(0)<<" "<<a.dimension(1)<<" "<<a.dimension(2)<<std::endl;

上面输出结果

0 0 0 0 
12 0 0 0 
0 0 0 0 

0 0 0 0 
0 0 0 0 
0 0 0 0

2 3 4
2.固定大小矩阵TensorFixedSize

参考官方解释

The fixed sized equivalent of Eigen::Tensor t(3, 5, 7); is Eigen::TensorFixedSize> t;

这里我们定义

  // 固定 大小的Size 2x3x4
  Eigen::TensorFixedSize<float, Eigen::Sizes<2, 3, 4>> b;
  // 每个元素都设置固定值
  b.setConstant(3.f);
  for (int i = 0; i < 2; i++) {
    for (int j = 0; j < 3; j++) {
      for (int k = 0; k < 4; k++) {
        std::cout << b(i, j, k) << " ";
      }
      std::cout << std::endl;
    }
    std::cout << std::endl << std::endl;
  }

结果如下

3 3 3 3 
3 3 3 3 
3 3 3 3 

3 3 3 3 
3 3 3 3 
3 3 3 3
3.常用函数API

参考从零开始编写深度学习库(四)Eigen::Tensor学习使用及代码重构
1.维度

  Eigen::Tensor<float, 2> a(3, 4);
  std::cout << "Dims " << a.NumDimensions;
  //=> Dims 2
  Eigen::Tensor<float, 2> a(3, 4);
  int dim1 = a.dimension(1);
  std::cout << "Dim 1: " << dim1;
  //=> Dim 1: 4

2.形状

  Eigen::Tensor<float, 2> a(3, 4);
  const Eigen::Tensor<float, 2>::Dimensions& d = a.dimensions();
  std::cout << "Dim size: " << d.size << ", dim 0: " << d[0]
	     << ", dim 1: " << d[1];
  //=> Dim size: 2, dim 0: 3, dim 1: 4

3.矩阵元素个数

  Eigen::Tensor<float, 2> a(3, 4);
  std::cout << "Size: " << a.size();
  //=> Size: 12

4.初始化

  /// 1.
  // setConstant(const Scalar& val),用于把一个矩阵的所有元素设置成一个指定的常数。
  Eigen::Tensor<string, 2> a(2, 3);
  a.setConstant("yolo");
  std::cout << "String tensor: " << endl << a << endl << endl;
  //=>
  // String tensor:
  // yolo yolo yolo
  // yolo yolo yolo
 
  /// 2.
  // setZero() 全部置零
  a.setZero();

  /// 3.
  // setRandom() 随机初始化
  a.setRandom();
  std::cout << "Random: " << endl << a << endl << endl;
  //=>
  //Random:
  //  0.680375    0.59688  -0.329554    0.10794
  // -0.211234   0.823295   0.536459 -0.0452059
  // 0.566198  -0.604897  -0.444451   0.257742
   
  /// 4.
  // setValues({..initializer_list}) 从列表、数据初始化
  Eigen::Tensor<float, 2> a(2, 3);
  a.setValues({{0.0f, 1.0f, 2.0f}, {3.0f, 4.0f, 5.0f}});
  std::cout << "a" << endl << a << endl << endl;
  //=>
  // a
  // 0 1 2
  // 3 4 5
  
  //如果给定的数组数据,少于矩阵元素的个数,那么后面不足的元素其值不变:
  Eigen::Tensor<int, 2> a(2, 3);
  a.setConstant(1000);
  a.setValues({{10, 20, 30}});
  std::cout << "a" << endl << a << endl << endl;
  //=>
  // a
  // 10   20   30
  // 1000 1000 1000

4.运算
参考Eigen Tensor详解【二】
4.1 一元运算

 operator-() 求相反数
 sqrt() 平方根
 rsqrt() 逆平方根
 square() 平方
 inverse()求逆
 exp()指数
 log() log运算
 abs() 绝对值
 pow(Scalar exponent)
 operator * (Scalar scale) 乘以某个值
 
void testUnary()
{
	Eigen::Tensor<int, 2> a(2, 3);
	a.setValues({ {0, 1, 8}, {27, 64, 125} });
	Eigen::Tensor<double, 2> b = a.cast<double>().pow(1.0 / 3.0);
	Eigen::Tensor<double, 2> sqrt = a.cast<double>().sqrt();
	Eigen::Tensor<double, 2> rsqrt = a.cast<double>().rsqrt();
	Eigen::Tensor<double, 2> square = a.cast<double>().square();
	Eigen::Tensor<double, 2> inverse = a.cast<double>().inverse();
	Eigen::Tensor<double, 2> exp = a.cast<double>().exp();
	Eigen::Tensor<double, 2> log = a.cast<double>().log();
	Eigen::Tensor<double, 2> abs = a.cast<double>().abs();
	Eigen::Tensor<int, 2> multiply = a * 2;
	std::cout << "a" << std::endl << a << std::endl <<std:: endl;
}

4.2 二元运算

 operator+(const OtherDerived& other)
 operator-(const OtherDerived& other)
 operator*(const OtherDerived& other)
 operator/(const OtherDerived& other)
 cwiseMax(const OtherDerived& other) //返回与原tensor同类型,同尺寸的tensor,且以两个原tensor的最大值填充
 cwiseMin(const OtherDerived& other)
//返回与原tensor同类型,同尺寸的tensor,且以两个原tensor的最小值填充
operator&&(const OtherDerived& other)
operator||(const OtherDerived& other)
operator<(const OtherDerived& other)
operator<=(const OtherDerived& other)
operator>(const OtherDerived& other)
operator>=(const OtherDerived& other)
operator==(const OtherDerived& other)
operator!=(const OtherDerived& other)
 
void testBinary()
{
	Eigen::Tensor<int, 2> a(2, 3);
	a.setValues({ {0, 1, 8}, {27, 64, 125} });
 
	Eigen::Tensor<int, 2> b = a * 3;
 
	std::cout << "a" << std::endl << a << std::endl << std::endl;
	std::cout << "b" << std::endl << b << std::endl << std::endl;
	std::cout << "a+b" << std::endl << a + b << std::endl << std::endl;
	std::cout << "a-b" << std::endl << a - b << std::endl << std::endl;
	std::cout << "a*b" << std::endl << a * b << std::endl << std::endl;
	std::cout << "a.cwiseMax(b)" << std::endl <<a.cwiseMax(b) << std::endl << std::endl;
	std::cout << "b.cwiseMax(a)" << std::endl << b.cwiseMax(a) << std::endl << std::endl;
	std::cout << "a.cwiseMin(b)" << std::endl << a.cwiseMin(b) << std::endl << std::endl;
	std::cout << "b.cwiseMin(a)" << std::endl << b.cwiseMin(a) << std::endl << std::endl;
}

4.3 三元运算和降维运算
看参考链接Eigen Tensor详解【二】

4.其他方式

参考Eigen构造使用三维矩阵
如果定义多维数据也可以使用Matrix模板来自定义,

Matrix<typename Scalar, int RowsAtCompileTime, int ColsAtCompileTime>
Eigen::Matrix<Eigen::MatrixXd,1,1> a;
Eigen::Matrix<Eigen::Matrix<double,1,5>,1,1> a;

Eigen::Matrix<Eigen::MatrixXd, 1, 1> a;//声明a,一个1*1矩阵
Eigen::MatrixXd b;      //声明b
b.setZero(1, 5); //对b初始化
b << 1, 2, 3, 4, 5;//对b赋值
a(0, 0) = b;//对a(0,0)赋值
std::cout << "a(0,0):  " << a(0, 0) << std::endl;//输出a(0,0)
std::cout << "b:  " << b << std::endl;//输出b
int row = a(0, 0).rows();//row为a(0,0)处矩阵的行维数
int col = a(0, 0).cols();//col为a(0,0)处矩阵的列维数
std::cout << "row:  " << row << "  col:  " << col << std::endl;//输出row和col值
参考

https://blog.csdn.net/hjimce/article/details/71710893
https://blog.csdn.net/fengshengwei3/article/details/103591178
http://eigen.tuxfamily.org/index.php?title=Tensor_support#Using_the_Tensor_module
https://eigen.tuxfamily.org/dox/unsupported/classEigen_1_1TensorFixedSize.html
https://zhuanlan.zhihu.com/p/148019818

你可能感兴趣的:(c++)