beego源码学习-ORM,SQL解释器


配置ORM

type ORMdemoController struct {
    beego.Controller
}

func (this * ORMdemoController) Get(){
    //注册数据驱动
    orm.RegisterDriver("mysql", orm.DRMySQL) // mysql、sqlite3、postgres 这三种是beego默认已经注册过的,所以可以无需设置
    //注册数据库 ORM 必须注册一个别名为 default 的数据库,作为默认使用
    //五个参数:1、数据库别名;2、数据库驱动;3、数据库账户:密码@链接地址/数据库名称;4、最大空闲连接数;5、最大数据连接
    orm.RegisterDataBase("default", "mysql", "root:passwd@tcp(127.0.0.1:3306)/db_user?charset=utf8")
    //设置数据库时区
    //orm.DefaultTimeLoc = time.UTC
    //注册模型
    orm.RegisterModel(new(UserTable))
    //自动创建表:1、默认数据;2、是否开启创建表;3、是否更新表
    orm.RunSyncdb("default", true, true)
    this.Ctx.WriteString("表创建成功")
}

创建Model

  • 定义model时注意事项

1、我们定义结构体作为表,必须要有主键。

2、当 Field 类型为 int, int32, int64, uint, uint32, uint64 时,可以设置字段为自增健。

3、当模型定义里没有主键时,符合int类型且名称为 Id 的 Field 将被视为自增健。

4、属性的首字母最好是大写,设置属性为公开访问性。

5、未定义其他规则时,自动生表时,所有字段都为NOT NULL,id为自增主键,其他都有其类型默认值

type UserTable struct {
    User     int //添加 int型字段ID,作为主键
    Pwd      string
    RealName string
    Age      string
    IdCard   string
    Email    string
    Tel      string
}
  • 使用结构体tag进行表的详细属性设置
type UserTable struct {
    ID       int    `orm:"pk;auto;column(id)"`        //设置主键自增长 字段名为 id
    User     string `orm:"size(15);column(user)"`     //设置长度为15 字段名为 user
    Pwd      string `orm:"size(20);column(pwd)"`      //设置长度为20 字段名为 pwd
    RealName string `orm:"size(10);column(realname)"` //设置长度为10 字段名为 realname
    Age      string `orm:column(age)"`                //设置字段名为age
    IdCard   string `orm:"size(18);column(idcard)"`   //设置长度为18 字段名为 idcard
    Email    string `orm:"size(100);column(email)"`   //设置长度为100 字段名为 email
    Tel      string `orm:"size(11);column(tel)"`      //设置长度为11 字段名为 tel
}
  • 表名

beego中表名默认使用驼峰:UserTable

遇到大写会增加 _,原名称中的下划线保留:UserTable->user_table

//自定义表名
func (u *UserTable) TableName() string {
    return "user"  //表名被改为user
}

//使用RegisterModelWithPrefix为表名设置前缀:prefix_user_table
orm.RegisterModelWithPrefix("prefix_", new(UserTable)) 
  • 关联关系定义
//一对一,rel(one): 创建对应关系的字段:表_id,即对应主键,有唯一约束
type User struct {
    ......
    Profile     *Profile   `orm:"rel(one)"` 
}

//反向一对一,reverse(one): 不会创建字段,可选tag
type Profile struct {
    ......
    User        *User   `orm:"reverse(one)"` 
}

//一对多,rel(fk): 创建对应关系的字段:表名_id,即对应主键,没有约束
type Post struct {
    ......
    User  *User  `orm:"rel(fk)"`
}

//反向一对多,reverse(many): 不会创建字段,写在关系为多的类里
type User struct {
    ......
    Post        []*Post `orm:"reverse(many)"` 
}

//多对多,rel(m2m): 不会创建字段,声明处的表名,为自动创建的表的表名前缀(post)
type Post struct {
    ......
    Tags  []*Tag `orm:"rel(m2m)"`
}

//反向多对多,reverse(many): 不会创建字段,声明处的表名,为自动创建的表的表名后缀+s(tags)
type Tag struct {
    ......
    Posts []*Post `orm:"reverse(many)"` 
}

ps:反向一对多和反向多对多,关键词一样


ORM的使用

  • 一些基本操作
type Ormer interface {
    Read(interface{}, …string) error
    ReadOrCreate(interface{}, string, …string) (bool, int64, error)
    Insert(interface{}) (int64, error)
    InsertMulti(int, interface{}) (int64, error)
    Update(interface{}, …string) (int64, error)
    Delete(interface{}) (int64, error)
    LoadRelated(interface{}, string, …interface{}) (int64, error)
    QueryM2M(interface{}, string) QueryM2Mer
    QueryTable(interface{}) QuerySeter
    Using(string) error
    Begin() error
    Commit() error
    Rollback() error
    Raw(string, …interface{}) RawSeter
    Driver() Driver
}
  • 基本操作符
  • 基本的CRUD
type UserTable struct {......}

func Create(param interface{}) (int, error) {
    return orm.NewOrm().Insert(param)
}
func Update(param interface{}, fields ...string) (int, error) {
    return orm.NewOrm().Update(param, fields...)
}
func Delete(param interface{}, cols ...string) (int, error) {
    return orm.NewOrm().Delete(param, cols...)
}
func Read(md interface{}, cols ...string) error {
    return orm.NewOrm().Read(md, cols...)
}

ORM源码分析

主要流程

  1. 启动应用时,完成orm相关配置的注册:数据库、model、rel

  2. 使用时,实例化orm对象、关联关系处理、进行sql解析

  • 注册ORM相关配置
    //注册数据库 ORM 必须注册一个别名为 default 的数据库,作为默认使用
    //五个参数:1、数据库别名;2、数据库驱动;3、数据库账户:密码@链接地址/数据库名称;4、最大空闲连接数;5、最大数据连接
    orm.RegisterDataBase("default", "mysql", "root:passwd@tcp(127.0.0.1:3306)/db_user?charset=utf8")
    //注册模型
    orm.RegisterModel(new(UserTable))
    //自动创建表:1、默认数据;2、是否开启创建表;3、是否更新表
    orm.RunSyncdb("default", true, true)

注册操作做了啥

  • 注册db
//使用当前驱动的连接配置信息(数据库账户:密码@链接地址/数据库名称),设置数据库连接参数
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
    ......
}
  • 注册model
//===== github.com/astaxie/beego/orm/orm.go =====

func RegisterModel(models ...interface{}) {
    ......
    RegisterModelWithPrefix("", models...)
}
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
    ......
    for _, model := range models {
        registerModel(prefix, model, true)
    }
}
//RegisterModel 和 RegisterModelWithPrefix 都是对 registerModel 的封装
func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
    //通过反射获取模型信息
    val := reflect.ValueOf(model)
    typ := reflect.Indirect(val).Type()
    ......
    //处理表名和前缀,完整路径
    table := getTableName(val)
    if PrefixOrSuffix != "" {
        if isPrefix {
            table = PrefixOrSuffix + table
        } else {
            table = table + PrefixOrSuffix
        }
    }
    // models's fullname is pkgpath + struct name
    name := getFullName(typ)
    if _, ok := modelCache.getByFullName(name); ok {
        fmt.Printf(" model `%s` repeat register, must be unique\n", name)
        os.Exit(2)
    }

    if _, ok := modelCache.get(table); ok {
        fmt.Printf(" table name `%s` repeat register, must be unique\n", table)
        os.Exit(2)
    }

    //通过反射判断是否有id字段,设置主键
    mi := newModelInfo(val)
    if mi.fields.pk == nil {
    outFor:
        for _, fi := range mi.fields.fieldsDB {
            if strings.ToLower(fi.name) == "id" {
                switch fi.addrValue.Elem().Kind() {
                case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
                    fi.auto = true
                    fi.pk = true
                    mi.fields.pk = fi
                    break outFor
                }
            }
        }

        if mi.fields.pk == nil {
            fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name)
            os.Exit(2)
        }

    }
    //处理完成的model信息,设置到model缓存中
    mi.table = table
    mi.pkg = typ.PkgPath()
    mi.model = model
    mi.manual = true

    modelCache.set(table, mi)
}

ORM初始化过程

以这样一段orm操作作为示例:

func main() {
    o := orm.NewOrm()
    user := User{Name: "user"}
    u := User{Id: user.Id}
    err = o.Read(&u)
    fmt.Println(err)
}
  • 初始化orm

NewOrm时,先调用BootStrap方法完成模型信息的加载,然后调用Using选中默认数据库:

源码github.com/astaxie/beego/orm/orm.go

type orm struct {
    alias *alias
    db    dbQuerier
    isTx  bool
}

func NewOrm() Ormer {
    //orm的构造函数一开始就会执行启动引导
    BootStrap() // execute only once

    //选择一个默认的数据库驱动
    o := new(orm)
    err := o.Using("default")
    if err != nil {
        panic(err)
    }
    return o
}

BootStrap方法:加载缓存中的model实例,把model中定义的字段跟数据表中的字段进行绑定关联,无缓存会进行首次初始化。

Using方法:读取缓存的数据库信息,赋值给 orm 的 alias 属性,之后就可以进行正常的crud操作了

//====== orm/orm.go ======
func (o *orm) Using(name string) error {
    ......
    //将数据库缓存信息赋值给 orm 的 alias 属性
    if al, ok := dataBaseCache.get(name); ok {
        o.alias = al
        if Debug {//debug模式下开启sql日志
            o.db = newDbQueryLog(al, al.DB)
        } else {
            o.db = al.DB
        }
    } else {
        return fmt.Errorf(" unknown db alias name `%s`", name)
    }
    return nil
}

//============ orm/models_boot.go ============
func BootStrap() {
    //对model的操作缓存进行加锁,保证只会执行一次
    modelCache.Lock()
    defer modelCache.Unlock()
    if modelCache.done {
        return
    }
    bootStrap()
    modelCache.done = true
}

func bootStrap() {
    ......
    var (
        err    error
        models map[string]*modelInfo
    )
    
    //遍历模型的字段
    models = modelCache.all()
    for _, mi := range models {
        for _, fi := range mi.fields.columns {
            //如果有设置rel或者reverse关系
            if fi.rel || fi.reverse {
                elm := fi.addrValue.Type().Elem()
                if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
                    elm = elm.Elem()
                }
                //检查关联的模型是否注册,并获取模型信息
                name := getFullName(elm)
                mii, ok := modelCache.getByFullName(name)
                if !ok || mii.pkg != elm.PkgPath() {
                    err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
                    goto end
                }
                fi.relModelInfo = mii

                switch fi.fieldType {
                case RelManyToMany:
                    //如果是多对多关系,且声明了中间表,则根据中间表获取关联表信息;并为当前字段设置模型关联信息
                    if fi.relThrough != "" {
                        if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
                            pn := fi.relThrough[:i]
                            rmi, ok := modelCache.getByFullName(fi.relThrough)
                            if !ok || pn != rmi.pkg {
                                err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
                                goto end
                            }
                            fi.relThroughModelInfo = rmi
                            fi.relTable = rmi.table
                        } else {
                            err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
                            goto end
                        }
                    } else {
                        //未主动声明中间表关系,创建新的关联模型实例,判断关联表是否被注册
                        i := newM2MModelInfo(mi, mii)
                        if fi.relTable != "" {
                            i.table = fi.relTable
                        }
                        if v := modelCache.set(i.table, i); v != nil {
                            err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
                            goto end
                        }
                        fi.relTable = i.table
                        fi.relThroughModelInfo = i
                    }

                    fi.relThroughModelInfo.isThrough = true
                }
            }
        }
    }

    //后面是一些模板代码
    //遍历字段关联,自动生成字段反向关联信息
    models = modelCache.all()
    for _, mi := range models {
        for _, fi := range mi.fields.fieldsRel {......}
    }
    //遍历字段关联,检查反向多对多关联关系设置
    models = modelCache.all()
    for _, mi := range models {
        for _, fi := range mi.fields.fieldsRel {......}
    }
    //遍历字段反向关联,检查字段是否在模型中设置
    models = modelCache.all()
    for _, mi := range models {
        for _, fi := range mi.fields.fieldsReverse {.....}
    }
}
  • 当我们使用QueryTable设置表时
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
    var name string
    if table, ok := ptrStructOrTableName.(string); ok {
        //如果是指针结构体或表名,根据命名策略,到模型缓存中获取模型
        name = nameStrategyMap[defaultNameStrategy](table)
        if mi, ok := modelCache.get(name); ok {
            qs = newQuerySet(o, mi)
        }
    } else {
        //否则,通过反射获取到表名,再根据表名获取缓存中的模型
        name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
        if mi, ok := modelCache.getByFullName(name); ok {
            qs = newQuerySet(o, mi)
        }
    }
    //最后返回一个QuerySet
    return
}
  • QuerySeter

QuerySeter接口定义了一系列的查询方法:

//====== orm/types.go ======
type QuerySeter interface {
    // 筛选条件 where 
    Filter(string, ...interface{}) QuerySeter
    // 原生过滤语句
    // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
    FilterRaw(string, string) QuerySeter
    // 排除筛选条件 where not in
    Exclude(string, ...interface{}) QuerySeter
    // 设置单个筛选条件
    SetCond(*Condition) QuerySeter
    // 获取指定的筛选条件
    GetCond() *Condition
    // 分页
    Limit(limit interface{}, args ...interface{}) QuerySeter
    Offset(offset interface{}) QuerySeter
    // 分组
    GroupBy(exprs ...string) QuerySeter
    // 排序
    OrderBy(exprs ...string) QuerySeter
    // 模型关联查询
    RelatedSel(params ...interface{}) QuerySeter
    // 去重
    Distinct() QuerySeter
    // 给构造器设置FOR UPDATE
    ForUpdate() QuerySeter
    // 计数
    Count() (int64, error)
    // 是否存在
    Exist() bool
    // 更新
    Update(values Params) (int64, error)
    // 删除
    Delete() (int64, error)
    // 返回一个插入queryer
    PrepareInsert() (Inserter, error)
    // 查询多条数据
    All(container interface{}, cols ...string) (int64, error)
    // 查询单条数据
    One(container interface{}, cols ...string) error
    // 查询多条结果并将结果存入字符串类型的map指针变量中
    Values(results *[]Params, exprs ...string) (int64, error)
    // 查询多条结果并将结果存入接口类型的map指针变量中
    ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
    // 将所有结果集存入map变量中,没有字段名
    ValuesFlat(result *ParamsList, expr string) (int64, error)
    // ptrStruct存放查询结果的指针map变量, keyCol查询字段名,valueCol字段值
    RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
    // ptrStruct存放查询结果的结构体指针变量, keyCol查询字段名,valueCol字段值
    RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
}
  • QuerySeter接口的实现
//====== orm/orm_queryset.go ======
type querySet struct {
    mi         *modelInfo
    cond       *Condition
    related    []string
    relDepth   int
    limit      int64
    offset     int64
    groups     []string
    orders     []string
    distinct   bool
    forupdate  bool
    orm        *orm
    ctx        context.Context
    forContext bool
}

func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
    o := new(querySet)
    o.mi = mi
    o.orm = orm
    return o
}

//这里就以 All 方法的实现说明一下,其他的类似
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
    //All最终会调用DbBaser接口提供的ReadBatch方法
    return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
  • dbBaser 接口
//====== orm/types.go ======
type dbBaser interface {
    Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
    Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
    InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
    InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
    InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
    InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
    Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
    Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
    ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
    ......
}
  • DbBaser 接口 ReadBatch 方法的实现
//====== orm/db.go ======
//ReadBatch方法的实现
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {

    val := reflect.ValueOf(container)
    ind := reflect.Indirect(val)

    errTyp := true
    one := true
    isPtr := true
    
    //通过反射判断容器类型,标记是否指针类型
    if val.Kind() == reflect.Ptr {
        fn := ""
        if ind.Kind() == reflect.Slice {
            one = false
            typ := ind.Type().Elem()
            switch typ.Kind() {
            case reflect.Ptr:
                fn = getFullName(typ.Elem())
            case reflect.Struct:
                isPtr = false
                fn = getFullName(typ)
            }
        } else {
            fn = getFullName(ind.Type())
        }
        errTyp = fn != mi.fullName
    }

    if errTyp {
        if one {
            panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
        } else {
            panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
        }
    }
    //从querySeter中获取分页信息
    rlimit := qs.limit
    offset := qs.offset

    Q := d.ins.TableQuote()

    var tCols []string
    if len(cols) > 0 {
        //判断模型是否有关联关系
        hasRel := len(qs.related) > 0 || qs.relDepth > 0
        tCols = make([]string, 0, len(cols))
        var maps map[string]bool
        if hasRel {
            maps = make(map[string]bool)
        }
        //从模型信息获取所有字段
        for _, col := range cols {
            if fi, ok := mi.fields.GetByAny(col); ok {
                tCols = append(tCols, fi.column)
                if hasRel {
                    maps[fi.column] = true
                }
            } else {
                return 0, fmt.Errorf("wrong field/column name `%s`", col)
            }
        }
        //如果有关联关系,从模型信息中读取关联字段,并赋值给tCols,完成数据库字段映射
        if hasRel {
            for _, fi := range mi.fields.fieldsDB {
                if fi.fieldType&IsRelField > 0 {
                    if !maps[fi.column] {
                        tCols = append(tCols, fi.column)
                    }
                }
            }
        }
    } else {
        tCols = mi.fields.dbcols
    }

    colsNum := len(tCols)
    sep := fmt.Sprintf("%s, T0.%s", Q, Q)
    sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)

    //初始dbTables对象,解析关联关系
    tables := newDbTables(mi, d.ins)
    tables.parseRelated(qs.related, qs.relDepth)

    //构建sql语句各节点:where、groupBy、limit、join
    //后续深入sql组装
    where, args := tables.getCondSQL(cond, false, tz)
    groupBy := tables.getGroupSQL(qs.groups)
    orderBy := tables.getOrderSQL(qs.orders)
    limit := tables.getLimitSQL(mi, offset, rlimit)
    join := tables.getJoinSQL()

    for _, tbl := range tables.tables {
        if tbl.sel {
            colsNum += len(tbl.mi.fields.dbcols)
            sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
            sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
        }
    }

    //组装sql
    sqlSelect := "SELECT"
    if qs.distinct {
        sqlSelect += " DISTINCT"
    }
    query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)

    if qs.forupdate {
        query += " FOR UPDATE"
    }

    d.ins.ReplaceMarks(&query)

    var rs *sql.Rows
    var err error
    if qs != nil && qs.forContext {
        rs, err = q.QueryContext(qs.ctx, query, args...)
        if err != nil {
            return 0, err
        }
    } else {
        rs, err = q.Query(query, args...)
        if err != nil {
            return 0, err
        }
    }

    refs := make([]interface{}, colsNum)
    for i := range refs {
        var ref interface{}
        refs[i] = &ref
    }

    defer rs.Close()

    slice := ind

    var cnt int64
    //如果存在下一个结果行
    //通过反射获取字段的类型信息,并设置给data结构体的属性
    //遍历dbTables实例,循环tables和models
    //通过反射获取字段发名称、值、索引
    for rs.Next() { 
        //if嵌套看着头大,就不伤害大家的眼睛了
        if one && cnt == 0 || !one {......}
    }

    //如果存在数据行,将slice赋值给ind,否则根据ind的类型创建一个空切片并赋值
    if !one {
        if cnt > 0 {
            ind.Set(slice)
        } else {
            // when a result is empty and container is nil
            // to set a empty container
            if ind.IsNil() {
                ind.Set(reflect.MakeSlice(ind.Type(), 0, 0))
            }
        }
    }

    return cnt, nil
}
  • getOrderSQL方法,组装排序sql
//====== orm/db_tables.go ======
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
    if len(orders) == 0 {
        return
    }

    Q := t.base.TableQuote()
    
    //根据orders切片长度,生成一个新的切片用于保存sql
    orderSqls := make([]string, 0, len(orders))
    for _, order := range orders {
        //设置排序方式,以'-'开头为降序,否则为升序
        asc := "ASC"
        if order[0] == '-' {
            asc = "DESC"
            order = order[1:]
        }
        //根据ExprSep分隔符常量,将order字符串拆分为切片
        exprs := strings.Split(order, ExprSep)
        //解析模型字段
        index, _, fi, suc := t.parseExprs(t.mi, exprs)
        if !suc {
            panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
        }
        //将解析结果格式化为字符串后追加到orderSqls切片
        orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
    }
    //将sql结果以逗号拼接为支付串后,组装"ORDER BY "
    orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
    return
}

你可能感兴趣的:(beego源码学习-ORM,SQL解释器)