gorm无法批量插入解决办法(已实践)

gorm2.0版本以下不支持批量插入,那么咋只好造个轮子
利用反射机制获取数据集的类型和字段,合生批量插入的sql,最终利用exec语句执行。
具体代码如下(有很大的优化空间):

// BatchCreate 批量插入
func BatchCreate(db *gorm.DB, data interface{}) error {
	getValue := reflect.ValueOf(data)
	if getValue.Kind() != reflect.Slice {
		return errors.New("数据类型不支持")
	}

	l := getValue.Len()
	if l == 0 {
		return nil
	}

	firstValue := getValue.Index(0)
	fieldNum := firstValue.NumField()
	tableName := getTableName(firstValue.Type().Name())

	const CreatedAt = "CreatedAt"			//创建时间的结构体字段名
	const UpdatedAt = "UpdatedAt"			//更新时间的结构体字段名
	const CreatedAtField = "created_at"		//创建时间的数据库字段名
	const UpdatedAtField = "updated_at"		//更新时间的数据库字段名

	//获取字段名称
	var fields []string
	for i := 0; i < fieldNum; i++ {
		if firstValue.Field(i).Type().String() == "gorm.Model" {
			gormValue := reflect.ValueOf(firstValue.Field(i).Interface())
			for j := 0; j < gormValue.NumField(); j++ {
				if gormValue.Type().Field(j).Name == CreatedAt {
					fields = append(fields, CreatedAtField)
				} else if gormValue.Type().Field(j).Name == UpdatedAt {
					fields = append(fields, UpdatedAtField)
				}
			}
			continue
		}
		column := getTagValues(firstValue.Type().Field(i).Tag.Get("gorm"))["column"]
		if column != "" {
			fields = append(fields, column)
		}
	}

	//获取字段值
	var values []string
	for i := 0; i < l; i++ {
		value := getValue.Index(i)
		var one []string
		for j := 0; j < fieldNum; j++ {
			if value.Field(j).Type().String() == "gorm.Model" {
				gormValue := reflect.ValueOf(firstValue.Field(j).Interface())
				for k := 0; k < gormValue.NumField(); k++ {
					if gormValue.Type().Field(k).Name == CreatedAt {
						createdTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
						if createdTime == "''" {
							createdTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
						}
						one = append(one, createdTime)
					} else if gormValue.Type().Field(k).Name == UpdatedAt {
						updatedTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
						if updatedTime == "''" {
							updatedTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
						}
						one = append(one, updatedTime)
					}
				}
				continue
			}
			if getTagValues(value.Type().Field(j).Tag.Get("gorm"))["column"] != "" {
				fieldType := value.Field(j).Type().String()
				one = append(one, getField(value.Field(j).Interface(), fieldType))
			}
		}
		values = append(values, fmt.Sprintf("(%s)", strings.Join(one, ",")))
		if len(values) >= 100 {
			//大于等于100条分页插入
			sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
			if err := db.Exec(sql).Error; err != nil {
				return err
			}
			values = []string{}
		}
	}

	if len(values) > 0 {
		sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
		return db.Exec(sql).Error
	}

	return nil
}

//根据结构体名称获取表名
func getTableName(modelName string) string {
	reg, _ := regexp.Compile("[A-Z]([a-z]+)")
	return strings.ToLower(strings.Join(reg.FindAllString(modelName, -1), "_"))
}

//获取gorm tag中的字段名
func getTagValues(tag string) map[string]string {
	var fieldMap map[string]string
	fieldMap = make(map[string]string)
	for _, v := range strings.Split(tag, ";") {
		s := strings.Split(v, ":")
		fieldMap[s[0]] = s[1]
	}
	return fieldMap
}

//获取插入的字段值
func getField(data interface{}, fieldType string) string {
	switch fieldType {
	case "string":
		return fmt.Sprintf("'%s'", data.(string))
	case "uint64":
		return fmt.Sprintf("%d", data.(uint64))
	case "uint32":
		return fmt.Sprintf("%d", data.(uint32))
	case "time.Time":
		s := data.(time.Time).Format("2006-01-02 15:04:05")
		if s == "0001-01-01 00:00:00" {
			return "''"
		}
		return fmt.Sprintf("'%s'", s)
	case "*time.Time":
		dtime := data.(*time.Time)
		if dtime == nil {
			return "''"
		}
		s := dtime.Format("2006-01-02 15:04:05")
		if s == "0001-01-01 00:00:00" {
			return "''"
		}
		return fmt.Sprintf("'%s'", s)
	}

	return "''"
}

使用方式:

//gorm.Model
type Model struct {
	ID        uint `gorm:"primary_key"`
	CreatedAt time.Time
	UpdatedAt time.Time
	DeletedAt *time.Time `sql:"index"`
}

type User struct {
	gorm.Model

	Name  		string `gorm:"column:name"`
	Sex         uint32 `gorm:"column:sex"`
}
users := []User{
		{
			Name: "小明",
			Sex: 0,
		},
		{
			Name: "小红",
			Sex: 1,
		},
	}
BatchCreate(conn, users)

你可能感兴趣的:(Go,go,gorm,批量插入)