从理论上讲,只要允许使用栈,所有的递归程序都可以转化成迭代。
但是并非所有递归都必须用栈,不用堆栈也可以转化成迭代的,大致有两类
- 尾递归:可以通过简单的变换,让递归作为最后一条语句,并且仅此一个递归调用。
// recursive int fac1(int n) { if (n <= 0) return 1; return n * fac1(n-1); } // iterative int fac2(int n) { int i = 1, y = 1; for (; i <= n; ++i) y *= i; return y; }
- 自顶向下->自底向上:对程序的结构有深刻理解后,自底向上计算,比如 fibnacci 数列的递归->迭代转化。
// recursive, top-down int fib1(int n) { if (n <= 1) return 1; return fib1(n-1) + fib1(n-2); } // iterative, down-top int fib2(int n) { int f0 = 1, f1 = 1, i; for (i = 2; i <= n; ++i) { int f2 = f1 + f0; f0 = f1; f1 = f2; } return f1; }
对于非尾递归,就必须使用堆栈。可以简单生硬地使用堆栈进行转化:把函数调用和返回的地方翻译成汇编代码,然后把对硬件 stack 的 push, pop 操作转化成对私有 stack 的 push, pop ,这其中需要特别注意的是对返回地址的 push/pop,对应的硬件指令一般是 call/ret。使用私有 stack 有两个好处:
- 可以省去公用局部变量,也就是在任何一次递归调用中都完全相同的函数参数,再加上从这些参数计算出来的局部变量。
- 如果需要得到当前的递归深度,可以从私有 stack 直接拿到,而用递归一般需要一个单独的 depth 变量,然后每次递归调用加 1。
我们把私有 stack 元素称为 Frame,那么 Frame 中必须包含以下信息:
- 返回地址(对应于每个递归调用的下一条语句的地址)
- 对每次递归调用都不同的参数
通过实际操作,我发现,有一类递归的 Frame 可以省去返回地址!所以,这里又分为两种情况:
- Frame 中可以省去返回地址的递归:仅有两个递归调用,并且其中有一个是尾递归。
// here used a function 'partition', but don't implement it tempalte<class RandIter> void QuickSort1(RandIter beg, RandIter end) { if (end - beg <= 1) return; RandIter pos = partition(beg, end); QuickSort1(beg, pos); QuickSort1(pos + 1, end); } tempalte<class RandIter> void QuickSort2(RandIter beg, RandIter end) { std::stack<std::pair<RandIter> > stk; stk.push({beg, end}); while (!stk.empty()) { std::pair<RandIter, RandIter> ii = stk.top(); stk.pop(); if (ii.second - ii.first) > 1) { RandIter pos = partition(beg, end); stk.push({ii.first, pos}); stk.push({pos + 1, ii.second}); } } }
- Frame 中必须包含返回地址的递归,这个比较复杂,所以我写了个完整的示例:
- 以MergeSort为例,因为 MergeSort 是个后序过程,两个递归调用中没有任何一个是尾递归
- MergeSort3 使用了 GCC 的 Label As Value 特性,只能在 GCC 兼容的编译器中使用
- 单纯对于这个实例来说,返回地址其实只有两种,返回地址为 0 的情况可以通过判断私有栈(varname=stk)是否为空,stk为空时等效于 retaddr == 0。如果要精益求精,一般情况下指针的最低位总是0,可以把这个标志保存在指针的最低位,当然,如此的话就无法对 sizeof(T)==1 的对象如 char 进行排序了。
- #include <stdio.h> #include <string.h> # if 1 #include <stack> #include <vector> template<class T> class MyStack : public std::stack<T, std::vector<T> > { }; #else template<class T> class MyStack { union { char* a; T* p; }; int n, t; public: explicit MyStack(int n=128) { this->n = n; this->t = 0; a = new char[n*sizeof(T)]; } ~MyStack() { while (t > 0) pop(); delete[] a; } void swap(MyStack<T>& y) { char* q = y.a; y.a = a; a = q; int z; z = y.n; y.n = n; n = z; z = y.t; y.t = t; t = z; } T& top() const { return p[t-1]; } void pop() { --t; p[t].~T(); } void push(const T& x) { x.print(); // debug p[t] = x; ++t; } int size() const { return t; } bool empty() const { return 0 == t; } bool full() const { return n == t; } }; #endif template<class T> struct Frame { static T* base; T *beg, *tmp; int len; int retaddr; Frame(T* beg, T* tmp, int len, int retaddr) : beg(beg), tmp(tmp), len(len), retaddr(retaddr) {} void print() const { // for debug printf("%4d %4d %d/n", int(beg-base), len, retaddr); } }; template<class T> T* Frame<T>::base; #define TOP(field) stk.top().field template<class T> bool issorted(const T* a, int n) { for (int i = 1; i < n; ++i) { if (a[i-1] > a[i]) return false; } return true; } template<class T> void mymerge(const T* a, int la, const T* b, int lb, T* c) { int i = 0, j = 0, k = 0; for (; i < la && j < lb; ++k) { if (b[j] < a[i]) c[k] = b[j], ++j; else c[k] = a[i], ++i; } for (; i < la; ++i, ++k) c[k] = a[i]; for (; j < lb; ++j, ++k) c[k] = b[j]; } template<class T> void MergeSort1(T* beg, T* tmp, int len) { if (len > 1) { int mid = len / 2; MergeSort1(beg , tmp , mid); MergeSort1(beg+mid, tmp+mid, len-mid); mymerge(tmp, mid, tmp+mid, len-mid, beg); memcpy(tmp, beg, sizeof(T)*len); } else *tmp = *beg; } template<class T> void MergeSort2(T* beg0, T* tmp0, int len0) { int mid; int cnt = 0; Frame<T>::base = beg0; MyStack<Frame<T> > stk; stk.push(Frame<T>(beg0, tmp0, len0, 0)); while (true) { ++cnt; if (TOP(len) > 1) { mid = TOP(len) / 2; stk.push(Frame<T>(TOP(beg), TOP(tmp), mid, 1)); continue; L1: mid = TOP(len) / 2; stk.push(Frame<T>(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid, 2)); continue; L2: mid = TOP(len) / 2; mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg)); memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len)); } else *TOP(tmp) = *TOP(beg); int retaddr0 = TOP(retaddr); stk.pop(); switch (retaddr0) { case 0: return; case 1: goto L1; case 2: goto L2; } } } // This Implementation Use GCC's goto saved label value // Very similiar with recursive version template<class T> void MergeSort3(T* beg0, T* tmp0, int len0) { MyEntry: int mid; int retaddr; Frame<T>::base = beg0; MyStack<Frame<T> > stk; stk.push(Frame<T>(beg0, tmp0, len0, 0)); #define Cat1(a,b) a##b #define Cat(a,b) Cat1(a,b) #define HereLabel() Cat(HereLable_, __LINE__) #define RecursiveCall(beg, tmp, len) / stk.push(Frame<T>(beg, tmp, len, (char*)&&HereLabel() - (char*)&&MyEntry)); / continue; / HereLabel():; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // retaddr == 0 是最外层的递归调用, // 只要到达这一层时 retaddr 才为 0, // 此时就可以返回了 #define MyReturn / retaddr = TOP(retaddr); / stk.pop(); / if (0 == retaddr) { / return; / } / goto *((char*)&&MyEntry + retaddr); //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ while (true) { if (TOP(len) > 1) { mid = TOP(len) / 2; RecursiveCall(TOP(beg), TOP(tmp), mid); mid = TOP(len) / 2; RecursiveCall(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid); mid = TOP(len) / 2; mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg)); memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len)); } else *TOP(tmp) = *TOP(beg); MyReturn; } } template<class T> void MergeSortDriver(T* beg, int len, void (*mf)(T* beg_, T* tmp_, int len_)) { T* tmp = new T[len]; (*mf)(beg, tmp, len); delete[] tmp; } #define test(a,n,mf) / memcpy(a, b, sizeof(a[0])*n); / MergeSortDriver(a, n, &mf); / printf("sort by %s:", #mf); / for (i = 0; i < n; ++i) printf("% ld", a[i]); / printf("/n"); int main(int argc, char* argv[]) { int n = argc - 1; int i; long* a = new long[n]; long* b = new long[n]; for (i = 0; i < n; ++i) b[i] = strtol(argv[i+1], NULL, 10); test(a, n, MergeSort1); test(a, n, MergeSort2); test(a, n, MergeSort3); printf("All Successed/n"); delete[] a; delete[] b; return 0; }