golang struct拷贝工具(类似于java中 BeanUtils.copyProperties())

golang struct拷贝工具(类似于java中 BeanUtils.copyProperties())

在开发中我们经常需要拷贝结构体的属性,例如前端传参时会比db表中对应接口体多一个属性

我们用gin开发时,一般都会用一个结构体来保存前端传来的json数据,例如:

type UserParams struct {
		models.UserBasic
		RePassword string `validate:"eqfield=Password" reg_error_info:"两次密码不一样"`
}
userParams := &UserParams{}
	if err := c.BindJSON(userParams); err != nil {
		c.JSON(-1, response.Err.WithMsg("参数有误,err:"+err.Error()))
		return
}

但是传给其他服务时(例如保存到db中),我们又希望只要db中对应的结构体,这时候我们只需要

userBasic := &models.UserBasic{}
// 将结构体的属性从 from 深拷贝到 to
err := utils.Copy(userParams).To(userBasic)

这个是github上比较出名的开源工具

  • https://github.com/ulule/deepcopier

其实就是一个文件,如下,使用即可

golang struct拷贝工具(类似于java中 BeanUtils.copyProperties())_第1张图片

package deepcopier

import (
	"database/sql/driver"
	"fmt"
	"reflect"
	"strings"
)

const (
	// TagName is the deepcopier struct tag name.
	TagName = "deepcopier"
	// FieldOptionName is the from field option name for struct tag.
	FieldOptionName = "field"
	// ContextOptionName is the context option name for struct tag.
	ContextOptionName = "context"
	// SkipOptionName is the skip option name for struct tag.
	SkipOptionName = "skip"
	// ForceOptionName is the skip option name for struct tag.
	ForceOptionName = "force"
)

type (
	// TagOptions is a map that contains extracted struct tag options.
	TagOptions map[string]string

	// Options are copier options.
	Options struct {
		// Context given to WithContext() method.
		Context map[string]interface{}
		// Reversed reverses struct tag checkings.
		Reversed bool
	}
)

// DeepCopier deep copies a struct to/from a struct.
type DeepCopier struct {
	dst interface{}
	src interface{}
	ctx map[string]interface{}
}

// Copy sets source or destination.
func Copy(src interface{}) *DeepCopier {
	return &DeepCopier{src: src}
}

// WithContext injects the given context into the builder instance.
func (dc *DeepCopier) WithContext(ctx map[string]interface{}) *DeepCopier {
	dc.ctx = ctx
	return dc
}

// To sets the destination.
func (dc *DeepCopier) To(dst interface{}) error {
	dc.dst = dst
	return process(dc.dst, dc.src, Options{Context: dc.ctx})
}

// From sets the given the source as destination and destination as source.
func (dc *DeepCopier) From(src interface{}) error {
	dc.dst = dc.src
	dc.src = src
	return process(dc.dst, dc.src, Options{Context: dc.ctx, Reversed: true})
}

// process handles copy.
func process(dst interface{}, src interface{}, args ...Options) error {
	var (
		options        = Options{}
		srcValue       = reflect.Indirect(reflect.ValueOf(src))
		dstValue       = reflect.Indirect(reflect.ValueOf(dst))
		srcFieldNames  = getFieldNames(src)
		srcMethodNames = getMethodNames(src)
	)

	if len(args) > 0 {
		options = args[0]
	}

	if !dstValue.CanAddr() {
		return fmt.Errorf("destination %+v is unaddressable", dstValue.Interface())
	}

	for _, f := range srcFieldNames {
		var (
			srcFieldValue               = srcValue.FieldByName(f)
			srcFieldType, srcFieldFound = srcValue.Type().FieldByName(f)
			srcFieldName                = srcFieldType.Name
			dstFieldName                = srcFieldName
			tagOptions                  TagOptions
		)

		if !srcFieldFound {
			continue
		}

		if options.Reversed {
			tagOptions = getTagOptions(srcFieldType.Tag.Get(TagName))
			if v, ok := tagOptions[FieldOptionName]; ok && v != "" {
				dstFieldName = v
			}
		} else {
			if name, opts := getRelatedField(dst, srcFieldName); name != "" {
				dstFieldName, tagOptions = name, opts
			}
		}

		if _, ok := tagOptions[SkipOptionName]; ok {
			continue
		}

		var (
			dstFieldType, dstFieldFound = dstValue.Type().FieldByName(dstFieldName)
			dstFieldValue               = dstValue.FieldByName(dstFieldName)
		)

		if !dstFieldFound {
			continue
		}

		// Force option for empty interfaces and nullable types
		_, force := tagOptions[ForceOptionName]

		// Valuer -> ptr
		if isNullableType(srcFieldType.Type) && dstFieldValue.Kind() == reflect.Ptr && force {
			// We have same nullable type on both sides
			if srcFieldValue.Type().AssignableTo(dstFieldType.Type) {
				dstFieldValue.Set(srcFieldValue)
				continue
			}

			v, _ := srcFieldValue.Interface().(driver.Valuer).Value()
			if v == nil {
				continue
			}

			valueType := reflect.TypeOf(v)

			ptr := reflect.New(valueType)
			ptr.Elem().Set(reflect.ValueOf(v))

			if valueType.AssignableTo(dstFieldType.Type.Elem()) {
				dstFieldValue.Set(ptr)
			}

			continue
		}

		// Valuer -> value
		if isNullableType(srcFieldType.Type) {
			// We have same nullable type on both sides
			if srcFieldValue.Type().AssignableTo(dstFieldType.Type) {
				dstFieldValue.Set(srcFieldValue)
				continue
			}

			if force {
				v, _ := srcFieldValue.Interface().(driver.Valuer).Value()
				if v == nil {
					continue
				}

				rv := reflect.ValueOf(v)
				if rv.Type().AssignableTo(dstFieldType.Type) {
					dstFieldValue.Set(rv)
				}
			}

			continue
		}

		if dstFieldValue.Kind() == reflect.Interface {
			if force {
				dstFieldValue.Set(srcFieldValue)
			}
			continue
		}

		// Ptr -> Value
		if srcFieldType.Type.Kind() == reflect.Ptr && !srcFieldValue.IsNil() && dstFieldType.Type.Kind() != reflect.Ptr {
			indirect := reflect.Indirect(srcFieldValue)

			if indirect.Type().AssignableTo(dstFieldType.Type) {
				dstFieldValue.Set(indirect)
				continue
			}
		}

		// Other types
		if srcFieldType.Type.AssignableTo(dstFieldType.Type) {
			dstFieldValue.Set(srcFieldValue)
		}
	}

	for _, m := range srcMethodNames {
		name, opts := getRelatedField(dst, m)
		if name == "" {
			continue
		}

		if _, ok := opts[SkipOptionName]; ok {
			continue
		}

		method := reflect.ValueOf(src).MethodByName(m)
		if !method.IsValid() {
			return fmt.Errorf("method %s is invalid", m)
		}

		var (
			dstFieldType, _ = dstValue.Type().FieldByName(name)
			dstFieldValue   = dstValue.FieldByName(name)
			_, withContext  = opts[ContextOptionName]
			_, force        = opts[ForceOptionName]
		)

		args := []reflect.Value{}
		if withContext {
			args = []reflect.Value{reflect.ValueOf(options.Context)}
		}

		var (
			result          = method.Call(args)[0]
			resultInterface = result.Interface()
			resultValue     = reflect.ValueOf(resultInterface)
			resultType      = resultValue.Type()
		)

		// Value -> Ptr
		if dstFieldValue.Kind() == reflect.Ptr && force {
			ptr := reflect.New(resultType)
			ptr.Elem().Set(resultValue)

			if ptr.Type().AssignableTo(dstFieldType.Type) {
				dstFieldValue.Set(ptr)
			}

			continue
		}

		// Ptr -> value
		if resultValue.Kind() == reflect.Ptr && force {
			if resultValue.Elem().Type().AssignableTo(dstFieldType.Type) {
				dstFieldValue.Set(resultValue.Elem())
			}

			continue
		}

		if resultType.AssignableTo(dstFieldType.Type) && result.IsValid() {
			dstFieldValue.Set(result)
		}
	}

	return nil
}

// getTagOptions parses deepcopier tag field and returns options.
func getTagOptions(value string) TagOptions {
	options := TagOptions{}

	for _, opt := range strings.Split(value, ";") {
		o := strings.Split(opt, ":")

		// deepcopier:"keyword; without; value;"
		if len(o) == 1 {
			options[o[0]] = ""
		}

		// deepcopier:"key:value; anotherkey:anothervalue"
		if len(o) == 2 {
			options[strings.TrimSpace(o[0])] = strings.TrimSpace(o[1])
		}
	}

	return options
}

// getRelatedField returns first matching field.
func getRelatedField(instance interface{}, name string) (string, TagOptions) {
	var (
		value      = reflect.Indirect(reflect.ValueOf(instance))
		fieldName  string
		tagOptions TagOptions
	)

	for i := 0; i < value.NumField(); i++ {
		var (
			vField     = value.Field(i)
			tField     = value.Type().Field(i)
			tagOptions = getTagOptions(tField.Tag.Get(TagName))
		)

		if tField.Type.Kind() == reflect.Struct && tField.Anonymous {
			if n, o := getRelatedField(vField.Interface(), name); n != "" {
				return n, o
			}
		}

		if v, ok := tagOptions[FieldOptionName]; ok && v == name {
			return tField.Name, tagOptions
		}

		if tField.Name == name {
			return tField.Name, tagOptions
		}
	}

	return fieldName, tagOptions
}

// getMethodNames returns instance's method names.
func getMethodNames(instance interface{}) []string {
	var methods []string

	t := reflect.TypeOf(instance)
	for i := 0; i < t.NumMethod(); i++ {
		methods = append(methods, t.Method(i).Name)
	}

	return methods
}

// getFieldNames returns instance's field names.
func getFieldNames(instance interface{}) []string {
	var (
		fields []string
		v      = reflect.Indirect(reflect.ValueOf(instance))
		t      = v.Type()
	)

	if t.Kind() != reflect.Struct {
		return nil
	}

	for i := 0; i < v.NumField(); i++ {
		var (
			vField = v.Field(i)
			tField = v.Type().Field(i)
		)

		// Is exportable?
		if tField.PkgPath != "" {
			continue
		}

		if tField.Type.Kind() == reflect.Struct && tField.Anonymous {
			fields = append(fields, getFieldNames(vField.Interface())...)
			continue
		}

		fields = append(fields, tField.Name)
	}

	return fields
}

// isNullableType returns true if the given type is a nullable one.
func isNullableType(t reflect.Type) bool {
	return t.ConvertibleTo(reflect.TypeOf((*driver.Valuer)(nil)).Elem())
}

你可能感兴趣的:(golang,java,golang,json)