自己实现一个简单Golang ORM函数库

前言

通过该项目,对go的反射有了更深入的了解。特意记录下。将要使用的sql驱动为github.com/go-sql-driver/mysql

正文

数据库初始化

任何sql操作都离不开初始化,调用sql.Open(dbType,dataSource);即可初始化数据库。但需要注意的是该函数是golang官方的数据库规范接口其具体实现交由第三处理。所以需要在需要初始化的包中导入并初始化第三方包的想关函数。

package db

import (
    "database/sql"
    "fmt"
    "time"

    //init sql
    _ "github.com/go-sql-driver/mysql"
)

//DB db oprate instance
var DB *sql.DB

var sqlType = "mysql"
var dataSource = "root:123456@tcp(localhost)/alming"

func init() {
    DB, _ = sql.Open(sqlType, dataSource)
    DB.SetConnMaxLifetime(time.Minute * 3)
    DB.SetMaxOpenConns(10)
    DB.SetMaxIdleConns(10)
    if err := DB.Ping(); err == nil {
        fmt.Println("Connect success:")
    } else {
        fmt.Println("Connect fail:", err)
    }
}

单结果查询操作

需要解决的问题是如何将Sql操作的结果集映射到struct中。首先看下常规查询操作

//以下代码基本为伪代码,未经测试仅展示流程
type User struct{
    Username string
    Password string
}
rows,err:=DB.Query("select * from user")
user:=new(User)
rows.Next()
rows.Scan(&user.Username,&user.Password)

可以看到,Scan()方法接收的是指针类型参数,所以说要创建一个指针容器用于存放结果集。那么有两个问题:容器内指针是什么类型,容器的大小又是多少。这时我们需要使用rows实例的另一个函数ColumnTypes()它返回一个[]*sql.ColumnType其数组内元素包含每列结果的数据库类型。有了数据库类型就可以根据数据库类型创建go类型参数。而该返回值的个数也就是我们需要创建的容器大小。详情见代码

//aldb.go
//获得结果集所有列信息以创建接收结果的容器
rc, err := rows.ColumnTypes()
if err != nil {
    log.Println("Get column types fail")
}
container := createContainer(rc)
if !rows.Next() {
    return false
}
column, _ := rows.Columns()
rows.Scan(container...)
success := mapResult(container, column, rs)
//dbutil.go
func createContainer(columnTyes []*sql.ColumnType) (params []interface{}) {
    params = make([]interface{}, len(columnTyes))
    for i, ct := range columnTyes {
        params[i] = createSlot(ct.DatabaseTypeName())
    }
    return
}
//这里也是仅列出了常用的类型,如需扩展再进行类型添加
func createSlot(dbType string) interface{} {
    switch dbType {
    case "INT", "TINYINT", "BIGINT":
        return new(int)
    case "MEDIUMINT":
    case "DOUBLE":
        return new(float32)
    case "DECIMAL":
    case "CHAR":
        return new(byte)
    case "VARCHAR", "TEXT", "LONGTEXT":
        return &sql.NullString{String: "", Valid: true}
    case "BIT":
        return new(interface{})
    case "DATE":
        return &sql.NullString{String: "", Valid: false}
    case "DATETIME":
        return &sql.NullString{String: "", Valid: false}
    case "TIMESTAMP":
        return &sql.NullString{String: "", Valid: false}
    }
    return nil
}

这里有一个坑就是想要映射为golang的string类型时需要使用sql.NullString,否则当驱动扫描到一个值为NULL的列时将不会继续扫描后面的结果将会获取不到

另外单结果查询我们还需要判断结果集是否为多个,因为有些业务只允许返回一个结果集,返回多个视为错误。实现起来也非常简单

if rows.Next() {
    panic("QueryOne except one result but get no more one")
}

多结果查询操作

多结果查询与单结果类似,只是在单结果上多了一个for循环

rc, err := rows.ColumnTypes()
if err != nil {
    log.Println("Get column types fail")
}

column, _ := rows.Columns()
var oneMoreSet bool = false
for rows.Next() {
    container := createContainer(rc)
    err = rows.Scan(container...)
    if err != nil {
        panic("Scan rows error")
    }
    oneMoreSet = mapResult(container, column, rs)
}

结果集映射

可以看到前文中mapResult(container, column, rs)即为结果集映射函数,多结果与单结果共用一个函数,内部通过if判断区分以写操作。在接下来的源码中您可能会看到toPascalCase(columns[i])函数,该函数是一个工具函数它将sql列命映射成为Golang命名规范的变量命方便使用反射。映射规则是将首字母大写,_后第一个字母大写,其源码为

func toPascalCase(src string) string {
    var dst = make([]uint8, 0)
    if src[0] > 96 && src[0] < 123 {
        dst = append(dst, src[0]-32)
    } else {
        dst = append(dst, src[0])
    }
    for i := 1; i < len(src); {
        if src[i] == '_' {
            if src[0] > 96 && src[0] < 123 {
                dst = append(dst, src[i+1]-32)
            }
            i += 2
        } else {
            dst = append(dst, src[i])
            i++

        }
    }
    return string(dst)
}

然后继续看映射部分

//mapResult 将sql rows扫描到的数据填入给定的结构中(结构体或slice)
//container :单条结果容器,columns 结果集对应数据库中的列名,value
//被映射对象
func mapResult(container []interface{}, columns []string, value reflect.Value) bool {
    var slot reflect.Value
    var arr = make([]reflect.Value, 0)
    //判断待映射类型,结构以与slice分别处理
    if value.Elem().Kind() == reflect.Struct {
        slot = value.Elem()
    } else {
        //slice内数据类型的实例
        slot = reflect.New(value.Type().Elem().Elem()).Elem()
    }
    var oneMoreSet = false
    //遍历一行结果集找到其在结构体中的位置并赋值
    for i, v := range container {
        //找到对应结构体的属性
        slotField := slot.FieldByName(toPascalCase(columns[i]))
        if slotField.CanSet() {
            switch value := v.(type) {
            case *int:
                //只有与其结构体类型匹配才赋值
                if slotField.Kind() == reflect.Int {
                    slotField.SetInt(int64(*value))
                }
                oneMoreSet = true
            case *string:
                if slotField.Kind() == reflect.String {
                    slotField.SetString(*value)
                }
                oneMoreSet = true
            case *sql.NullString:
                if slotField.Kind() == reflect.String {
                    slotField.SetString(value.String)
                }
                oneMoreSet = true
            }
        }
    }
    //如果被映射对象是slice也就是多结果集映射要通过反射将映射出的
    //结构体实例追加到结果集中
    if value.Elem().Kind() == reflect.Slice {
        arr = append(arr, slot)
        added := reflect.Append(value.Elem(), arr...)
        value.Elem().Set(added)
    }
    return oneMoreSet
}

插入更新操作

这部分我实现了一个自定义SQL格式,使用时需按该格式编写sql。规定:sql中参数都使用:数据库列名代替,它看起来是下面这样

update user set username=:username where id=:id

使用时会像下面这样

u := user{
    Id:       2,
    Username: "alming_update",
}
Exec(&u, "update user set username=:username where id=:id")

其内部实现原理也非常简单,直接看源码

//Exec excute sql with the params in the struct you give
func Exec(structure interface{}, sqlStr string) (success bool) {
    rs := reflect.ValueOf(structure)
    pointTo := rs.Elem()
    //自定义sql 表达式中 ?由[]:变量名]代替,找到这些变量名并由反射根据改名称获取所给
    //结构体实例当中的数据作为参数传递给Exec函数
    reg, _ := regexp.Compile(`:[a-zA-z_]+`)
    regFind := reg.FindAllString(sqlStr, -1)
    //通过反射创建参数列表的容器
    params := make([]interface{}, len(regFind))
    //通过自定义sql表达式获取sql
    SQLParsed := reg.ReplaceAllString(sqlStr, "?")
    //通过自定义sql中:找到对应的参数
    for i, sqlArgs := range regFind {
        parseArg := strings.TrimPrefix(sqlArgs, `:`)
        fieldName := toPascalCase(parseArg)
        field := pointTo.FieldByName(fieldName)
        switch field.Kind() {
        case reflect.Int:
            //将参数添加到参数容器中
            params[i] = field.Int()
        case reflect.String:
            params[i] = field.String()
        case reflect.Float32, reflect.Float64:
            params[i] = field.Float()
        }
    }
    var res sql.Result
    var err error
    if len(params) > 0 {
        res, err = DB.Exec(SQLParsed, params...)
    } else {
        res, err = DB.Exec(SQLParsed)
    }
    if err == nil {
        rowAf, _ := res.RowsAffected()
        return rowAf > 0
    }
    return false
}

关于一对多问题

该操作实现的过于笨重且限制较多,就不班门弄斧了。感兴趣可以看下源码。

func QueryOneToMany(slice interface{}, sqlStr string, outPk string, inPk string, params ...interface{}) (resMatched bool) {
    defer catchPanic()
    rs := reflect.ValueOf(slice)
    pointTo := rs.Elem()
    if pointTo.Kind() != reflect.Slice {
        panic("QueryOne must to map to a slice,please check your structure parameter")
    }
    var rows *sql.Rows
    var err error
    if len(params) == 0 {
        rows, err = DB.Query(sqlStr)
    } else {
        rows, err = DB.Query(sqlStr, params...)
    }
    if err != nil {
        log.Println("An error occerred when exec query sql", err)
    }

    rc, err := rows.ColumnTypes()
    if err != nil {
        log.Println("Get column types fail")
    }

    column, _ := rows.Columns()
    var allRows = make([][]interface{}, 0)
    for rows.Next() {
        container := createContainer(rc)
        err = rows.Scan(container...)
        if err != nil {
            panic("Scan rows error")
        }
        allRows = append(allRows, container)
    }
    //outPk,对应“一”的主键,inPk对应“多”的主键
    mapRes(allRows, column, rs, 0, outPk, inPk)
    //别忘改
    return true
}

//mapRes 将查询的结果集按一对多形式映射到结构当中
//allRows 所有结果集,columns 结果集对应数据库中的列名,value
//被映射对象,height工具属性与可变参数pk配合使用,pk(primary
//key)设计目的是为了兼容QueryOne与Query的结果集映射。实际
//这两个方法有单独的映射函数
func mapRes(allRows [][]interface{}, columns []string, value reflect.Value, height int, pk ...string) {
    in := value.Elem()
    inType := in.Type().Elem()
    var inSlot reflect.Value
    var inSlotName string
    //查找给定结构的slice属性并为其
    for i := 0; i < inType.NumField(); i++ {
        if inType.Field(i).Type.Kind() == reflect.Slice {
            //记录改属性属性名方便之后通过反射获取改属性并为其赋值
            inSlotName = inType.Field(i).Name
            inSlot = reflect.New(inType.Field(i).Type)
            mapRes(allRows, columns, inSlot, height+1, pk...)
        }
    }
    //mark为一个标识,以sql primary key为map,通过它标识同一元素是否被重复扫描
    mark := make(map[interface{}]byte)
    //主键在column中索引位置,方便获取主键值并配合mark判断是否重复扫描
    var pkIdx = -1
    if len(pk) > 0 {
        pkIdx = getColIndex(columns, pk[height])
    }
    var arr = make([]reflect.Value, 0)
    for _, row := range allRows {
        if mark[pkValue(row[pkIdx])] == 1 {
            continue
        }
        outSlot := reflect.New(inType).Elem()
        var oneMoreSet = false
        for i, v := range row {
            slot := outSlot.FieldByName(toPascalCase(columns[i]))
            if slot.CanSet() {
                switch setValue := v.(type) {
                case *int:
                    if slot.Kind() == reflect.Int {
                        slot.SetInt(int64(*setValue))
                        oneMoreSet = true
                    }
                case *string:
                    if slot.Kind() == reflect.String {
                        slot.SetString(*setValue)
                        oneMoreSet = true
                    }
                case *sql.NullString:
                    if slot.Kind() == reflect.String {
                        slot.SetString(setValue.String)
                        oneMoreSet = true
                    }
                }
            }
        }
        slot := outSlot.FieldByName(inSlotName)
        if slot.CanSet() {
            slot.Set(inSlot.Elem())
        }
        if oneMoreSet {
            if len(pk) > 0 {
                mark[pkValue(row[pkIdx])] = 1
            }
        }
        arr = append(arr, outSlot)
    }
    added := reflect.Append(in, arr...)
    in.Set(added)
}

func getColIndex(colunms []string, col string) int {
    for idx, item := range colunms {
        if item == col {
            return idx
        }
    }
    return -1
}
func pkValue(pkContent interface{}) interface{} {
    switch v := pkContent.(type) {
    case *int:
        return *v
    case *byte:
        return *v
    case *float32:
        return *v
    case *string:
        return *v
    case *sql.NullString:
        return v.String
    default:
        return nil
    }
}

总结

关于Go反射

  1. go反射不像java,go必须在已有实例上进行反射。

  2. go使用反射修改实例内容时需要反射的内容必须为指针类型(可通过CanSet()判断该属性是否可以赋值),并且修改时需要调用Elem()方法获取其指向的元素。

  3. 反射slice添加元素比较复杂详情见代码。

  4. Elem()返回指针所指向的元素,如果是数组类型则返回其内部元素的类型。

  5. 可以通过reflect.New()创建新的实例,但与第一条不冲突(创建实例所需的类型参数由反射已有实例获得)

附录

源代码:alming_backend

一些平台禁止外链 https://github.com/ALMing530/alming_backend

进入该项目db文件夹下查看

你可能感兴趣的:(自己实现一个简单Golang ORM函数库)