表达式模板是一种C++模板元编程(template metaprogram)技术。典型情况下,表达式模板自身代表一种操作,模板参数代表该操作的操作数。模板表达式可将子表达式的计算推迟,这样 有利于优化(特别是减少临时变量的使用)。表达式模板也可以作为参数传递给一个函数。
例子:我们实现一个用来求表达式 x = 1.2*x + x*y 的模板表达式,其中x、y为数组
//exprarray.h #include <stddef.h> #include <cassert> #include "sarray.h" template<typename T> class A_Scale { public: A_Scale(T const& t):value(t){} T operator[](size_t) const { return value; } size_t size() const { return 0; } private: T const& value; }; template<typename T> class A_Traits { public: typedef T const& exprRef; }; template<typename T> class A_Traits<A_Scale<T> > { public: typedef A_Scale<T> exprRef; }; template<typename T,typename L1,typename R2> class A_Add { private: typename A_Traits<L1>::exprRef op1; typename A_Traits<R2>::exprRef op2; public: A_Add(L1 const& a,R2 const& b):op1(a),op2(b) { } T operator[](size_t indx) const { return op1[indx] + op2[indx]; } size_t size() const { assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size()); return op1.size() != 0 ? op1.size() : op2.size(); } }; template<typename T,typename L1,typename R2> class A_Mul { private: typename A_Traits<L1>::exprRef op1; typename A_Traits<R2>::exprRef op2; public: A_Mul(L1 const& a,R2 const& b):op1(a),op2(b) { } T operator[](size_t indx) const { return op1[indx] * op2[indx]; } size_t size() const { assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size()); return op1.size() != 0 ? op1.size():op2.size(); } }; template<typename T,typename Rep = SArray<T> > class Array { public: explicit Array(size_t N):expr_Rep(N){} Array(Rep const& rep):expr_Rep(rep){} Array& operator=(Array<T> const& orig) { assert(size() == orig.size()); for (size_t indx=0;indx < orig.size();indx++) { expr_Rep[indx] = orig[indx]; } return *this; } template<typename T2,typename Rep2> Array& operator=(Array<T2,Rep2> const& orig) { assert(size() == orig.size()); for (size_t indx=0;indx<orig.size();indx++) { expr_Rep[indx] = orig[indx]; } return *this; } size_t size() const { return expr_Rep.size(); } T operator[](size_t indx) const { assert(indx < size()); return expr_Rep[indx]; } T& operator[](size_t indx) { assert(indx < size()); return expr_Rep[indx]; } Rep const& rep() const { return expr_Rep; } Rep& rep() { return expr_Rep; } private: Rep expr_Rep; }; template<typename T,typename L1,typename R2> Array<T,A_Add<T,L1,R2> > operator+(Array<T,L1> const& a,Array<T,R2> const& b) { return Array<T,A_Add<T,L1,R2> >(A_Add<T,L1,R2>(a.rep(),b.rep())); } template<typename T,typename L1,typename R2> Array<T,A_Mul<T,L1,R2> > operator*(Array<T,L1> const& a,Array<T,R2> const& b) { return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep())); } template<typename T,typename R2> Array<T,A_Mul<T,A_Scale<T>,R2> > operator*(T const& a,Array<T,R2> const& b) { return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep())); }
测试代码(求解表达式1.2*x+x*y):
//test.cpp #include "exprarray.h" #include <iostream> using namespace std; template <typename T> void print (T const& c) { for (int i=0; i<8; ++i) { std::cout << c[i] << ' '; } std::cout << "..." << std::endl; } int main() { Array<double> x(1000), y(1000); for (int i=0; i<1000; ++i) { x[i] = i; y[i] = x[i]+x[i]; } std::cout << "x: "; print(x); std::cout << "y: "; print(y); x = 1.2 * x; std::cout << "x = 1.2 * x: "; print(x); x = 1.2*x + x*y; std::cout << "1.2*x + x*y: "; print(x); x = y; std::cout << "after x = y: "; print(x); return 0; }
下面我们来分析一下模板表达式的解析过程:
我们以表达式 x = 1.2*x + x*y为例
当编译器解析表达式:x = 1.2*x + x*y 的时候,编译器首先会应用最左边的*运算符,它是一个Scale-Array运算符。于是重载解析规则将会选择operator*的Scale-Array形式:
template<typename T,typename R2> Array<T,A_Mul<T,A_Scale<T>,R2> > operator*(T const& a,Array<T,R2> const& b) { return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep())); }
其中操作数的类型是double和Array<double,SArray<double> >,因此实际的结果类型是:
Array<double,A_Mul<double,A_Scale<double>,SArray<double> > >
接下来,编译器会对第二个乘法进行求值:x*y是一个array-array操作,这一次,我们将会选择operator*的Array-Array重载操作:
template<typename T,typename L1,typename R2> Array<T,A_Mul<T,L1,R2> > operator*(Array<T,L1> const& a,Array<T,R2> const& b) { return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep())); }
其中两个操作数类型都是Array<double,SArray<double> >,因此结果类型为:
Array<double,A_Mul<double,SArray<double>,SArray<double> > >
这一次,A_Mul所封装的连个参数对象都引用了一个SArray<double>表示:即一个表示x对象,一个表示y对象。
现在开始对+运算符进行求值。这次还是Array-Array操作,因此调用Array-Array版本的operator+:
template<typename T,typename L1,typename R2> Array<T,A_Add<T,L1,R2> > operator+(SArray<T,L1> const& a,SArray<T,R2> const& b) { return Array<T,L1,R2>(A_Add<T,L1,R2>(a.rep(),b.rep())); }
其中用double来替换T,则R1为:
A_Mul<double,A_Scale<double>,SArray<double> >
R2为:
A_Mul<double,SArray<double>,SArray<double> >
因此赋值表达式 x = 1.2*x + x*y的右边经过编译器解析后的最终类型为:
Array<double, A_Add<double, A_Mul<double,A_Scale<double>,SArray<double> > A_Mul<double,SArray<double>,SArray<double> > > >
这个类型将与Array模板的赋值运算符模板进行匹配:
//针对不同类型数组的赋值运算符 template<typename T2,typename Rep2> Array& operator=(Array<T2,Rep2> const& orig) { assert(size() == orig.size()); for (size_t indx=0;indx<orig.size();indx++) { expr_Rep[indx] = orig[indx]; } return *this; }
此时,赋值运算符将会运用右边Array的下标运算符来计算目标数组的每一个元素,而Array的实际类型为:
Array<double, A_Add<double, A_Mul<double,A_Scale<double>,SArray<double> > A_Mul<double,SArray<double>,SArray<double> > > >
我们记为:ArrayTgt
此时,ArrayTgt[indx]将会匹配模板类A_Add中的重载操作符operator[],即:
T operator[](size_t indx) const { return op1[indx] + op2[indx]; }
匹配之后就变成:
A_Mul<double,A_Scale<double>,SArray<double> >[indx] + A_Mul<double,SArray<double>,SArray<double> >[indx];
而A_Mul[indx]又会匹配模板类A_Mul中的重载操作符operator[],即:
T operator[](size_t indx) const { return op1[indx] * op2[indx]; }
匹配之后就变成:
A_Scale<double>[indx] * SArray<double>[indx] + SArray<double>[indx] * SArray<double>[indx]
而A_Scale[indx]又会匹配模板类A_Scale中的重载操作符operator[],即:
T operator[](size_t) const { return value; }
这样最终的结果就表达式就变成:
value[indx] * SArray<double>[indx] + SArray<double>[indx] * SArray<double>[indx]
至此,整个模板表达式的解析工作已经完成,只需进行计算即可。在整个计算过程中,没有产生任何的中间变量,所以程序的效率得以大幅的提高。
程序注意事项:
1.在上述代码中,如果将模板类Array的代码:
Array& operator=(Array<T2,Rep2> const& orig) { assert(size() == orig.size()); for (size_t indx=0;indx<orig.size();indx++) { expr_Rep[indx] = orig[indx]; } return *this; }
中的参数改为Array<T2,Rep2> & orig,即变成:
Array& operator=(Array<T2,Rep2>& orig) { assert(size() == orig.size()); for (size_t indx=0;indx<orig.size();indx++) { expr_Rep[indx] = orig[indx]; } return *this; }
将会导致编译出错,原因是:
在test.cpp文件中,我们使用了表达式:x = 1.2 * x ,这个表达式的右边将会被编译器解析为如下形式的表达式:
Array<double,A_Mul<double,A_Scale<double>,SArray<doube> > >
这样在进行重载操作符operator[]的匹配时,将会变成如下形式:
SArray[indx] = A_Scale[indx] * SArray[indx]
到了这一步,问题就出现了,因为A_Scale[indx]会匹配模板类Array中的重载操作符operator[],但是我们发现在模板类Array代码中,有两个重载的operator[],即:
T operator[](size_t indx) const { assert(indx < size()); return expr_Rep[indx]; } T& operator[](size_t indx) { assert(indx < size()); return expr_Rep[indx]; }
如果我们没在重载操作符operator=的参数中写入const的话,这里会优先调用无const的operator[]重载函数,但是A_Scale[indx]是个常数,在本例中也就是一个double类型,这样最后在调用operator[]返回的时候就出现了类型不匹配的现象,因为无const的operator[]返回的类型是double&,所以会报错。当然,我们可以将test.cpp程序中的表达式1.2*x去掉,我们会发现,这个时候无const的operator=就会编译通过。
2.在上述代码中,模板类Array的构造函数代码为:
explicit Array(size_t N):expr_Rep(N){}
这表明定义Array必须通过显式转型,不能通过隐式转型。下述代码会导致编译出错:
Array a = 5;
我们只能使用
Array a(5);
进行显式初始化。
下面用一个例子来区别显式转型和隐式转型的细微区别:
X x;
Y y(x); //显式转型
Y y = x;//隐式转型
其中前者通过使用从X到Y类型的显式转型,新建一个类型为Y的对象。后者使用了一个从类型X到Y类型的隐式转型,新建了一个类型Y的对象。
3.在上述代码中,模板类Array的两个重载操作符operator[]代码:
T operator[](size_t indx) const { assert(indx < size()); return expr_Rep[indx]; } T& operator[](size_t indx) { assert(indx < size()); return expr_Rep[indx]; }
注意在一个重载操作符函数后面的const一定不能少,否则会导致编译错误。因为没有const的话,函数
T operator[](size_t indx) { assert(indx < size()); return expr_Rep[indx]; }
和
T& operator[](size_t indx) { assert(indx < size()); return expr_Rep[indx]; }
会被认为是一个函数,因为他们静静是返回类型不同而已。函数
int test(){}
和
int test() const{}
会被编译器理解为两个不同的函数。
最后将SArray的代码附上:
#ifndef SARRAY_H #define SARRAY_H #include <stddef.h> #include <cassert> template<typename T> class SArray { public: explicit SArray(size_t N):ptr(new T[N]),_size(N) { init(); } SArray(SArray<T> const& orig):ptr(new T[orig.size()]),_size(orig.size()) { copy(orig); } ~SArray() { delete[] ptr; } size_t size() const { return _size; } T operator[](size_t indx) const { return ptr[indx]; } T& operator[](size_t indx) { return ptr[indx]; } SArray<T>& operator=(SArray<T> const& orig) { if (&orig != this) { copy(orig); } return *this; } protected: void copy(SArray<T> const& orig) { assert(size() == orig.size()); for (size_t indx=0;indx<orig.size();indx++) { ptr[indx] = orig[indx]; } } void init() { for(size_t i=0;i<size();i++) { ptr[i] = T(); } } private: T* ptr; size_t _size; }; #endif