强大的C# Expression在一个函数求导问题中的简单运用

号称面试的题目总是非常有趣的,这里是又一个例子:

【原题出处

http://topic.csdn.net/u/20110928/15/B00A34FE-8544-42E2-A771-3C4A888DB85A.html

【问题梗概】

求一个函数的一阶导数。

【代码方案】

namespace Derivative
{
    class Program
    {
        // 求一个节点表达的算式的导函数  
        static Expression GetDerivative(Expression node)
        {
            if (node.NodeType == ExpressionType.Add
                || node.NodeType == ExpressionType.Subtract)
            {   // 该节点在做加减法,套用加减法导数公式  
                BinaryExpression binexp = (BinaryExpression)node;
                Expression dleft = GetDerivative(binexp.Left);
                Expression dright = GetDerivative(binexp.Right);
                BinaryExpression resbinexp;

                if (node.NodeType == ExpressionType.Add)
                    resbinexp = Expression.Add(dleft, dright);
                else
                    resbinexp = Expression.Subtract(dleft, dright);
                return resbinexp;
            }
            else if (node.NodeType == ExpressionType.Multiply)
            {   // 该节点在做乘法,套用乘法导数公式  
                BinaryExpression binexp = (BinaryExpression)node;
                Expression left = binexp.Left;
                Expression right = binexp.Right;

                Expression dleft = GetDerivative(left);
                Expression dright = GetDerivative(right);

                return Expression.Add(Expression.Multiply(dleft, right),
                    Expression.Multiply(left, dright));
            }
            else if (node.NodeType == ExpressionType.Parameter)
            {   // 该节点是x本身(叶子节点),故而其导数即常数1  
                return Expression.Constant(1.0);
            }
            else if (node.NodeType == ExpressionType.Constant)
            {   // 该节点是一个常数(叶子节点),故其导数为零  
                return Expression.Constant(0.0);
            }
            else if (node.NodeType == ExpressionType.Call)
            {
                MethodCallExpression callexp = (MethodCallExpression)node;
                Expression arg0 = callexp.Arguments[0];
                // 一下一元函数求导后均需要乘上自变量的导数
                Expression darg0 = GetDerivative(arg0);
                if (callexp.Method.Name == "Exp")
                {
                    // 指数函数的导数还是其本身
                    return Expression.Multiply(
                           Expression.Call(null, callexp.Method, arg0), darg0);
                }
                else if (callexp.Method.Name == "Sin")
                {
                    // 正弦函数的倒数是余弦函数
                    MethodInfo miCos = typeof(Math).GetMethod("Cos", 
                                       BindingFlags.Public | BindingFlags.Static);
                    return Expression.Multiply(
                           Expression.Call(null, miCos, arg0), darg0);
                }
                else if (callexp.Method.Name == "Cos")
                {
                    // 余弦函数的导数是正弦函数的相反数
                    MethodInfo miSin = typeof(Math).GetMethod("Sin", 
                                       BindingFlags.Public | BindingFlags.Static);
                    return Expression.Multiply(
                           Expression.Negate(Expression.Call(null, miSin, arg0)), darg0);
                }
            }

            throw new NotImplementedException();    // 其余的尚未实现          
        }

        static Func<double, double> GetDerivative(Expression<Func<double, double>> func)
        {
            // 从Lambda表达式中获得函数体  
            Expression resBody = GetDerivative(func.Body);

            // 需要续用Lambda表达式的自变量  
            ParameterExpression parX = func.Parameters[0];

            Expression<Func<double, double>> resFunc
                = (Expression<Func<double, double>>)Expression.Lambda(resBody, parX);

            Console.WriteLine("diff function = {0}", resFunc);

            // 编译成CLR的IL表达的函数  
            return resFunc.Compile();
        }

        static double GetDerivative(Expression<Func<double, double>> func, double x)
        {
            Func<double, double> diff = GetDerivative(func);
            return diff(x);
        }

        static void Main(string[] args)
        {
            // 举例:求出函数f(x) = cos(x*x)+sin(3*x)+exp(2*x)在x=2.0处的导数  
            double y = GetDerivative(x => Math.Cos(x*x) + Math.Sin(3*x) + Math.Exp(2*x), 2.0);
            Console.WriteLine("f'(x) = {0}", y);
        }
    }
}  


【实现大意】

用表达式分解并递归求导(过程是相当容易的,比想象的还容易)。目前只是实现了一个最简单的模型。

【优势】

给出的是解析解,在求导运算方面没有任何数值解的误差,输出运算也是瞬时的,时间复杂度仅和表达式复杂度相关。

【限制】

1. 函数只能以Lambda表达式输入,只能是能求出解析解的表达式

2. 目前只实现了加减法和乘法

【后续扩展】

1. 实现其他运算符(没有太大难度,只是比较繁琐而已)

2. 表达式树优化(也不太难的,根据情况定),最基本的可以从常数乘法开始……

3. 条件运算符的处理(这个会变得极难极复杂,但一定程度上实现分段函数求导),其他特殊情况(对求导还可以,如果考虑求不定积分问题可能会有很多特殊情况和hardcode)

4. 输入端向字符串解析过渡;复杂运算符->逐渐向自定义的数据结构过渡?……

...

【更新】

20110611: 添加三角和指数函数支持,优化仍未进行。

你可能感兴趣的:(强大的C# Expression在一个函数求导问题中的简单运用)