玩转NET Expression

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Fast.Framework.Models;


namespace Fast.Framework.Extensions
{

    /// 
    /// 表达式扩展类
    /// 
    public static class ExpressionExtensions
    {
        /// 
        /// 解析Sql
        /// 
        /// 表达式
        /// 选项
        /// 
        public static SqlInfo ResolveSql(this Expression expression, SqlExpressionOptions options)
        {
            var ex = new SqlExpressionVisitor(options);
            ex.Visit(expression);
            return ex.sqlInfo;
        }
    }

    /// 
    /// 解析类型
    /// 
    public enum ResolveType
    {
        /// 
        /// Sql字符串
        /// 
        SqlString = 0,

        /// 
        /// 调用
        /// 
        Call = 1,

        /// 
        /// 左相似
        /// 
        LeftLike = 2,

        /// 
        /// 右相似
        /// 
        RightLike = 3,

        /// 
        /// 相似
        /// 
        Like = 4
    }

    /// 
    /// 表达式映射
    /// 
    public static class ExpressionMapper
    {
        /// 
        /// 表达式类型
        /// 
        public static readonly Dictionarystring> expressionTypes;

        /// 
        /// 方法调用
        /// 
        public static readonly Dictionary<string, Action, Action>> methodCall;

        /// 
        /// 转换
        /// 
        public static readonly Action, Action> convert;

        /// 
        /// 构造方法
        /// 
        static ExpressionMapper()
        {
            expressionTypes = new Dictionarystring>()
            {
                {ExpressionType.Add,"+" },
                {ExpressionType.Subtract,"-" },
                {ExpressionType.Multiply,"*" },
                {ExpressionType.Divide,"/" },
                {ExpressionType.AndAlso,"AND" },
                {ExpressionType.OrElse,"OR" },
                {ExpressionType.Equal,"=" },
                {ExpressionType.NotEqual,"<>" },
                {ExpressionType.GreaterThan,">" },
                {ExpressionType.LessThan,"<" },
                {ExpressionType.GreaterThanOrEqual,">=" },
                {ExpressionType.LessThanOrEqual,"<=" }
            };

            methodCall = new Dictionary<string, Action, Action>>()
            {
                #region 聚合函数
                {"Sum", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("SUM");
                } },
                {"Max", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("MAX");
                } },
                {"Min", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("MIN");
                } },
                {"Count", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("COUNT");
                } },
                {"Avg", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("AVG");
                } },
                #endregion

                #region 日期函数
                {"GetDate", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("GETDATE");
                } },
                {"Year", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("YEAR");
                } },
                {"Month", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("MONTH");
                } },
                {"Day", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("DAY");
                } },
                {"Current_Timestamp", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("CURRENT_TIMESTAMP");
                } },
                {"Current_Date", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("CURRENT_DATE");
                } },
                {"Current_Time", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("CURRENT_TIME");
                } },
                #endregion

                #region 数学函数
                {"Abs", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("ABS");
                } },
                {"Round", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[1]);
                    sql.SqlStack.Push(",");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("ROUND");
                } },
                #endregion

                #region 字符串函数
                {"Contains", (sql,node,visit,call)=>
                {
                    call(ResolveType.Like);
                    visit.Invoke(node.Arguments[0]);
                    call(ResolveType.SqlString);
                    sql.SqlStack.Push(" LIKE ");
                    visit.Invoke(node.Object);
                } },
                {"StartsWith", (sql,node,visit,call)=>
                {
                    call(ResolveType.LeftLike);
                    visit.Invoke(node.Arguments[0]);
                    call(ResolveType.SqlString);
                    sql.SqlStack.Push(" LIKE ");
                    visit.Invoke(node.Object);
                } },
                {"EndsWith", (sql,node,visit,call)=>
                {
                    call(ResolveType.RightLike);
                    visit.Invoke(node.Arguments[0]);
                    call(ResolveType.SqlString);
                    sql.SqlStack.Push(" LIKE ");
                    visit.Invoke(node.Object);
                } },
                {"Substring", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    if (node.Arguments.Count > 1)
                    {
                        visit.Invoke(node.Arguments[1]);
                        sql.SqlStack.Push(",");
                    }
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push(",");
                    visit.Invoke(node.Object);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("SUBSTRING");
                } },
                {"Replace", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[1]);
                    sql.SqlStack.Push(",");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push(",");
                    visit.Invoke(node.Object);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("REPLACE");
                } },
                {"ToUpper", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Object);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("UPPER");
                } },
                {"ToLower", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Object);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("LOWER");
                } },
                {"Trim", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Object);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("TRIM");
                } },
                #endregion

                #region 其它函数
                {"Equals", (sql,node,visit,call)=>
                {
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push(" = ");
                    visit.Invoke(node.Object);
                } },
                {"IsNull", (sql,node,visit,call)=>
                {
                    sql.SqlStack.Push(")");
                    visit.Invoke(node.Arguments[1]);
                    sql.SqlStack.Push(",");
                    visit.Invoke(node.Arguments[0]);
                    sql.SqlStack.Push("(");
                    sql.SqlStack.Push("ISNULL");
                    visit.Invoke(node.Object);
                } }
                #endregion
            };
        }
    }

    /// 
    /// Sql表达式访问
    /// 
    public class SqlExpressionVisitor : ExpressionVisitor
    {

        /// 
        /// Sql信息
        /// 
        public readonly SqlInfo sqlInfo;

        /// 
        /// 参数索引
        /// 
        private readonly Dictionary<string, int> parameterIndex;

        /// 
        /// 成员信息
        /// 
        private readonly Stack memberInfos;

        /// 
        /// 表达式选项
        /// 
        private readonly SqlExpressionOptions options;

        /// 
        /// 解析类型
        /// 
        private ResolveType resolveType;

        /// 
        /// 构造方法
        /// 
        /// 表达式选项
        public SqlExpressionVisitor(SqlExpressionOptions options, ResolveType resolveType = ResolveType.SqlString)
        {
            sqlInfo = new SqlInfo();
            parameterIndex = new Dictionary<string, int>();
            memberInfos = new Stack();
            this.options = options;
            this.resolveType = resolveType;
        }

        /// 
        /// 访问
        /// 
        /// 
        /// 
        [return: NotNullIfNotNull("node")]
        public override Expression Visit(Expression node)
        {
            //Console.WriteLine(node?.NodeType);
            return base.Visit(node);
        }

        /// 
        /// 访问Lambda
        /// 
        /// 
        /// 节点
        /// 
        protected override Expression VisitLambda(Expression node)
        {
            if (!options.IgnoreParameterExpression)
            {
                var index = 0;
                foreach (var parameter in node.Parameters)
                {
                    parameterIndex.Add(parameter.Name, index);
                    index++;
                }
            }
            return Visit(node.Body);
        }

        /// 
        /// 访问一元表达式
        /// 
        /// 节点
        /// 
        protected override Expression VisitUnary(UnaryExpression node)
        {
            var memberExpression = (node as UnaryExpression).Operand as MemberExpression;
            var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);
            var array = (sqlExpressionVisitor.Visit(memberExpression) as ConstantExpression).Value as object[];
            return Visit(Expression.Constant(array.Length));
        }

        /// 
        /// 访问二元表达式
        /// 
        /// 节点
        /// 
        protected override Expression VisitBinary(BinaryExpression node)
        {
            if (node.NodeType == ExpressionType.ArrayIndex)
            {
                var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);
                var array = (sqlExpressionVisitor.Visit(node.Left) as ConstantExpression).Value as object[];
                var value = array[Convert.ToInt32((sqlExpressionVisitor.Visit(node.Right) as ConstantExpression).Value)];
                return VisitConstant(Expression.Constant(value));
            }
            else
            {
                if (!ExpressionMapper.expressionTypes.ContainsKey(node.NodeType))
                {
                    throw new Exception($"暂不支持{node.NodeType}类型解析");
                }
                var op = ExpressionMapper.expressionTypes[node.NodeType];
                if (node.Right.NodeType == ExpressionType.Constant && ((ConstantExpression)node.Right).Value == null)//判断null值
                {
                    if (op == "=")
                    {
                        sqlInfo.SqlStack.Push(" IS NULL");
                    }
                    else if (op == "<>")
                    {
                        sqlInfo.SqlStack.Push(" IS NOT NULL");
                    }
                }
                else
                {
                    if (node.Right is BinaryExpression && node.Right.NodeType == ExpressionType.Equal)
                    {
                        sqlInfo.SqlStack.Push(")");
                        Visit(node.Right);//右边带括号
                        sqlInfo.SqlStack.Push("(");
                    }
                    else
                    {
                        Visit(node.Right);//右边
                    }
                    sqlInfo.SqlStack.Push($" {op} ");//操作符
                }
                if (node.Left is BinaryExpression && node.Left.NodeType == ExpressionType.Equal)
                {
                    sqlInfo.SqlStack.Push(")");
                    Visit(node.Left); //左边带括号
                    sqlInfo.SqlStack.Push("(");
                }
                else
                {
                    Visit(node.Left); //左边
                }
            }
            return node;
        }

        /// 
        /// 访问索引
        /// 
        /// 节点
        /// 
        protected override Expression VisitIndex(IndexExpression node)
        {
            return base.VisitIndex(node);
        }

        /// 
        /// 访问参数
        /// 
        /// 节点
        /// 
        protected override Expression VisitParameter(ParameterExpression node)
        {
            if (!options.IgnoreParameterExpression)
            {
                sqlInfo.SqlStack.Push($"p{parameterIndex[node.Name]}.");
            }
            return node;
        }

        /// 
        /// 访问对象
        /// 
        /// 节点
        /// 
        protected override Expression VisitNew(NewExpression node)
        {
            if (resolveType == ResolveType.SqlString)
            {
                for (int i = 0; i < node.Arguments.Count; i++)
                {
                    var name = AddIdentifier(node.Members[i].IsDefined(typeof(ColumnAttribute)) ? node.Members[i].GetCustomAttribute().Name : node.Members[i].Name);
                    Visit(node.Arguments[i]);
                    var value = string.Join("", sqlInfo.SqlStack.ToList());
                    sqlInfo.NewKeyValues.Add(name, value);
                    sqlInfo.NewNames.Add(name);
                    sqlInfo.NewValues.Add(value);
                    sqlInfo.NewAssignMapper.Add($"{name} = {value}");
                    sqlInfo.NewAsMapper.Add($"{value} AS {name}");
                    sqlInfo.SqlStack.Clear();
                }
            }
            else if (resolveType == ResolveType.Call)
            {
                var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);

                var args = new List<object>();
                foreach (var item in node.Arguments)
                {
                    args.Add((sqlExpressionVisitor.Visit(item) as ConstantExpression).Value);
                }

                var type = node.Type;

                object obj = null;

                if (type.IsGenericTypeDefinition)//泛型类
                {
                    type = type.MakeGenericType(type.GetGenericArguments());
                }
                if (type.IsDefined(typeof(CompilerGeneratedAttribute)))//匿名对象
                {
                    var argsTypes = args.Select(s => s?.GetType()).ToList().GetTypeDefaultValue();
                    obj = Activator.CreateInstance(type, argsTypes.ToArray());
                    for (int i = 0; i < node.Members.Count; i++)
                    {
                        var member = node.Members[i];
                        type.GetField($"<{member.Name}>i__Field", BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance).SetValue(obj, args[i]);
                    }
                }
                else
                {
                    obj = Activator.CreateInstance(type, args.ToArray());
                }
                return VisitConstant(Expression.Constant(obj));
            }
            return node;
        }

        /// 
        /// 访问列表初始化
        /// 
        /// 节点
        /// 
        protected override Expression VisitListInit(ListInitExpression node)
        {
            var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);

            var args = new List<object>();
            args.Add((sqlExpressionVisitor.Visit(node.NewExpression) as ConstantExpression).Value);

            var type = node.Type;
            if (type.IsGenericTypeDefinition)
            {
                type = type.MakeGenericType(type.GetGenericArguments());
            }
            var list = Activator.CreateInstance(type, args.ToArray());
            foreach (var item in node.Initializers)
            {
                var args2 = new List<object>();
                foreach (var item2 in item.Arguments)
                {
                    args2.Add((sqlExpressionVisitor.Visit(item2) as ConstantExpression).Value);
                }
                item.AddMethod.Invoke(list, args2.ToArray());
            }
            if (resolveType == ResolveType.SqlString)
            {
                return VisitConstant(Expression.Constant(string.Join(",", list)));
            }

            return VisitConstant(Expression.Constant(list));
        }

        /// 
        /// 访问对象数组
        /// 
        /// 节点
        /// 
        protected override Expression VisitNewArray(NewArrayExpression node)
        {
            var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);

            var array = Array.CreateInstance(node.Type.GetElementType(), node.Expressions.Count);

            for (int i = 0; i < array.Length; i++)
            {
                array.SetValue((sqlExpressionVisitor.Visit(node.Expressions[i]) as ConstantExpression).Value, i);
            }

            if (resolveType == ResolveType.SqlString)
            {
                return VisitConstant(Expression.Constant(string.Join(",", array as object[])));
            }

            return VisitConstant(Expression.Constant(array));
        }

        /// 
        /// 访问方法
        /// 
        /// 节点
        /// 
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (ExpressionMapper.methodCall.ContainsKey(node.Method.Name) && resolveType == ResolveType.SqlString)
            {
                ExpressionMapper.methodCall[node.Method.Name].Invoke(sqlInfo, node, e => Visit(e), (type) =>
                {
                    resolveType = type;//回调设置类型
                });
            }
            else
            {
                var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);
                var args = new List<object>();
                foreach (var item in node.Arguments)
                {
                    args.Add((sqlExpressionVisitor.Visit(item) as ConstantExpression).Value);
                }
                object obj = null;
                if (node.Object == null)
                {
                    obj = this;
                }
                else
                {
                    obj = (sqlExpressionVisitor.Visit(node.Object) as ConstantExpression).Value;
                }
                var value = node.Method.Invoke(obj, args.ToArray());
                resolveType = ResolveType.SqlString;
                return Visit(Expression.Constant(value));
            }
            return node;
        }

        /// 
        /// 访问成员初始化
        /// 
        /// 
        /// 
        protected override Expression VisitMemberInit(MemberInitExpression node)
        {
            var expressions = (node.Reduce() as BlockExpression).Expressions.Skip(1).SkipLast(1).ToList();//去除不需要的表达式;

            if (resolveType == ResolveType.SqlString)
            {
                for (int i = 0; i < expressions.Count; i++)
                {
                    var binaryExpression = (BinaryExpression)expressions[i];
                    Visit(binaryExpression.Right);
                    var name = AddIdentifier(node.Bindings[i].Member.IsDefined(typeof(ColumnAttribute)) ? node.Bindings[i].Member.GetCustomAttribute().Name : node.Bindings[i].Member.Name);
                    var value = string.Join("", sqlInfo.SqlStack.ToList());
                    sqlInfo.NewKeyValues.Add(name, value);
                    sqlInfo.NewNames.Add(name);
                    sqlInfo.NewValues.Add(value);
                    sqlInfo.NewAssignMapper.Add($"{name} = {value}");
                    sqlInfo.NewAsMapper.Add($"{value} AS {name}");
                    sqlInfo.SqlStack.Clear();
                }
            }
            else if (resolveType == ResolveType.Call)
            {
                var sqlExpressionVisitor = new SqlExpressionVisitor(options, ResolveType.Call);
                var obj = (sqlExpressionVisitor.Visit(node.NewExpression) as ConstantExpression).Value;
                for (int i = 0; i < expressions.Count; i++)
                {
                    if (node.Bindings[i].Member.MemberType == MemberTypes.Field)
                    {
                        var fields = (FieldInfo)node.Bindings[i].Member;
                        fields.SetValue(obj, (sqlExpressionVisitor.Visit((expressions[i] as BinaryExpression).Right) as ConstantExpression).Value);
                    }
                    if (node.Bindings[i].Member.MemberType == MemberTypes.Property)
                    {
                        var propertyInfo = (PropertyInfo)node.Bindings[i].Member;
                        propertyInfo.SetValue(obj, (sqlExpressionVisitor.Visit((expressions[i] as BinaryExpression).Right) as ConstantExpression).Value);
                    }
                }
                return VisitConstant(Expression.Constant(obj));
            }
            return node;
        }

        /// 
        /// 访问成员
        /// 
        /// 节点
        /// 
        protected override Expression VisitMember(MemberExpression node)
        {
            memberInfos.Push(node.Member);
            if (node.Expression == null && node.NodeType == ExpressionType.MemberAccess)
            {
                return Visit(Expression.Constant(null));//子表达式为空但有成员处理
            }
            if (node.Expression.NodeType == ExpressionType.Parameter)
            {
                if (node.Member.IsDefined(typeof(ColumnAttribute)))
                {
                    sqlInfo.SqlStack.Push(AddIdentifier(node.Member.GetCustomAttribute().Name));
                }
                else
                {
                    sqlInfo.SqlStack.Push(AddIdentifier(node.Member.Name));
                }
            }
            return Visit(node.Expression);
        }

        /// 
        /// 访问常量
        /// 
        /// 节点
        /// 
        protected override Expression VisitConstant(ConstantExpression node)
        {
            var value = node.Value;
            if (value != null && value.GetType().FullName.EndsWith("c__DisplayClass0_0"))
            {
                foreach (var member in memberInfos)
                {
                    if (member.MemberType == MemberTypes.Field)
                    {
                        var fieldInfo = (FieldInfo)member;
                        value = fieldInfo.GetValue(value);
                    }
                    else if (member.MemberType == MemberTypes.Property)
                    {
                        var propertyInfo = (PropertyInfo)member;
                        value = propertyInfo.GetValue(value);
                    }
                }
            }

            if (resolveType != ResolveType.Call)
            {
                if (value != null && value.GetType().Equals(typeof(DateTime)))
                {
                    value = Convert.ToDateTime(value).ToString(options.DateTimeFormat);//格式化日期类型
                }
                if (memberInfos.Count > 0)
                {
                    var parameterName = Guid.NewGuid().ToString().Replace("-", "");
                    sqlInfo.SqlStack.Push($"{options.ParameterNotation}{parameterName}");
                    sqlInfo.SqlParameters.Add(parameterName, AddWildcard(value));
                }
                else
                {
                    value = AddQuotes(AddWildcard(value));
                    sqlInfo.SqlStack.Push(Convert.ToString(value));
                }
            }

            //最后一个节点初始化
            memberInfos.Clear();
            return Expression.Constant(value);//返回新的常量表达式
        }

        #region 私有公共方法

        /// 
        /// 添加引号
        /// 
        /// 
        /// 
        private static object AddQuotes(object value)
        {
            if (value == null)
            {
                return null;
            }
            var type = value.GetType();
            if (type.Equals(typeof(char)) || type.Equals(typeof(string)) || type.Equals(typeof(DateTime)))
            {
                return $"'{value}'";
            }
            return value;
        }

        /// 
        /// 添加识别符
        /// 
        /// 名称
        /// 
        private string AddIdentifier(string name)
        {
            if (string.IsNullOrWhiteSpace(options.Identifier))
            {
                return name;
            }
            return options.Identifier.Insert(1, name);
        }

        /// 
        /// 添加通配符
        /// 
        /// 
        /// 
        private object AddWildcard(object value)
        {
            if (resolveType == ResolveType.LeftLike)
            {
                return $"%{value}";
            }
            else if (resolveType == ResolveType.RightLike)
            {
                return $"{value}%";
            }
            else if (resolveType == ResolveType.Like)
            {
                return $"%{value}%";
            }
            return value;
        }
        #endregion

    }
}

 

https://gitee.com/China-Mr-zhong/Fast.Framework 欢迎Star

 

你可能感兴趣的:(玩转NET Expression)