序
本文主要研究一下gorm的Model
Model
gorm.io/[email protected]/model.go
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt DeletedAt `gorm:"index"`
}
Model定义了ID、CreatedAt、UpdatedAt、DeletedAt属性
ParseField
gorm.io/[email protected]/schema/field.go
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
field := &Field{
Name: fieldStruct.Name,
BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Creatable: true,
Updatable: true,
Readable: true,
Tag: fieldStruct.Tag,
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
Schema: schema,
AutoIncrementIncrement: 1,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
field.IndirectFieldType = field.IndirectFieldType.Elem()
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
//......
field.GORMDataType = field.DataType
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType())
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoCreateTime = UnixMillisecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
//......
return field
}
ParseField方法会解析field的属性,如果field的name为CreatedAt或者UpdatedAt,且dataType为Time、Int、Unit或者tag标注了AUTOCREATETIME或者AUTOUPDATETIME,则会设置field.AutoCreateTime或者field.AutoUpdateTime
TimeType
gorm.io/[email protected]/schema/field.go
type TimeType int64
const (
UnixSecond TimeType = 1
UnixMillisecond TimeType = 2
UnixNanosecond TimeType = 3
)
field.AutoCreateTime、AutoUpdateTime属性为TimeType类型,该类型有UnixSecond、UnixMillisecond、UnixNanosecond三种类型
ConvertToCreateValues
gorm.io/[email protected]/callbacks/create.go
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
curTime = stmt.DB.NowFunc()
isZero bool
)
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
}
} else if field.AutoUpdateTime > 0 {
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
field.Set(rv, curTime)
values.Values[0][idx], _ = field.ValueOf(rv)
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
}
defaultValueFieldsHavingValue[field][i] = v
}
}
}
}
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], v)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
}
}
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClause(onConflict)
}
}
}
return values
}
ConvertToCreateValues从stmt.DB.NowFunc()获取curTime,然后对于field.AutoCreateTime或者field.AutoUpdateTime大于0的,会设置curTime
setupValuerAndSetter
gorm.io/[email protected]/schema/field.go
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
//......
// Set
switch field.FieldType.Kind() {
case reflect.Bool:
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case bool:
field.ReflectValueOf(value).SetBool(data)
case *bool:
if data != nil {
field.ReflectValueOf(value).SetBool(*data)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case int64:
if data > 0 {
field.ReflectValueOf(value).SetBool(true)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b)
default:
return fallbackSetter(value, v, field.Set)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case int64:
field.ReflectValueOf(value).SetInt(data)
case int:
field.ReflectValueOf(value).SetInt(int64(data))
case int8:
field.ReflectValueOf(value).SetInt(int64(data))
case int16:
field.ReflectValueOf(value).SetInt(int64(data))
case int32:
field.ReflectValueOf(value).SetInt(int64(data))
case uint:
field.ReflectValueOf(value).SetInt(int64(data))
case uint8:
field.ReflectValueOf(value).SetInt(int64(data))
case uint16:
field.ReflectValueOf(value).SetInt(int64(data))
case uint32:
field.ReflectValueOf(value).SetInt(int64(data))
case uint64:
field.ReflectValueOf(value).SetInt(int64(data))
case float32:
field.ReflectValueOf(value).SetInt(int64(data))
case float64:
field.ReflectValueOf(value).SetInt(int64(data))
case []byte:
return field.Set(value, string(data))
case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetInt(i)
} else {
return err
}
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(value).SetInt(0)
}
default:
return fallbackSetter(value, v, field.Set)
}
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case uint64:
field.ReflectValueOf(value).SetUint(data)
case uint:
field.ReflectValueOf(value).SetUint(uint64(data))
case uint8:
field.ReflectValueOf(value).SetUint(uint64(data))
case uint16:
field.ReflectValueOf(value).SetUint(uint64(data))
case uint32:
field.ReflectValueOf(value).SetUint(uint64(data))
case int64:
field.ReflectValueOf(value).SetUint(uint64(data))
case int:
field.ReflectValueOf(value).SetUint(uint64(data))
case int8:
field.ReflectValueOf(value).SetUint(uint64(data))
case int16:
field.ReflectValueOf(value).SetUint(uint64(data))
case int32:
field.ReflectValueOf(value).SetUint(uint64(data))
case float32:
field.ReflectValueOf(value).SetUint(uint64(data))
case float64:
field.ReflectValueOf(value).SetUint(uint64(data))
case []byte:
return field.Set(value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
}
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i)
} else {
return err
}
default:
return fallbackSetter(value, v, field.Set)
}
return err
}
case reflect.Float32, reflect.Float64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case float64:
field.ReflectValueOf(value).SetFloat(data)
case float32:
field.ReflectValueOf(value).SetFloat(float64(data))
case int64:
field.ReflectValueOf(value).SetFloat(float64(data))
case int:
field.ReflectValueOf(value).SetFloat(float64(data))
case int8:
field.ReflectValueOf(value).SetFloat(float64(data))
case int16:
field.ReflectValueOf(value).SetFloat(float64(data))
case int32:
field.ReflectValueOf(value).SetFloat(float64(data))
case uint:
field.ReflectValueOf(value).SetFloat(float64(data))
case uint8:
field.ReflectValueOf(value).SetFloat(float64(data))
case uint16:
field.ReflectValueOf(value).SetFloat(float64(data))
case uint32:
field.ReflectValueOf(value).SetFloat(float64(data))
case uint64:
field.ReflectValueOf(value).SetFloat(float64(data))
case []byte:
return field.Set(value, string(data))
case string:
if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(value).SetFloat(i)
} else {
return err
}
default:
return fallbackSetter(value, v, field.Set)
}
return err
}
case reflect.String:
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case string:
field.ReflectValueOf(value).SetString(data)
case []byte:
field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(utils.ToString(data))
case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return fallbackSetter(value, v, field.Set)
}
return err
}
default:
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
case time.Time:
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case *time.Time:
if data != nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem())
} else {
field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{}))
}
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(value, v, field.Set)
}
return nil
}
case *time.Time:
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case time.Time:
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case string:
if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == "" {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(value, v, field.Set)
}
return nil
}
default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = fieldValue.Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set)
}
}
}
}
}
setupValuerAndSetter方法针对time.Time或
*time.Time
类型的setter会根据TimeType再做时间精度处理
实例
type Product struct {
gorm.Model
Code string
Price uint
}
Product内嵌了gorm.Model,内置了ID、CreatedAt、UpdatedAt、DeletedAt属性,同时Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt
小结
gorm定义了ID、CreatedAt、UpdatedAt、DeletedAt属性;其中Create的时候会自动设置CreatedAt、UpdatedAt,Update的时候会自动更新UpdatedAt;CreatedAt、UpdatedAt支持 UnixSecond、UnixMillisecond、UnixNanosecond三种时间精度。