聊聊gorm的Model

本文主要研究一下gorm的Model

Model

gorm.io/[email protected]/model.go

// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
//    type User struct {
//      gorm.Model
//    }
type Model struct {
    ID        uint `gorm:"primarykey"`
    CreatedAt time.Time
    UpdatedAt time.Time
    DeletedAt DeletedAt `gorm:"index"`
}
Model定义了ID、CreatedAt、UpdatedAt、DeletedAt属性

ParseField

gorm.io/[email protected]/schema/field.go

func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
    var err error

    field := &Field{
        Name:                   fieldStruct.Name,
        BindNames:              []string{fieldStruct.Name},
        FieldType:              fieldStruct.Type,
        IndirectFieldType:      fieldStruct.Type,
        StructField:            fieldStruct,
        Creatable:              true,
        Updatable:              true,
        Readable:               true,
        Tag:                    fieldStruct.Tag,
        TagSettings:            ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
        Schema:                 schema,
        AutoIncrementIncrement: 1,
    }

    for field.IndirectFieldType.Kind() == reflect.Ptr {
        field.IndirectFieldType = field.IndirectFieldType.Elem()
    }

    fieldValue := reflect.New(field.IndirectFieldType)
    // if field is valuer, used its value or first fields as data type
    valuer, isValuer := fieldValue.Interface().(driver.Valuer)
    
    //......

    field.GORMDataType = field.DataType

    if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
        field.DataType = DataType(dataTyper.GormDataType())
    }

    if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
        if strings.ToUpper(v) == "NANO" {
            field.AutoCreateTime = UnixNanosecond
        } else if strings.ToUpper(v) == "MILLI" {
            field.AutoCreateTime = UnixMillisecond
        } else {
            field.AutoCreateTime = UnixSecond
        }
    }

    if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
        if strings.ToUpper(v) == "NANO" {
            field.AutoUpdateTime = UnixNanosecond
        } else if strings.ToUpper(v) == "MILLI" {
            field.AutoUpdateTime = UnixMillisecond
        } else {
            field.AutoUpdateTime = UnixSecond
        }
    }

    //......

    return field
}
ParseField方法会解析field的属性,如果field的name为CreatedAt或者UpdatedAt,且dataType为Time、Int、Unit或者tag标注了AUTOCREATETIME或者AUTOUPDATETIME,则会设置field.AutoCreateTime或者field.AutoUpdateTime

TimeType

gorm.io/[email protected]/schema/field.go

type TimeType int64

const (
    UnixSecond      TimeType = 1
    UnixMillisecond TimeType = 2
    UnixNanosecond  TimeType = 3
)
field.AutoCreateTime、AutoUpdateTime属性为TimeType类型,该类型有UnixSecond、UnixMillisecond、UnixNanosecond三种类型

ConvertToCreateValues

gorm.io/[email protected]/callbacks/create.go

// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
    switch value := stmt.Dest.(type) {
    case map[string]interface{}:
        values = ConvertMapToValuesForCreate(stmt, value)
    case *map[string]interface{}:
        values = ConvertMapToValuesForCreate(stmt, *value)
    case []map[string]interface{}:
        values = ConvertSliceOfMapToValuesForCreate(stmt, value)
    case *[]map[string]interface{}:
        values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
    default:
        var (
            selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
            curTime                   = stmt.DB.NowFunc()
            isZero                    bool
        )
        values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}

        for _, db := range stmt.Schema.DBNames {
            if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
                if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
                    values.Columns = append(values.Columns, clause.Column{Name: db})
                }
            }
        }

        switch stmt.ReflectValue.Kind() {
        case reflect.Slice, reflect.Array:
            stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
            values.Values = make([][]interface{}, stmt.ReflectValue.Len())
            defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
            if stmt.ReflectValue.Len() == 0 {
                stmt.AddError(gorm.ErrEmptySlice)
                return
            }

            for i := 0; i < stmt.ReflectValue.Len(); i++ {
                rv := reflect.Indirect(stmt.ReflectValue.Index(i))
                if !rv.IsValid() {
                    stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
                    return
                }

                values.Values[i] = make([]interface{}, len(values.Columns))
                for idx, column := range values.Columns {
                    field := stmt.Schema.FieldsByDBName[column.Name]
                    if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
                        if field.DefaultValueInterface != nil {
                            values.Values[i][idx] = field.DefaultValueInterface
                            field.Set(rv, field.DefaultValueInterface)
                        } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
                            field.Set(rv, curTime)
                            values.Values[i][idx], _ = field.ValueOf(rv)
                        }
                    } else if field.AutoUpdateTime > 0 {
                        if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
                            field.Set(rv, curTime)
                            values.Values[0][idx], _ = field.ValueOf(rv)
                        }
                    }
                }

                for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
                    if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
                        if v, isZero := field.ValueOf(rv); !isZero {
                            if len(defaultValueFieldsHavingValue[field]) == 0 {
                                defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
                            }
                            defaultValueFieldsHavingValue[field][i] = v
                        }
                    }
                }
            }

            for field, vs := range defaultValueFieldsHavingValue {
                values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
                for idx := range values.Values {
                    if vs[idx] == nil {
                        values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
                    } else {
                        values.Values[idx] = append(values.Values[idx], vs[idx])
                    }
                }
            }
        case reflect.Struct:
            values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
            for idx, column := range values.Columns {
                field := stmt.Schema.FieldsByDBName[column.Name]
                if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
                    if field.DefaultValueInterface != nil {
                        values.Values[0][idx] = field.DefaultValueInterface
                        field.Set(stmt.ReflectValue, field.DefaultValueInterface)
                    } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
                        field.Set(stmt.ReflectValue, curTime)
                        values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
                    }
                }
            }

            for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
                if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
                    if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
                        values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
                        values.Values[0] = append(values.Values[0], v)
                    }
                }
            }
        default:
            stmt.AddError(gorm.ErrInvalidData)
        }
    }

    if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
        if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
            if stmt.Schema != nil && len(values.Columns) > 1 {
                columns := make([]string, 0, len(values.Columns)-1)
                for _, column := range values.Columns {
                    if field := stmt.Schema.LookUpField(column.Name); field != nil {
                        if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
                            columns = append(columns, column.Name)
                        }
                    }
                }

                onConflict := clause.OnConflict{
                    Columns:   make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
                    DoUpdates: clause.AssignmentColumns(columns),
                }

                for idx, field := range stmt.Schema.PrimaryFields {
                    onConflict.Columns[idx] = clause.Column{Name: field.DBName}
                }

                stmt.AddClause(onConflict)
            }
        }
    }

    return values
}
ConvertToCreateValues从stmt.DB.NowFunc()获取curTime,然后对于field.AutoCreateTime或者field.AutoUpdateTime大于0的,会设置curTime

setupValuerAndSetter

gorm.io/[email protected]/schema/field.go

// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
    //......

    // Set
    switch field.FieldType.Kind() {
    case reflect.Bool:
        field.Set = func(value reflect.Value, v interface{}) error {
            switch data := v.(type) {
            case bool:
                field.ReflectValueOf(value).SetBool(data)
            case *bool:
                if data != nil {
                    field.ReflectValueOf(value).SetBool(*data)
                } else {
                    field.ReflectValueOf(value).SetBool(false)
                }
            case int64:
                if data > 0 {
                    field.ReflectValueOf(value).SetBool(true)
                } else {
                    field.ReflectValueOf(value).SetBool(false)
                }
            case string:
                b, _ := strconv.ParseBool(data)
                field.ReflectValueOf(value).SetBool(b)
            default:
                return fallbackSetter(value, v, field.Set)
            }
            return nil
        }
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        field.Set = func(value reflect.Value, v interface{}) (err error) {
            switch data := v.(type) {
            case int64:
                field.ReflectValueOf(value).SetInt(data)
            case int:
                field.ReflectValueOf(value).SetInt(int64(data))
            case int8:
                field.ReflectValueOf(value).SetInt(int64(data))
            case int16:
                field.ReflectValueOf(value).SetInt(int64(data))
            case int32:
                field.ReflectValueOf(value).SetInt(int64(data))
            case uint:
                field.ReflectValueOf(value).SetInt(int64(data))
            case uint8:
                field.ReflectValueOf(value).SetInt(int64(data))
            case uint16:
                field.ReflectValueOf(value).SetInt(int64(data))
            case uint32:
                field.ReflectValueOf(value).SetInt(int64(data))
            case uint64:
                field.ReflectValueOf(value).SetInt(int64(data))
            case float32:
                field.ReflectValueOf(value).SetInt(int64(data))
            case float64:
                field.ReflectValueOf(value).SetInt(int64(data))
            case []byte:
                return field.Set(value, string(data))
            case string:
                if i, err := strconv.ParseInt(data, 0, 64); err == nil {
                    field.ReflectValueOf(value).SetInt(i)
                } else {
                    return err
                }
            case time.Time:
                if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
                    field.ReflectValueOf(value).SetInt(data.UnixNano())
                } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
                    field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
                } else {
                    field.ReflectValueOf(value).SetInt(data.Unix())
                }
            case *time.Time:
                if data != nil {
                    if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
                        field.ReflectValueOf(value).SetInt(data.UnixNano())
                    } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
                        field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
                    } else {
                        field.ReflectValueOf(value).SetInt(data.Unix())
                    }
                } else {
                    field.ReflectValueOf(value).SetInt(0)
                }
            default:
                return fallbackSetter(value, v, field.Set)
            }
            return err
        }
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        field.Set = func(value reflect.Value, v interface{}) (err error) {
            switch data := v.(type) {
            case uint64:
                field.ReflectValueOf(value).SetUint(data)
            case uint:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case uint8:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case uint16:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case uint32:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case int64:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case int:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case int8:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case int16:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case int32:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case float32:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case float64:
                field.ReflectValueOf(value).SetUint(uint64(data))
            case []byte:
                return field.Set(value, string(data))
            case time.Time:
                if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
                    field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
                } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
                    field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
                } else {
                    field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
                }
            case string:
                if i, err := strconv.ParseUint(data, 0, 64); err == nil {
                    field.ReflectValueOf(value).SetUint(i)
                } else {
                    return err
                }
            default:
                return fallbackSetter(value, v, field.Set)
            }
            return err
        }
    case reflect.Float32, reflect.Float64:
        field.Set = func(value reflect.Value, v interface{}) (err error) {
            switch data := v.(type) {
            case float64:
                field.ReflectValueOf(value).SetFloat(data)
            case float32:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case int64:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case int:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case int8:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case int16:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case int32:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case uint:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case uint8:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case uint16:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case uint32:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case uint64:
                field.ReflectValueOf(value).SetFloat(float64(data))
            case []byte:
                return field.Set(value, string(data))
            case string:
                if i, err := strconv.ParseFloat(data, 64); err == nil {
                    field.ReflectValueOf(value).SetFloat(i)
                } else {
                    return err
                }
            default:
                return fallbackSetter(value, v, field.Set)
            }
            return err
        }
    case reflect.String:
        field.Set = func(value reflect.Value, v interface{}) (err error) {
            switch data := v.(type) {
            case string:
                field.ReflectValueOf(value).SetString(data)
            case []byte:
                field.ReflectValueOf(value).SetString(string(data))
            case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
                field.ReflectValueOf(value).SetString(utils.ToString(data))
            case float64, float32:
                field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
            default:
                return fallbackSetter(value, v, field.Set)
            }
            return err
        }
    default:
        fieldValue := reflect.New(field.FieldType)
        switch fieldValue.Elem().Interface().(type) {
        case time.Time:
            field.Set = func(value reflect.Value, v interface{}) error {
                switch data := v.(type) {
                case time.Time:
                    field.ReflectValueOf(value).Set(reflect.ValueOf(v))
                case *time.Time:
                    if data != nil {
                        field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem())
                    } else {
                        field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{}))
                    }
                case string:
                    if t, err := now.Parse(data); err == nil {
                        field.ReflectValueOf(value).Set(reflect.ValueOf(t))
                    } else {
                        return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
                    }
                default:
                    return fallbackSetter(value, v, field.Set)
                }
                return nil
            }
        case *time.Time:
            field.Set = func(value reflect.Value, v interface{}) error {
                switch data := v.(type) {
                case time.Time:
                    fieldValue := field.ReflectValueOf(value)
                    if fieldValue.IsNil() {
                        fieldValue.Set(reflect.New(field.FieldType.Elem()))
                    }
                    fieldValue.Elem().Set(reflect.ValueOf(v))
                case *time.Time:
                    field.ReflectValueOf(value).Set(reflect.ValueOf(v))
                case string:
                    if t, err := now.Parse(data); err == nil {
                        fieldValue := field.ReflectValueOf(value)
                        if fieldValue.IsNil() {
                            if v == "" {
                                return nil
                            }
                            fieldValue.Set(reflect.New(field.FieldType.Elem()))
                        }
                        fieldValue.Elem().Set(reflect.ValueOf(t))
                    } else {
                        return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
                    }
                default:
                    return fallbackSetter(value, v, field.Set)
                }
                return nil
            }
        default:
            if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
                // pointer scanner
                field.Set = func(value reflect.Value, v interface{}) (err error) {
                    reflectV := reflect.ValueOf(v)
                    if !reflectV.IsValid() {
                        field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
                    } else if reflectV.Type().AssignableTo(field.FieldType) {
                        field.ReflectValueOf(value).Set(reflectV)
                    } else if reflectV.Kind() == reflect.Ptr {
                        if reflectV.IsNil() || !reflectV.IsValid() {
                            field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
                        } else {
                            return field.Set(value, reflectV.Elem().Interface())
                        }
                    } else {
                        fieldValue := field.ReflectValueOf(value)
                        if fieldValue.IsNil() {
                            fieldValue.Set(reflect.New(field.FieldType.Elem()))
                        }

                        if valuer, ok := v.(driver.Valuer); ok {
                            v, _ = valuer.Value()
                        }

                        err = fieldValue.Interface().(sql.Scanner).Scan(v)
                    }
                    return
                }
            } else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
                // struct scanner
                field.Set = func(value reflect.Value, v interface{}) (err error) {
                    reflectV := reflect.ValueOf(v)
                    if !reflectV.IsValid() {
                        field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
                    } else if reflectV.Type().AssignableTo(field.FieldType) {
                        field.ReflectValueOf(value).Set(reflectV)
                    } else if reflectV.Kind() == reflect.Ptr {
                        if reflectV.IsNil() || !reflectV.IsValid() {
                            field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
                        } else {
                            return field.Set(value, reflectV.Elem().Interface())
                        }
                    } else {
                        if valuer, ok := v.(driver.Valuer); ok {
                            v, _ = valuer.Value()
                        }

                        err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
                    }
                    return
                }
            } else {
                field.Set = func(value reflect.Value, v interface{}) (err error) {
                    return fallbackSetter(value, v, field.Set)
                }
            }
        }
    }
}
setupValuerAndSetter方法针对time.Time或 *time.Time类型的setter会根据TimeType再做时间精度处理

实例

type Product struct {
    gorm.Model
    Code  string
    Price uint
}
Product内嵌了gorm.Model,内置了ID、CreatedAt、UpdatedAt、DeletedAt属性,同时Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt

小结

gorm定义了ID、CreatedAt、UpdatedAt、DeletedAt属性;其中Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt;CreatedAt、UpdatedAt支持 UnixSecond、UnixMillisecond、UnixNanosecond三种时间精度。

doc

你可能感兴趣的:(golang)