封装的golang model

model,算是比较完整核心的,欢迎大家留言

package models

import (
	"database/sql"
	"encoding/json"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"lianxi/log"
	// "log"
	"reflect"
	// "strconv"
	"strings"
	"time"
)

const (
	userName = "root"
	password = "12312312"
	ip       = "127.0.0.1"
	port     = "3306"
	dbName   = "acc_pro"
)

type Model struct {
	link      *sql.DB
	tableName string
	field     string
	allFields []string
	where     string
	order     string
	limit     string
	group     string
	pramas    []interface{}
	tables    map[string]interface{}
}

func M(tableName string) Model {
	var obj Model
	obj.tableName = tableName
	obj.field = "*"
	obj.limit = ""
	obj.order = ""
	obj.group = ""
	obj.where = ` WHERE 1`
	obj.connect()
	obj.getFields()

	return obj
}

func (this *Model) Register(tabName string, tabSrurt interface{}) *Model {
	if this.tables == nil {
		this.tables = make(map[string]interface{})
	}

	if _, ok := this.tables[tabName]; ok == false {
		this.tables[tabName] = tabSrurt
	}
	return this
}

func in_array(need interface{}, needArr []string) bool {
	for _, v := range needArr {
		if need == v {
			return true
		}
	}
	return false
}

func (this *Model) connect() {
	path := strings.Join([]string{userName, ":", password, "@tcp(", ip, ":", port, ")/", dbName, "?charset=utf8"}, "")
	this.link, _ = sql.Open("mysql", path)
	this.link.SetConnMaxLifetime(100 * time.Second)
	this.link.SetConnMaxLifetime(100)
	this.link.SetMaxIdleConns(10)
	if err := this.link.Ping(); err != nil {
		fmt.Println("open database fail", err)
	}
}

func (this *Model) getFields() {
	res, err := this.link.Query("DESC " + this.tableName)
	if err != nil {
		fmt.Printf("getfields %s", err)
	}
	this.allFields = make([]string, 0)

	for res.Next() {
		var field string
		var Type interface{}
		var Null string
		var Key string
		var Default interface{}
		var Extra string
		err := res.Scan(&field, &Type, &Null, &Key, &Default, &Extra)
		if err != nil {
			fmt.Printf("scan fail ! [%s]", err)
		}
		this.allFields = append(this.allFields, field)
	}
}

func formatRes(errno int, errmsg string, res interface{}) string {
	tmp := make(map[string]interface{})
	tmp["errno"] = errno
	tmp["errmsg"] = errmsg
	tmp["result"] = res
	data, _ := json.Marshal(tmp)

	return string(data)
}

func (this *Model) exec(sql string, params []interface{}) interface{} {
	stmt, err := this.link.Prepare(sql)
	defer stmt.Close()
	var idRow int64
	res, err := stmt.Exec(params...)

	if err != nil {
		return formatRes(33061, ``, err.Error())
	}
	if strings.Index(sql, "INSERT") != -1 {
		lastId, err := res.LastInsertId()
		if err != nil {
			return formatRes(33062, ``, err.Error())
		}
		idRow = lastId
	} else if strings.Index(sql, "UPDATE") != -1 || strings.Index(sql, "DELETE") != -1 {
		lastId, err := res.RowsAffected()
		if err != nil {
			return formatRes(33062, ``, err.Error())
		}
		idRow = lastId
	}

	return idRow
}

func (this *Model) Where(where interface{}) *Model {
	// defer func() {
	// 	if err := recover(); err != nil {
	// 		fmt.Fprint(w, "Recovered model error:", err)
	// 	}
	// }()

	switch where.(type) {
	case string:
		this.where += ` AND ` + where.(string)

	case map[string]interface{}:
		for k, v := range where.(map[string]interface{}) {
			// fmt.Println(reflect.TypeOf(v), t, k)
			if ok := in_array(k, this.allFields); ok == true {
				switch v.(type) { //[name:aa,age:[<,80],sex:[in,[n,x]]]
				case int:
					this.where += ` AND ` + k + `=?`
					this.pramas = append(this.pramas, v)
				case int64:
					this.where += ` AND ` + k + `=?`
					this.pramas = append(this.pramas, v)
				case string:
					this.where += ` AND ` + k + `=?`
					this.pramas = append(this.pramas, v)
				case []interface{}:
					switch v.([]interface{})[0] {
					case ">":
						fallthrough
					case ">=":
						fallthrough
					case "<=":
						fallthrough
					case "<":
						this.where += ` AND ` + k + v.([]interface{})[0].(string) + ` ?`
						this.pramas = append(this.pramas, v.([]interface{})[1])
					case "in":
						if _, ok := v.([]interface{})[1].([]interface{}); ok == false {
							panic("must be slice type on second prama of where in!")
						}
						// str := strings.Replace(strings.Trim(fmt.Sprint(v.([]interface{})[1].([]interface{})), "[]"), " ", ",", -1)
						// this.where += ` AND ` + k + ` IN (`+str+`)`
						this.where += ` AND ` + k + ` IN (`
						length := len(v.([]interface{})[1].([]interface{}))
						for i := 0; i < length; i++ {
							this.where += `?,`
						}
						this.pramas = append(this.pramas, v.([]interface{})[1].([]interface{})...)
						this.where = strings.TrimRight(this.where, ",")
						this.where += `)`
					case "between":
						if _, ok := v.([]interface{})[1].([]interface{}); ok == false {
							panic("must be slice type on second prama of where between!")
						}
						this.where += ` AND ` + k + ` BETWEEN ? AND ?`
						this.pramas = append(this.pramas, v.([]interface{})[1].([]interface{})...)
					}
				default:
					panic(fmt.Sprintf("invalid type of where prama:%s", reflect.TypeOf(v)))
				}
			}
		}
	}

	return this
}

func (this *Model) Add(arr map[string]interface{}) interface{} {
	var fields, value string
	params := make([]interface{}, 0, len(arr))

	for key, val := range arr {
		if res := in_array(key, this.allFields); res == true {
			fields += `,` + key
			value += `,?`
			params = append(params, val)
		}
	}

	fields = strings.TrimLeft(fields, `,`)
	value = strings.TrimLeft(value, `,`)

	result := this.exec(`INSERT INTO `+this.tableName+`(`+fields+`) VALUES (`+value+`)`, params)
	if _, ok := result.(int64); ok != true {
		// log.Printf(result.(string))
		plog.Debug("modelDebug.go", result.(string))
		return 0
	}

	return result
}

func (this *Model) Update(option map[string]interface{}) interface{} {
	var set string
	var setPramas []interface{}
	for k, v := range option {
		if ok := in_array(k, this.allFields); ok == true {
			set += k + `=?,`
			setPramas = append(setPramas, v)
		}
	}

	setPramas = append(setPramas, this.pramas...)

	set = strings.TrimRight(set, ",")
	// fmt.Println("---->", `UPDATE `+this.tableName+` SET `+set+` `+this.where, "--->", setPramas)
	// return 0
	result := this.exec(`UPDATE `+this.tableName+` SET `+set+` `+this.where, setPramas)
	if _, ok := result.(int64); ok != true {
		// log.Printf(result.(string))
		plog.Debug("modelDebug.go", result.(string))
		return 0
	}
	return result
}

func (this *Model) Delete() interface{} {
	result := this.exec(`DELETE FROM `+this.tableName+` `+this.where, this.pramas)
	if _, ok := result.(int64); ok != true {
		// log.Printf(result.(string))
		plog.Debug("modelDebug.go", result.(string))
		return 0
	}
	return result
}

func (this *Model) Field(field []string) *Model {
	this.field = strings.Join(field, ",")
	return this
}

func (this *Model) Query(sql string) interface{} {
	rows, err := this.link.Query(sql, this.pramas...)
	defer rows.Close()
	if err != nil {
		fmt.Println("query err:", err)
		return false
	}

	cols, err := rows.Columns()

	scanVal := make([][]byte, len(cols))
	scanField := make([]interface{}, len(cols))
	var result []map[string]string
	for k, _ := range scanVal {
		scanField[k] = &scanVal[k]
	}
	// ref := reflect.ValueOf(this.tables[this.tableName]).Elem()
	// for k, col := range cols {
	// 	scanField[k] = &this.tables[this.tableName].col
	// }

	for rows.Next() {
		rows.Scan(scanField...)
		tmp := make(map[string]string, len(cols))
		for k, rowField := range scanVal {
			tmp[cols[k]] = string(rowField[:])
		}
		result = append(result, tmp)
	}

	if err := rows.Err(); err != nil {
		plog.Debug("modelDebug.log", fmt.Sprintf("rows.Err():%s\n", err))
		return 0
	}

	return result
}

func (this *Model) Limit(limit string) *Model {

	this.limit = ` LIMIT ` + limit
	return this
}

func (this *Model) Order(order string) *Model {
	this.order = ` ORDER BY ` + order
	return this
}

func (this *Model) Group(group string) *Model {
	this.order = ` GROUP BY ` + group
	return this
}

func (this *Model) All() interface{} {
	sql := `SELECT ` + this.field + ` FROM ` + this.tableName + this.where + this.group + this.order + this.limit
	sql = strings.Trim(sql, " ")
	return this.Query(sql)
}

func (this *Model) One() interface{} {
	sql := `SELECT ` + this.field + ` FROM ` + this.tableName + this.where + this.group + this.order + this.limit
	sql = strings.Trim(sql, " ")
	res := this.Query(sql)
	return res.([]map[string]string)[0]
}

func returnFormat(result interface{}) {

}

func (this *Model) Count() interface{} {
	sql := `SELECT count(` + this.field + `) as count FROM ` + this.tableName + this.where + this.group + this.order + this.limit
	sql = strings.Trim(sql, " ")
	var count int64
	err := this.link.QueryRow(sql, this.pramas...).Scan(&count)
	if err != nil {
		fmt.Println("count scan err:", err)
		return err
	}

	return count
}

你可能感兴趣的:(golang)