golang sql 转 struct 插件实现

写go curd的时候,经常需要sql 转 struct,比较麻烦,写个自动转换的代码:
main.go

package main

import (
    "convert/convert"
    "flag"
    "fmt"
    "log"
    "os"
    "path"
)

const VERSION = "1.0.0"
const VersionText = "Convert of mysql schema to golang struct"

var saveFilePath string

func init() {
    pa, err := os.Getwd()
    if err != nil {
        fmt.Println("获取运行文件路径获取失败")
        return
    }
    saveFilePath = path.Join(pa, "models")
    if isexists, _ := PathExists(saveFilePath); !isexists {
        os.Mkdir(saveFilePath, os.ModePerm)
    }
}

func main() {
    connStr :="root:[email protected](127.0.0.1:3306)/test?charset=utf8mb4&parseTime=true"
    dsn := flag.String("dsn", connStr, "连接数据库字符串")
    file := flag.String("file", saveFilePath, "保存路径")
    table := flag.String("table", "", "要生成的表名") //不填默认所有表
    realNameMethod := flag.String("realNameMethod", "TableName", "结构体对应的表名")
    packageName := flag.String("packageName", "", "生成struct的包名(默认为空的话, 则取名为: package models)")
    tagKey := flag.String("tagKey", "gorm", "字段tag的key")
    prefix := flag.String("prefix", "", "表前缀")
    version := flag.Bool("version", false, "版本号")
    v := flag.Bool("v", false, "版本号")
    enableJsonTag := flag.Bool("enableJsonTag", true, "是否添加json的tag,默认false")
    enableFormTag := flag.Bool("enableFormTag", false, "是否添加form的tag,默认false")

    // 版本号
    if *version || *v {
        fmt.Println(fmt.Sprintf("\n version: %s\n %s\n using -h param for more help \n", VERSION, VersionText))
        return
    }

    // 初始化
    t2t := convert.NewTable2Struct()
    // 个性化配置
    t2t.Config(&convert.T2tConfig{
        // json tag是否转为驼峰(大驼峰式),默认为false,不转换
        JsonTagToHump:     true, //false=>json tag:  request_method  true=> json tag  RequestMethod
        JsonTagToFirstLow: true, // json  tag首字母是否转换小写, 默认true  => requestMethod
        // 结构体名称是否转为驼峰式,默认为false
        StructNameToHump: true,
        // 如果字段首字母本来就是大写, 就不添加tag, 默认false添加, true不添加
        RmTagIfUcFirsted: false,
        // tag的字段名字首字母是否转换为小写, 如果本身有大写字母的话, 默认false不转
        TagToLower: false,
        // 字段首字母大写的同时, 是否要把其他字母转换为小写,默认false不转换
        UcFirstOnly: false,
        // 每个struct放入单独的文件,默认false,放入同一个文件
        SeperatFile: false,
    })
    // 开始迁移转换
    err := t2t.
        // 指定某个表,如果不指定,则默认全部表都迁移
        Table(*table).
        // 表前缀
        Prefix(*prefix).
        // 是否添加json tag
        EnableJsonTag(*enableJsonTag).
        EnableFormTag(*enableFormTag).
        // 生成struct的包名(默认为空的话, 则取名为: package model)
        PackageName(*packageName).
        // tag字段的key值,默认是gorm
        TagKey(*tagKey).
        // 是否添加结构体方法获取表名
        RealNameMethod(*realNameMethod).
        // 生成的结构体保存路径
        SavePath(*file).
        // 数据库dsn
        Dsn(*dsn).
        // 执行
        Run()
    if err != nil {
        log.Println(err.Error())
    }
}

// PathExists 判断所给路径文件/文件夹是否存在
func PathExists(path string) (bool, error) {
    _, err := os.Stat(path)
    if err == nil {
        return true, nil
    }
    //isnotexist来判断,是不是不存在的错误
    if os.IsNotExist(err) { //如果返回的错误类型使用os.isNotExist()判断为true,说明文件或者文件夹不存在
        return false, nil
    }
    return false, err //如果有错误了,但是不是不存在的错误,所以把这个错误原封不动的返回
}

tableTostruct.go 代码:

package convert

import (
    "database/sql"
    "errors"
    "fmt"
    _ "github.com/go-sql-driver/mysql"
    "log"
    "os"
    "os/exec"
    "path"
    "strings"
)

//map for converting mysql type to golang types
var typeForMysqlToGo = map[string]string{
    "int":                "int64",
    "integer":            "int64",
    "tinyint":            "int64",
    "smallint":           "int64",
    "mediumint":          "int64",
    "bigint":             "int64",
    "int unsigned":       "int64",
    "integer unsigned":   "int64",
    "tinyint unsigned":   "int64",
    "smallint unsigned":  "int64",
    "mediumint unsigned": "int64",
    "bigint unsigned":    "int64",
    "bit":                "int64",
    "bool":               "bool",
    "enum":               "string",
    "set":                "string",
    "varchar":            "string",
    "char":               "string",
    "tinytext":           "string",
    "mediumtext":         "string",
    "text":               "string",
    "longtext":           "string",
    "blob":               "string",
    "tinyblob":           "string",
    "mediumblob":         "string",
    "longblob":           "string",
    "date":               "time.Time", // time.Time or string
    "datetime":           "time.Time", // time.Time or string
    "timestamp":          "time.Time", // time.Time or string
    "time":               "time.Time", // time.Time or string
    "float":              "float64",
    "double":             "float64",
    "decimal":            "float64",
    "binary":             "string",
    "varbinary":          "string",
}

type TableToStruct struct {
    dsn            string
    savePath       string
    db             *sql.DB
    table          string
    prefix         string
    config         *T2tConfig
    err            error
    realNameMethod string
    enableJsonTag  bool   // 是否添加json的tag, 默认不添加
    enableFormTag  bool   // 是否添加form的tag, 默认不添加
    packageName    string // 生成struct的包名(默认为空的话, 则取名为: package model)
    tagKey         string // tag字段的key值,默认是orm
    dateToTime     bool   // 是否将 date相关字段转换为 time.Time,默认否
}

type T2tConfig struct {
    StructNameToHump  bool // 结构体名称是否转为驼峰式,默认为false
    RmTagIfUcFirsted  bool // 如果字段首字母本来就是大写, 就不添加tag, 默认false添加, true不添加
    TagToLower        bool // tag的字段名字是否转换为小写, 如果本身有大写字母的话, 默认false不转
    JsonTagToHump     bool // json tag是否转为驼峰(大驼峰式),默认为false,不转换
    JsonTagToFirstLow bool // json tag 首字母是否转换小写
    UcFirstOnly       bool // 字段首字母大写的同时, 是否要把其他字母转换为小写,默认false不转换
    SeperatFile       bool // 每个struct放入单独的文件,默认false,放入同一个文件
}

func NewTable2Struct() *TableToStruct {
    return &TableToStruct{}
}

func (t *TableToStruct) Dsn(d string) *TableToStruct {
    t.dsn = d
    return t
}

func (t *TableToStruct) TagKey(r string) *TableToStruct {
    t.tagKey = r
    return t
}

func (t *TableToStruct) PackageName(r string) *TableToStruct {
    t.packageName = r
    return t
}

func (t *TableToStruct) RealNameMethod(r string) *TableToStruct {
    t.realNameMethod = r
    return t
}

func (t *TableToStruct) SavePath(p string) *TableToStruct {
    t.savePath = p
    return t
}

func (t *TableToStruct) DB(d *sql.DB) *TableToStruct {
    t.db = d
    return t
}

func (t *TableToStruct) Table(tab string) *TableToStruct {
    t.table = tab
    return t
}

func (t *TableToStruct) Prefix(p string) *TableToStruct {
    t.prefix = p
    return t
}

func (t *TableToStruct) EnableJsonTag(p bool) *TableToStruct {
    t.enableJsonTag = p
    return t
}

func (t *TableToStruct) EnableFormTag(p bool) *TableToStruct {
    t.enableFormTag = p
    return t
}

func (t *TableToStruct) DateToTime(d bool) *TableToStruct {
    t.dateToTime = d
    return t
}

func (t *TableToStruct) Config(c *T2tConfig) *TableToStruct {
    t.config = c
    return t
}

// Run 生成逻辑
func (t *TableToStruct) Run() error {
    if t.config == nil {
        t.config = new(T2tConfig)
    }
    // 链接mysql, 获取db对象
    t.dialMysql()
    if t.err != nil {
        return t.err
    }

    // 获取表和字段的shcema
    tableColumns, err := t.getColumns()
    if err != nil {
        return err
    }

    // 包名
    var packageName string
    if t.packageName == "" {
        packageName = "package models\n\n"
    } else {
        packageName = fmt.Sprintf("package %s\n\n", t.packageName)
    }

    // 组装struct
    var structContent string
    for tableRealName, item := range tableColumns {
        // 去除前缀
        if t.prefix != "" {
            tableRealName = tableRealName[len(t.prefix):]
        }
        tableName := tableRealName
        structName := tableName
        if t.config.StructNameToHump {
            structName = t.camelCase(structName)
        }

        switch len(tableName) {
        case 0:
        case 1:
            tableName = strings.ToUpper(tableName[0:1])
        default:
            // 字符长度大于1时
            tableName = strings.ToUpper(tableName[0:1]) + tableName[1:]
        }
        depth := 1
        structContent += "type " + structName + " struct {\n"
        for _, v := range item {
            //structContent += tab(depth) + v.ColumnName + " " + v.Type + " " + v.Json + "\n"
            // 字段注释
            var clumnComment string
            if v.ColumnComment != "" {
                clumnComment = fmt.Sprintf(" // %s", v.ColumnComment)
            }
            structContent += fmt.Sprintf("%s%s %s %s%s\n",
                tab(depth), v.ColumnName, v.Type, v.Tag, clumnComment)
        }
        structContent += tab(depth-1) + "}\n\n"

        // 添加 method 获取真实表名
        if t.realNameMethod != "" {
            structContent += fmt.Sprintf("func (*%s) %s() string {\n",
                structName, t.realNameMethod)
            structContent += fmt.Sprintf("%sreturn \"%s\"\n",
                tab(depth), tableRealName)
            structContent += "}\n\n"
        }

        //如果为 true,每个表按照 struct 分开存放一个文件
        if t.config.SeperatFile {
            // 如果有引入 time.Time, 则需要引入 time 包
            var importContent string
            if strings.Contains(structContent, "time.Time") {
                importContent = "import \"time\"\n\n"
            }

            // 写入文件struct
            var savePath = t.savePath
            savePath = path.Join(savePath, tableRealName+".go")

            filePath := fmt.Sprintf("%s", savePath)
            f, err := os.Create(filePath)
            if err != nil {
                log.Println("Can not write file")
                return err
            }
            defer f.Close()

            f.WriteString(packageName + importContent + structContent)

            cmd := exec.Command("gofmt", "-w", filePath)
            cmd.Run()
            structContent = ""
        }
    }
    // false, 多个表放入同一个文件
    if t.config.SeperatFile == false {
        // 如果有引入 time.Time, 则需要引入 time 包
        var importContent string
        if strings.Contains(structContent, "time.Time") {
            importContent = "import \"time\"\n\n"
        }

        // 写入文件struct
        var savePath = t.savePath
        if t.table != "" {
            savePath = path.Join(savePath, t.table+".go")
        } else {
            savePath = path.Join(savePath, "models.go")
        }
        filePath := fmt.Sprintf("%s", savePath)
        f, err := os.Create(filePath)
        if err != nil {
            log.Println("Can not write file")
            return err
        }
        defer f.Close()

        f.WriteString(packageName + importContent + structContent)

        cmd := exec.Command("gofmt", "-w", filePath)
        cmd.Run()
    }
    log.Println("gen model finish!!!")

    return nil
}

func (t *TableToStruct) dialMysql() {
    if t.db == nil {
        if t.dsn == "" {
            t.err = errors.New("dsn数据库配置缺失")
            return
        }
        t.db, t.err = sql.Open("mysql", t.dsn)
    }
    return
}

type column struct {
    ColumnName    string
    Type          string
    Nullable      string
    TableName     string
    ColumnComment string
    Tag           string
}

// Function for fetching schema definition of passed table
func (t *TableToStruct) getColumns(table ...string) (tableColumns map[string][]column, err error) {
    // 根据设置,判断是否要把 date 相关字段替换为 string
    if t.dateToTime == false {
        typeForMysqlToGo["date"] = "string"
        typeForMysqlToGo["datetime"] = "string"
        typeForMysqlToGo["timestamp"] = "string"
        typeForMysqlToGo["time"] = "string"
    }
    tableColumns = make(map[string][]column)
    // sql
    var sqlStr = `SELECT COLUMN_NAME,DATA_TYPE,IS_NULLABLE,TABLE_NAME,COLUMN_COMMENT
        FROM information_schema.COLUMNS 
        WHERE table_schema = DATABASE()`
    // 是否指定了具体的table
    if t.table != "" {
        sqlStr += fmt.Sprintf(" AND TABLE_NAME = '%s'", t.prefix+t.table)
    }
    // sql排序
    sqlStr += " order by TABLE_NAME asc, ORDINAL_POSITION asc"

    rows, err := t.db.Query(sqlStr)
    if err != nil {
        log.Println("Error reading table information: ", err.Error())
        return
    }

    defer rows.Close()

    for rows.Next() {
        col := column{}
        err = rows.Scan(&col.ColumnName, &col.Type, &col.Nullable, &col.TableName, &col.ColumnComment)

        if err != nil {
            log.Println(err.Error())
            return
        }

        //col.Json = strings.ToLower(col.ColumnName)
        col.Tag = col.ColumnName
        col.ColumnName = t.camelCase(col.ColumnName)
        col.Type = typeForMysqlToGo[col.Type]
        jsonTag := col.Tag
        // 字段首字母本身大写, 是否需要删除tag
        if t.config.RmTagIfUcFirsted &&
            col.ColumnName[0:1] == strings.ToUpper(col.ColumnName[0:1]) {
            col.Tag = "-"
        } else {
            // 是否需要将tag转换成小写
            if t.config.TagToLower {
                col.Tag = strings.ToLower(col.Tag)
                jsonTag = col.Tag
            }

            if t.config.JsonTagToHump {
                jsonTag = t.camelCase(jsonTag)
            }
            if t.config.JsonTagToFirstLow {
                jsonTag = FirstLower(jsonTag)
            }

            //if col.Nullable == "YES" {
            //    col.Json = fmt.Sprintf("`json:\"%s,omitempty\"`", col.Json)
            //} else {
            //}
        }
        if t.tagKey == "" {
            t.tagKey = "orm"
        } else if t.tagKey == "gorm" {
            col.Tag = "column:" + col.Tag
        }
        if t.enableJsonTag {
            //col.Json = fmt.Sprintf("`json:\"%s\" %s:\"%s\"`", col.Json, t.config.TagKey, col.Json)
            if t.enableFormTag {
                col.Tag = fmt.Sprintf("`%s:\"%s\" json:\"%s\" form:\"%s\"`", t.tagKey, col.Tag, jsonTag, jsonTag)
            } else {
                col.Tag = fmt.Sprintf("`%s:\"%s\" json:\"%s\"`", t.tagKey, col.Tag, jsonTag)
            }
        } else {
            col.Tag = fmt.Sprintf("`%s:\"%s\"`", t.tagKey, col.Tag)
        }
        //columns = append(columns, col)
        if _, ok := tableColumns[col.TableName]; !ok {
            tableColumns[col.TableName] = []column{}
        }
        tableColumns[col.TableName] = append(tableColumns[col.TableName], col)
    }
    return
}

func (t *TableToStruct) camelCase(str string) string {
    // 是否有表前缀, 设置了就先去除表前缀
    if t.prefix != "" {
        str = strings.Replace(str, t.prefix, "", 1)
    }
    var text string
    //for _, p := range strings.Split(name, "_") {
    for _, p := range strings.Split(str, "_") {
        // 字段首字母大写的同时, 是否要把其他字母转换为小写
        switch len(p) {
        case 0:
        case 1:
            text += strings.ToUpper(p[0:1])
        default:
            // 字符长度大于1时
            if t.config.UcFirstOnly == true {
                text += strings.ToUpper(p[0:1]) + strings.ToLower(p[1:])
            } else {
                text += strings.ToUpper(p[0:1]) + p[1:]
            }
        }
    }
    return text
}
func tab(depth int) string {
    return strings.Repeat("\t", depth)
}

// FirstUpper 字符串首字母大写
func FirstUpper(s string) string {
    if s == "" {
        return ""
    }
    return strings.ToUpper(s[:1]) + s[1:]
}

// FirstLower 字符串首字母小写
func FirstLower(s string) string {
    if s == "" {
        return ""
    }
    return strings.ToLower(s[:1]) + s[1:]
}

目录结构截图:
golang sql 转 struct 插件实现_第1张图片

转换的struct:

package models

type Region struct {
    Id         int64  `gorm:"column:id" json:"id"`                  // 主键
    RegionName string `gorm:"column:region_name" json:"regionName"` // 区域
    Timezone   string `gorm:"column:timezone" json:"timezone"`      // 时区
    CreateTime string `gorm:"column:create_time" json:"createTime"`
    UpdateTime string `gorm:"column:update_time" json:"updateTime"`
}

你可能感兴趣的:(go)