Dapper就是一堆Connection的扩展方法,我们也用相同的方法实现,为了练习反射写的,原创~
使用技术:泛型、反射、表达式树...
客户端调用:
static void Main(string[] args) { var connection = new SqlConnection("Data Source=.;User Id=sa;Password=123456;Database=fanDB;"); //增 connection.Insert(new Person() { Name = "fan11", Age = 1 }); connection.Insert(new List{ new Person() { Name = "fan432", Age = 24 }, new Person() { Name = "fan", Age = 4 } }); //删 connection.Delete (5); connection.Delete(new Person() { ID = 6 }); //改 connection.Update(new Person() { ID = 17, Name = "fanfan", Age = 18 }); //查 var list = connection.Select (p => p.Name == "fan" || p.Name.Contains("fan1") || p.Name.StartsWith("fan") || p.Name.EndsWith("fan") && p.Age > 3); Console.ReadKey(); }
ORM:
public static class ORM { private const string ID_NAME = "ID"; private const string INSERT_SQL = "INSERT INTO @TABLE_NAME(@COLUMNS) VALUES(@VALUES)"; private const string SELECT_SQL = "SELECT * FROM @TABLE_NAME WHERE @WHERE"; private const string DELETE_SQL = "DELETE FROM @TABLE_NAME WHERE @WHERE"; private const string UPDATE_SQL = "UPDATE @TABLE_NAME SET @UPDATE_COLUMNS WHERE @WHERE"; private static readonly ConcurrentDictionaryPROPERTIES_CACHE = new System.Collections.Concurrent.ConcurrentDictionary (); private static readonly WhereBuilder WHERE_BUILDER = null;//通过Expression生成where static ORM() { WHERE_BUILDER = new WhereBuilder('[', ']'); } public static int Insert (this SqlConnection connection, T entity) { int result = 0; var t = typeof(T); var tableName = t.Name; var columnInfoList = GetColumnInfos(entity); var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME); var columnNames = excludeIDColumns.Select(c => c.Name); var columnParameterNames = excludeIDColumns.Select(c => "@" + c.Name); string sql = INSERT_SQL.Replace("@TABLE_NAME", tableName) .Replace("@COLUMNS", string.Join(',', columnNames)) .Replace("@VALUES", string.Join(',', columnParameterNames)); SqlParameter[] paras = excludeIDColumns.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray(); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } public static int Insert (this SqlConnection connection, List list) { int result = 0; foreach (var entity in list) { result += connection.Insert(entity); } return result; } public static List Select (this SqlConnection connection, Expression bool>> whereExp) where T : new() { List list = new List (); var t = typeof(T); var tableName = t.Name; var wherePart = WHERE_BUILDER.ToSql (whereExp); var whereParameter = wherePart.Parameters; var paras = whereParameter.Select(p => new SqlParameter(p.Key, p.Value)).ToArray(); string sql = SELECT_SQL.Replace("@TABLE_NAME", tableName) .Replace("@WHERE", wherePart.Sql); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); using (var reader = command.ExecuteReader()) { while (reader.Read()) { list.Add(ReaderToEntity (reader)); } } } return list; } public static int Delete (this SqlConnection connection, int ID) { int result = 0; var t = typeof(T); var tableName = t.Name; string sql = DELETE_SQL .Replace("@TABLE_NAME", tableName) .Replace("@WHERE", $"{ID_NAME}=@{ID_NAME}"); SqlParameter[] paras = new SqlParameter[] { new SqlParameter("@" + ID_NAME, ID) }; OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } public static int Delete (this SqlConnection connection, T entity) { var IDProperty = entity.GetType().GetProperty(ID_NAME); int ID = (int)IDProperty.GetValue(entity); return connection.Delete (ID); } public static int Update (this SqlConnection connection, T entity) { int result = 0; var t = typeof(T); var tableName = t.Name; var columnInfoList = GetColumnInfos(entity); var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME); var columnNames = excludeIDColumns.Select(c => c.Name); var columnParameters = excludeIDColumns.Select(c => c.Name + "=@" + c.Name); string sql = UPDATE_SQL.Replace("@TABLE_NAME", tableName) .Replace("@UPDATE_COLUMNS", string.Join(',', columnParameters)) .Replace("@WHERE", $"{ID_NAME}=@ID"); SqlParameter[] paras = columnInfoList.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray(); OpenConnection(connection); using (var command = connection.CreateCommand()) { command.CommandType = CommandType.Text; command.CommandText = sql; command.Parameters.AddRange(paras); result = command.ExecuteNonQuery(); } return result; } private static T ReaderToEntity (SqlDataReader reader) where T : new() { var entity = Activator.CreateInstance(typeof(T)); var propertyInfos = GetPropertys (); foreach (var propertyInfo in propertyInfos) { var value = reader[propertyInfo.Name]; propertyInfo.SetValue(entity, value); } return (T)entity; } private static PropertyInfo[] GetPropertys () { return PROPERTIES_CACHE.GetOrAdd(typeof(T), t => { return t.GetProperties(); }); } private static List GetColumnInfos (T entity) { var t = entity.GetType(); var columnInfos = new List (); var properties = GetPropertys (); for (int i = 0; i < properties.Length; i++) { var prop = properties[i]; columnInfos.Add(new ColumnInfo(prop.Name, prop.PropertyType.FullName, prop.GetValue(entity))); } return columnInfos; } private static DbType GetDbType(string typeName) { DbType type = DbType.String; switch (typeName) { case "System.String": type = DbType.String; break; case "System.Int32": type = DbType.Int32; break; case "System.Decimal": type = DbType.Decimal;break; //其他类型自己扩展,我就不加了 Guid、DateTime... } return type; } private static void OpenConnection(IDbConnection connection) { if (connection.State != ConnectionState.Open) { connection.Open(); } } } public class ColumnInfo { public ColumnInfo(string name, string typeName, object value) { this.Name = name; this.TypeName = typeName; this.Value = value; } public string Name { get; set; } public string TypeName { get; set; } public object Value { get; set; } }
WhereBuilder:将表达式树转成where子句(从第三方扒下来的)
using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Text; ////// 生成Where条件的SQL语句 /// Generating SQL from expression trees /// public class WhereBuilder { private readonly char _columnBeginChar = '['; private readonly char _columnEndChar = ']'; private System.Collections.ObjectModel.ReadOnlyCollection expressParameterNameCollection; public WhereBuilder(char columnChar = '`') { this._columnBeginChar = this._columnEndChar = columnChar; } public WhereBuilder(char columnBeginChar = '[', char columnEndChar = ']') { this._columnBeginChar = columnBeginChar; this._columnEndChar = columnEndChar; } /// /// LINQ转SQL /// /// /// /// public WherePart ToSql (Expression bool>> expression) { var i = 1; if (expression.Parameters.Count > 0) { this.expressParameterNameCollection = expression.Parameters; } return Recurse(ref i, expression.Body, isUnary: true); } /// /// LINQ转SQL /// /// /// 种子值 /// /// public WherePart ToSql (ref int i, Expression bool>> expression) { if (expression.Parameters.Count > 0) { this.expressParameterNameCollection = expression.Parameters; } return Recurse(ref i, expression.Body, isUnary: true); } /// /// LINQ转SQL /// /// 种子值 /// /// /// /// /// private WherePart Recurse(ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null) { //运算符表达式 if (expression is UnaryExpression) { var unary = (UnaryExpression)expression; //示例:m.birthday=DateTime.Now if (unary.NodeType == ExpressionType.Convert) { var value = GetValue(expression); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } else { //示例:m.Birthday>'2018-10-31' return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true)); } } if (expression is BinaryExpression) { var body = (BinaryExpression)expression; return WherePart.Concat(Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right)); } //常量值表达式 //示例右侧表达式:m.ID=123; if (expression is ConstantExpression) { var constant = (ConstantExpression)expression; var value = constant.Value; if (value is int) { return WherePart.IsSql(value.ToString()); } if (value is string) { value = prefix + (string)value + postfix; } if (value is bool && isUnary) { return WherePart.Concat(WherePart.IsParameter(i++, value), "=", WherePart.IsSql("1")); } return WherePart.IsParameter(i++, value); } //成员表达式 if (expression is MemberExpression) { var member = (MemberExpression)expression; var memberExpress = member.Expression; bool isContainsParameterExpress = false; this.IsContainsParameterExpress(member, ref isContainsParameterExpress); if (member.Member is PropertyInfo && isContainsParameterExpress) { var property = (PropertyInfo)member.Member; //var colName = _tableDef.GetColumnNameFor(property.Name); var colName = property.Name; if (isUnary && member.Type == typeof(bool)) { return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++, true)); } return WherePart.IsSql(string.Format("{0}{1}{2}", this._columnBeginChar, colName, this._columnEndChar)); } if (member.Member is FieldInfo || !isContainsParameterExpress) { var value = GetValue(member); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception($"Expression does not refer to a property or field: {expression}"); } //方法表达式 if (expression is MethodCallExpression) { var methodCall = (MethodCallExpression)expression; //属性表达式中的参数表达式是否是表达式参数集合中的实例(或者表达式中包含的其他表达式中的参数表达式) bool isContainsParameterExpress = false; this.IsContainsParameterExpress(methodCall, ref isContainsParameterExpress); if (isContainsParameterExpress) { // LIKE queries: if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%", postfix: "%")); } if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], postfix: "%")); } if (methodCall.Method == typeof(string).GetMethod("EndsWith", new[] { typeof(string) })) { return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%")); } // IN queries: if (methodCall.Method.Name == "Contains") { Expression collection; Expression property; if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 2) { collection = methodCall.Arguments[0]; property = methodCall.Arguments[1]; } else if (!methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 1) { collection = methodCall.Object; property = methodCall.Arguments[0]; } else { throw new Exception("Unsupported method call: " + methodCall.Method.Name); } var values = (IEnumerable)GetValue(collection); return WherePart.Concat(Recurse(ref i, property), "IN", WherePart.IsCollection(ref i, values)); } } else { var value = GetValue(expression); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception("Unsupported method call: " + methodCall.Method.Name); } //New表达式 if (expression is NewExpression) { var member = (NewExpression)expression; var value = GetValue(member); if (value is string) { value = prefix + (string)value + postfix; } return WherePart.IsParameter(i++, value); } throw new Exception("Unsupported expression: " + expression.GetType().Name); } /// /// 判断表达式内部是否含有变量M /// /// 表达式 /// private void IsContainsParameterExpress(Expression expression, ref bool result) { if (this.expressParameterNameCollection != null && this.expressParameterNameCollection.Count > 0 && expression != null) { if (expression is MemberExpression) { if (this.expressParameterNameCollection.Contains(((MemberExpression)expression).Expression)) { result = true; } } else if (expression is MethodCallExpression) { MethodCallExpression methodCallExpression = (MethodCallExpression)expression; if (methodCallExpression.Object != null) { if (methodCallExpression.Object is MethodCallExpression) { //判断示例1:m.ID.ToString().Contains("123") this.IsContainsParameterExpress(methodCallExpression.Object, ref result); } else if (methodCallExpression.Object is MemberExpression) { //判断示例2:m.ID.Contains(123) MemberExpression MemberExpression = (MemberExpression)methodCallExpression.Object; if (MemberExpression.Expression != null && this.expressParameterNameCollection.Contains(MemberExpression.Expression)) { result = true; } } } //判断示例3: int[] ids=new ids[]{1,2,3}; ids.Contains(m.ID) if (result == false && methodCallExpression.Arguments != null && methodCallExpression.Arguments.Count > 0) { foreach (Expression express in methodCallExpression.Arguments) { if (express is MemberExpression || express is MethodCallExpression) { this.IsContainsParameterExpress(express, ref result); } else if (this.expressParameterNameCollection.Contains(express)) { result = true; break; } } } } } } private static object GetValue(Expression member) { // source: http://stackoverflow.com/a/2616980/291955 var objectMember = Expression.Convert(member, typeof(object)); var getterLambda = Expression.Lambda object>>(objectMember); var getter = getterLambda.Compile(); return getter(); } private static string NodeTypeToString(ExpressionType nodeType) { switch (nodeType) { case ExpressionType.Add: return "+"; case ExpressionType.And: return "&"; case ExpressionType.AndAlso: return "AND"; case ExpressionType.Divide: return "/"; case ExpressionType.Equal: return "="; case ExpressionType.ExclusiveOr: return "^"; case ExpressionType.GreaterThan: return ">"; case ExpressionType.GreaterThanOrEqual: return ">="; case ExpressionType.LessThan: return "<"; case ExpressionType.LessThanOrEqual: return "<="; case ExpressionType.Modulo: return "%"; case ExpressionType.Multiply: return "*"; case ExpressionType.Negate: return "-"; case ExpressionType.Not: return "NOT"; case ExpressionType.NotEqual: return "<>"; case ExpressionType.Or: return "|"; case ExpressionType.OrElse: return "OR"; case ExpressionType.Subtract: return "-"; } throw new Exception($"Unsupported node type: {nodeType}"); } } public class WherePart { /// /// 含有参数变量的SQL语句 /// public string Sql { get; set; } /// /// SQL语句中的参数变量 /// public Dictionary<string, object> Parameters { get; set; } = new Dictionary<string, object>(); public static WherePart IsSql(string sql) { return new WherePart() { Parameters = new Dictionary<string, object>(), Sql = sql }; } public static WherePart IsParameter(int count, object value) { return new WherePart() { Parameters = { { count.ToString(), value } }, Sql = $"@{count}" }; } public static WherePart IsCollection(ref int countStart, IEnumerable values) { var parameters = new Dictionary<string, object>(); var sql = new StringBuilder("("); foreach (var value in values) { parameters.Add((countStart).ToString(), value); sql.Append($"@{countStart},"); countStart++; } if (sql.Length == 1) { sql.Append("null,"); } sql[sql.Length - 1] = ')'; return new WherePart() { Parameters = parameters, Sql = sql.ToString() }; } public static WherePart Concat(string @operator, WherePart operand) { return new WherePart() { Parameters = operand.Parameters, Sql = $"({@operator} {operand.Sql})" }; } public static WherePart Concat(WherePart left, string @operator, WherePart right) { return new WherePart() { Parameters = left.Parameters.Union(right.Parameters).ToDictionary(kvp => kvp.Key, kvp => kvp.Value), Sql = $"({left.Sql} {@operator} {right.Sql})" }; } }