大佬写的orm框架,mark一下稍后阅读

package db

import (
	"database/sql"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"log"
	"reflect"
	"strconv"
	"util"
)

// 数据库操作处理结构体
type DB struct {
	pool  *sql.DB
	prefix string
	table  string
	field  string
	where  string
	group  string
	having string
	order  string
	limit  string
	args   map[string][]interface{}
}

func Open(c map[string]string)(dbPool DB,err error) {
	//var dbPool DB //等同 dbPool = new(DB)
	dsn:= fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s",c["db_user"],c["db_pwd"],c["db_host"],c["db_port"],c["db_name"], "utf8mb4,utf8")
	connPool,err:= sql.Open("mysql", dsn)//创建连接池,此时只是初始化了连接池,并没有连接数据库
	if err!= nil {
		return
	}
	connMaxOpen := util.Str_to_int(c["db_max_open"],10)
	connMaxIdle := util.Str_to_int(c["db_max_idle"],5)
	connPool.SetMaxOpenConns(connMaxOpen)//是设置的最大连接数,也就是甭管你多少并发,只能最多创建N条tcp连接,还有要注意的一点是,当执行完sql,连接转移到rows对象上,如果rows不关闭,这条连接不会被放回池里,其他并发获取不到连接会被阻塞住。
	connPool.SetMaxIdleConns(connMaxIdle)//是设置的执行完闲置的连接,这些就算是执行结束了连接还是会保留着的   简单点理解 上面那个是最大连接数  下面这个是最小连接数
   //connPool.SetConnMaxLifetime(time.Second * 15) //设置连接超时时间 一般不用设置 这样空置的连接会被一直保持 等待被复用
	err = connPool.Ping()
	if err != nil {
		return
    }
	dbPool.pool = connPool
	dbPool.prefix = c["db_prefix"]
	return
}

//表名
func (this DB) Table(name string) DB {
	this.table = fmt.Sprintf("%s%s",this.prefix,name)
	return this
}

//字段
func (this DB) Field(name string) DB {
	this.field = name
	return this
}

//条件语句
func (this DB) Where(query string, params ...interface{}) DB {
	this.where = query
	if(len(this.args)>0){//已有值
		this.args["where"] = params
	}else{
		this.args = map[string][]interface{}{"where":params}
	}
	return this
}

//排序
func (this DB) Order(params string) DB {
	this.order = params
	return this
}

//分组
func (this DB) Group(params string) DB {
	this.group = params
	return this
}

//分组 HAVING
func (this DB) Having(query string, params ...interface{}) DB {
	this.having = query
	if(len(this.args)>0){//已有值
		this.args["having"] = params
	}else{
		this.args = map[string][]interface{}{"having":params}
	}
	return this
}

//限制
func (this DB) Limit(params string) DB {
	this.limit = params
	return this
}

//-------------------------------------------------------------------------------------------------------------------------------------------------
// 获取第一条数据,返回数据类型为map
func (this DB) Get() map[string]interface{} {
	var RetOne map[string]interface{}
	this.limit = "1"
	GetSql := this.build_sql()
	GetArg := this.build_arg()
	rows, err := this.pool.Query(GetSql,GetArg...)
	if err != nil {
		log.Print("Get查询错误 ",GetSql,err)
		return RetOne
	}
	RetMap:= dealMysqlRows(rows)
	if len(RetMap) > 0 {
		RetOne = RetMap[0]
	}
	return RetOne
}
//获取单条数据的特定字段数据
func (this DB) GetVal(field string) interface{} {
	this.field = field
	RetMap:= this.Get()
	RetOne:= RetMap[field]
	return RetOne
}
//获取多条数据
func (this DB) GetList() []map[string]interface{} {
	GetSql := this.build_sql()
	GetArg := this.build_arg()
	rows, err := this.pool.Query(GetSql,GetArg...)
	if err != nil {
		log.Print("GetList查询错误 ",GetSql,err)
		return nil
	}
	RetMap := dealMysqlRows(rows)
	return RetMap
}
//判断是否有对应数据
func (this DB) Has() int{
	GetSql := "SELECT EXISTS(" + this.build_sql() + ")"
	GetArg := this.build_arg()
	rows,err:= this.pool.Query(GetSql,GetArg...)
	if err != nil {
		log.Print("Has查询错误 ",GetSql,err)
		return 0
	}
	RetMap := dealMysqlRows(rows)
	var RetOne int
	for _,v:= range RetMap[0]{
		i,err:= strconv.Atoi(reflect.ValueOf(v).String())
        if(err!=nil){
			RetOne = 0
		}else{
			RetOne = i
		}
	}
	return RetOne
}
//新增
func (this DB) Insert(d map[string]interface{})(int,error){
	var Args []interface{}
	InsertTable,InsertCols,InsertArgs := this.table,"",""
	this.table = ""
	for k, v := range d {
		// 数据列只能为string类型
		if InsertCols == "" {
			InsertCols += fmt.Sprintf("%s", k)
			InsertArgs += "?"
		} else {
			InsertCols += fmt.Sprintf(",%s", k)
			InsertArgs += ",?"
		}
		Args = append(Args,v)
	}
	// 组合数据写入SQL
	Sql := fmt.Sprintf("INSERT INTO %v (%v) VALUES (%v);", InsertTable, InsertCols, InsertArgs)
	retData, err1 := this.pool.Exec(Sql,Args...)
	if err1 != nil {
		log.Print("Insert错误 ",Sql,err1)
		return 0, nil
	}
	LastId, err2 := retData.LastInsertId()
	if err2 != nil {
		log.Print("Insert错误 ",Sql,err2)
		return 0, err2
	}
	return int(LastId), err2
}
//更新
func (this DB) Update(d map[string]interface{})(int,error){
	var Args []interface{}
	UpdateTable,UpdateArgs,UpdateWhere,UpdateOrder,UpdateLimit:= this.table,"","","",""
	this.table = ""
	for k, v := range d {
		// 数据列只能为string类型
		if UpdateArgs == "" {
			UpdateArgs += fmt.Sprintf("%s=?", k)
		} else {
			UpdateArgs += fmt.Sprintf(",%s=?", k)
		}
		Args = append(Args,v)
	}
	if len(this.where) > 0 {
		UpdateWhere = fmt.Sprintf(" WHERE %v", this.where)
		this.where = ""
		if(len(this.args)>0) {
			if (len(this.args["where"]) > 0) {
				Args = append(Args, this.args["where"]...)
			}
			this.args = nil
		}
	}else{//如果没有更新条件不更新 提示出错
		log.Print("Update错误 无更新条件")
		return 0, nil
	}
	if(len(this.order)>0) {
		UpdateOrder = fmt.Sprintf(" ORDER BY %v", this.order)
		this.order = ""
	}
	if(len(this.limit)>0) {
		UpdateLimit = fmt.Sprintf(" LIMIT %v", this.limit)
		this.limit = ""
	}
	Sql := fmt.Sprintf("UPDATE %v SET %v %v%v%v;",UpdateTable,UpdateArgs,UpdateWhere,UpdateOrder,UpdateLimit)
	retData, err1 := this.pool.Exec(Sql,Args...)
	if err1 != nil {
		log.Print("Update错误 ",Sql,err1)
		return 0, nil
	}
	ARows, err2 := retData.RowsAffected()
	if err2 != nil {
		log.Print("Update错误 ",Sql,err2)
		return 0, err2
	}
	return int(ARows), nil
}
//删除
func (this DB) Delete()(int,error){
	var Args []interface{}
	DeleteTable,DeleteWhere,DeleteOrder,DeleteLimit:= this.table,"","",""
	this.table = ""
	if len(this.where) > 0 {
		DeleteWhere = fmt.Sprintf(" WHERE %v", this.where)
		this.where = ""
		if(len(this.args)>0) {
			if (len(this.args["where"]) > 0) {
				Args = append(Args, this.args["where"]...)
			}
			this.args = nil
		}
	}else{//如果无条件不删除 提示出错
		log.Print("Delete错误 无删除条件")
		return 0, nil
	}
	if(len(this.order)>0) {
		DeleteOrder = fmt.Sprintf(" ORDER BY %v", this.order)
		this.order = ""
	}
	if(len(this.limit)>0) {
		DeleteLimit = fmt.Sprintf(" LIMIT %v", this.limit)
		this.limit = ""
	}
	// 组合删除数据SQL
	Sql := fmt.Sprintf("DELETE FROM %v%v%v%v;",DeleteTable,DeleteWhere,DeleteOrder,DeleteLimit)
	fmt.Println(Sql)
	retData, err1 := this.pool.Exec(Sql,Args...)
	if err1 != nil {
		log.Print("Delete错误 ",Sql,err1)
		return 0, err1
	}
	ARows, err2 := retData.RowsAffected()
	if err2 != nil {
		log.Print("Delete错误 ",Sql,err2)
		return 0, err2
	}
	return int(ARows), nil
}
//原生
func (this DB) Query(Sql string,Args ...interface{}) []map[string]interface{} {
	rows, err := this.pool.Query(Sql,Args...)
	if err != nil {
		log.Print("Query查询错误 ",Sql,err)
		return nil
	}
	RetMap := dealMysqlRows(rows)
	return RetMap
}
//表名前缀
func (this DB) Pre(table_name string) string{
	return fmt.Sprintf("%s%s",this.prefix,table_name)
}
//获取当前SQL
func (this DB) GetSql() string{
	return this.build_sql()
}
//-------------------------------------------------------------------------------------------------------------------------------------------------

//解析SQL 语句
func (this DB) build_sql() string {
	TableFilter,FieldFilter,WhereFilter,GroupFilter,HavingFilter,OrderFilter,LimitFilter:= this.table,"*","","","","",""
	if(len(TableFilter)>0) {
		this.table = ""
	}else{//如果表名为空 直接输出空
		return ""
	}
	if(len(this.field)>0) {
		FieldFilter = this.field
		this.field = ""
	}
	if len(this.where) > 0 {
		WhereFilter = fmt.Sprintf(" WHERE %v", this.where)
		this.where = ""
	}
	if(len(this.group)>0) {
		GroupFilter = fmt.Sprintf(" GROUP BY %v", this.group)
		this.group = ""
	}
	if(len(this.having)>0) {
		HavingFilter = fmt.Sprintf(" HAVING %v", this.having)
	}
	if(len(this.order)>0) {
		OrderFilter = fmt.Sprintf(" ORDER BY %v", this.order)
		this.order = ""
	}
	if(len(this.limit)>0) {
		LimitFilter = fmt.Sprintf(" LIMIT %v", this.limit)
		this.limit = ""
	}
	Sql := fmt.Sprintf("SELECT %v FROM %v%v%v%v%v%v", FieldFilter, TableFilter, WhereFilter, GroupFilter, HavingFilter, OrderFilter, LimitFilter)
	return Sql
}
//解析SQL 语句占位符 值
func (this DB) build_arg() []interface{} {
	var Args []interface{}
	if(len(this.args)>0) {
		if(len(this.args["where"])>0){
			Args = append(Args,this.args["where"]...)
		}
		if (len(this.args["having"]) > 0) {
			Args = append(Args,this.args["having"]...)
		}
		this.args = nil
	}
	return Args
}

// 数据库返回数据处理,返回数据类型为slice,slice内层为map
func dealMysqlRows(rows *sql.Rows) []map[string]interface{} {
	defer rows.Close()
	columns, _:= rows.Columns()// 获取列名
	columnTypes, _:= rows.ColumnTypes()
	// 获取每列的数据类型
	ColumnTypeMap := make(map[string]string)
	for _, v := range columnTypes {
		ColumnTypeMap[v.Name()] = v.DatabaseTypeName()
	}
	// 定义返回参数的slice
	retValues := make([]sql.RawBytes, len(columns))
	// 定义数据列名的slice
	scanArgs := make([]interface{}, len(retValues))
	// 数据列赋值
	for i := range retValues {
		scanArgs[i] = &retValues[i]
	}
	// 定义返回数据类型slice
	var resList []map[string]interface{}
	// 返回数据赋值
	for rows.Next() {
		// Scan将当前行各列结果填充进指定的各个值中
		_=rows.Scan(scanArgs...)
		// 内层数据格式
		rowMap := make(map[string]interface{})
		for i, colVal := range retValues {
			if colVal != nil {
				keyName := columns[i]
				value := string(colVal)
				// 数据类型转换
				switch ColumnTypeMap[keyName] {
					case "MEDIUMINT","INT","TINYINT":
						newValue, _ := strconv.Atoi(value)//字符串转int
						rowMap[keyName] = newValue
					case "VARCHAR":
						rowMap[keyName] = value
					case "DATETIME":
						//newValue, _ := time.ParseInLocation(value,value,time.Local)
						//rowMap[keyName] = newValue
						rowMap[keyName] = value
					default:
						rowMap[keyName] = value
				}
			}
		}
		resList = append(resList, rowMap)
	}
	return resList
}

大佬写的orm框架,mark一下稍后阅读

你可能感兴趣的:(golang)