gorm2.0版本以下不支持批量插入,那么咋只好造个轮子
利用反射机制获取数据集的类型和字段,合生批量插入的sql,最终利用exec语句执行。
具体代码如下(有很大的优化空间):
// BatchCreate 批量插入
func BatchCreate(db *gorm.DB, data interface{}) error {
getValue := reflect.ValueOf(data)
if getValue.Kind() != reflect.Slice {
return errors.New("数据类型不支持")
}
l := getValue.Len()
if l == 0 {
return nil
}
firstValue := getValue.Index(0)
fieldNum := firstValue.NumField()
tableName := getTableName(firstValue.Type().Name())
const CreatedAt = "CreatedAt" //创建时间的结构体字段名
const UpdatedAt = "UpdatedAt" //更新时间的结构体字段名
const CreatedAtField = "created_at" //创建时间的数据库字段名
const UpdatedAtField = "updated_at" //更新时间的数据库字段名
//获取字段名称
var fields []string
for i := 0; i < fieldNum; i++ {
if firstValue.Field(i).Type().String() == "gorm.Model" {
gormValue := reflect.ValueOf(firstValue.Field(i).Interface())
for j := 0; j < gormValue.NumField(); j++ {
if gormValue.Type().Field(j).Name == CreatedAt {
fields = append(fields, CreatedAtField)
} else if gormValue.Type().Field(j).Name == UpdatedAt {
fields = append(fields, UpdatedAtField)
}
}
continue
}
column := getTagValues(firstValue.Type().Field(i).Tag.Get("gorm"))["column"]
if column != "" {
fields = append(fields, column)
}
}
//获取字段值
var values []string
for i := 0; i < l; i++ {
value := getValue.Index(i)
var one []string
for j := 0; j < fieldNum; j++ {
if value.Field(j).Type().String() == "gorm.Model" {
gormValue := reflect.ValueOf(firstValue.Field(j).Interface())
for k := 0; k < gormValue.NumField(); k++ {
if gormValue.Type().Field(k).Name == CreatedAt {
createdTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
if createdTime == "''" {
createdTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
}
one = append(one, createdTime)
} else if gormValue.Type().Field(k).Name == UpdatedAt {
updatedTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
if updatedTime == "''" {
updatedTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
}
one = append(one, updatedTime)
}
}
continue
}
if getTagValues(value.Type().Field(j).Tag.Get("gorm"))["column"] != "" {
fieldType := value.Field(j).Type().String()
one = append(one, getField(value.Field(j).Interface(), fieldType))
}
}
values = append(values, fmt.Sprintf("(%s)", strings.Join(one, ",")))
if len(values) >= 100 {
//大于等于100条分页插入
sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
if err := db.Exec(sql).Error; err != nil {
return err
}
values = []string{}
}
}
if len(values) > 0 {
sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
return db.Exec(sql).Error
}
return nil
}
//根据结构体名称获取表名
func getTableName(modelName string) string {
reg, _ := regexp.Compile("[A-Z]([a-z]+)")
return strings.ToLower(strings.Join(reg.FindAllString(modelName, -1), "_"))
}
//获取gorm tag中的字段名
func getTagValues(tag string) map[string]string {
var fieldMap map[string]string
fieldMap = make(map[string]string)
for _, v := range strings.Split(tag, ";") {
s := strings.Split(v, ":")
fieldMap[s[0]] = s[1]
}
return fieldMap
}
//获取插入的字段值
func getField(data interface{}, fieldType string) string {
switch fieldType {
case "string":
return fmt.Sprintf("'%s'", data.(string))
case "uint64":
return fmt.Sprintf("%d", data.(uint64))
case "uint32":
return fmt.Sprintf("%d", data.(uint32))
case "time.Time":
s := data.(time.Time).Format("2006-01-02 15:04:05")
if s == "0001-01-01 00:00:00" {
return "''"
}
return fmt.Sprintf("'%s'", s)
case "*time.Time":
dtime := data.(*time.Time)
if dtime == nil {
return "''"
}
s := dtime.Format("2006-01-02 15:04:05")
if s == "0001-01-01 00:00:00" {
return "''"
}
return fmt.Sprintf("'%s'", s)
}
return "''"
}
使用方式:
//gorm.Model
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `sql:"index"`
}
type User struct {
gorm.Model
Name string `gorm:"column:name"`
Sex uint32 `gorm:"column:sex"`
}
users := []User{
{
Name: "小明",
Sex: 0,
},
{
Name: "小红",
Sex: 1,
},
}
BatchCreate(conn, users)