写go curd的时候,经常需要sql 转 struct,比较麻烦,写个自动转换的代码:
main.go
package main
import (
"convert/convert"
"flag"
"fmt"
"log"
"os"
"path"
)
const VERSION = "1.0.0"
const VersionText = "Convert of mysql schema to golang struct"
var saveFilePath string
func init() {
pa, err := os.Getwd()
if err != nil {
fmt.Println("获取运行文件路径获取失败")
return
}
saveFilePath = path.Join(pa, "models")
if isexists, _ := PathExists(saveFilePath); !isexists {
os.Mkdir(saveFilePath, os.ModePerm)
}
}
func main() {
connStr :="root:[email protected](127.0.0.1:3306)/test?charset=utf8mb4&parseTime=true"
dsn := flag.String("dsn", connStr, "连接数据库字符串")
file := flag.String("file", saveFilePath, "保存路径")
table := flag.String("table", "", "要生成的表名") //不填默认所有表
realNameMethod := flag.String("realNameMethod", "TableName", "结构体对应的表名")
packageName := flag.String("packageName", "", "生成struct的包名(默认为空的话, 则取名为: package models)")
tagKey := flag.String("tagKey", "gorm", "字段tag的key")
prefix := flag.String("prefix", "", "表前缀")
version := flag.Bool("version", false, "版本号")
v := flag.Bool("v", false, "版本号")
enableJsonTag := flag.Bool("enableJsonTag", true, "是否添加json的tag,默认false")
enableFormTag := flag.Bool("enableFormTag", false, "是否添加form的tag,默认false")
// 版本号
if *version || *v {
fmt.Println(fmt.Sprintf("\n version: %s\n %s\n using -h param for more help \n", VERSION, VersionText))
return
}
// 初始化
t2t := convert.NewTable2Struct()
// 个性化配置
t2t.Config(&convert.T2tConfig{
// json tag是否转为驼峰(大驼峰式),默认为false,不转换
JsonTagToHump: true, //false=>json tag: request_method true=> json tag RequestMethod
JsonTagToFirstLow: true, // json tag首字母是否转换小写, 默认true => requestMethod
// 结构体名称是否转为驼峰式,默认为false
StructNameToHump: true,
// 如果字段首字母本来就是大写, 就不添加tag, 默认false添加, true不添加
RmTagIfUcFirsted: false,
// tag的字段名字首字母是否转换为小写, 如果本身有大写字母的话, 默认false不转
TagToLower: false,
// 字段首字母大写的同时, 是否要把其他字母转换为小写,默认false不转换
UcFirstOnly: false,
// 每个struct放入单独的文件,默认false,放入同一个文件
SeperatFile: false,
})
// 开始迁移转换
err := t2t.
// 指定某个表,如果不指定,则默认全部表都迁移
Table(*table).
// 表前缀
Prefix(*prefix).
// 是否添加json tag
EnableJsonTag(*enableJsonTag).
EnableFormTag(*enableFormTag).
// 生成struct的包名(默认为空的话, 则取名为: package model)
PackageName(*packageName).
// tag字段的key值,默认是gorm
TagKey(*tagKey).
// 是否添加结构体方法获取表名
RealNameMethod(*realNameMethod).
// 生成的结构体保存路径
SavePath(*file).
// 数据库dsn
Dsn(*dsn).
// 执行
Run()
if err != nil {
log.Println(err.Error())
}
}
// PathExists 判断所给路径文件/文件夹是否存在
func PathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
//isnotexist来判断,是不是不存在的错误
if os.IsNotExist(err) { //如果返回的错误类型使用os.isNotExist()判断为true,说明文件或者文件夹不存在
return false, nil
}
return false, err //如果有错误了,但是不是不存在的错误,所以把这个错误原封不动的返回
}
tableTostruct.go 代码:
package convert
import (
"database/sql"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
"log"
"os"
"os/exec"
"path"
"strings"
)
//map for converting mysql type to golang types
var typeForMysqlToGo = map[string]string{
"int": "int64",
"integer": "int64",
"tinyint": "int64",
"smallint": "int64",
"mediumint": "int64",
"bigint": "int64",
"int unsigned": "int64",
"integer unsigned": "int64",
"tinyint unsigned": "int64",
"smallint unsigned": "int64",
"mediumint unsigned": "int64",
"bigint unsigned": "int64",
"bit": "int64",
"bool": "bool",
"enum": "string",
"set": "string",
"varchar": "string",
"char": "string",
"tinytext": "string",
"mediumtext": "string",
"text": "string",
"longtext": "string",
"blob": "string",
"tinyblob": "string",
"mediumblob": "string",
"longblob": "string",
"date": "time.Time", // time.Time or string
"datetime": "time.Time", // time.Time or string
"timestamp": "time.Time", // time.Time or string
"time": "time.Time", // time.Time or string
"float": "float64",
"double": "float64",
"decimal": "float64",
"binary": "string",
"varbinary": "string",
}
type TableToStruct struct {
dsn string
savePath string
db *sql.DB
table string
prefix string
config *T2tConfig
err error
realNameMethod string
enableJsonTag bool // 是否添加json的tag, 默认不添加
enableFormTag bool // 是否添加form的tag, 默认不添加
packageName string // 生成struct的包名(默认为空的话, 则取名为: package model)
tagKey string // tag字段的key值,默认是orm
dateToTime bool // 是否将 date相关字段转换为 time.Time,默认否
}
type T2tConfig struct {
StructNameToHump bool // 结构体名称是否转为驼峰式,默认为false
RmTagIfUcFirsted bool // 如果字段首字母本来就是大写, 就不添加tag, 默认false添加, true不添加
TagToLower bool // tag的字段名字是否转换为小写, 如果本身有大写字母的话, 默认false不转
JsonTagToHump bool // json tag是否转为驼峰(大驼峰式),默认为false,不转换
JsonTagToFirstLow bool // json tag 首字母是否转换小写
UcFirstOnly bool // 字段首字母大写的同时, 是否要把其他字母转换为小写,默认false不转换
SeperatFile bool // 每个struct放入单独的文件,默认false,放入同一个文件
}
func NewTable2Struct() *TableToStruct {
return &TableToStruct{}
}
func (t *TableToStruct) Dsn(d string) *TableToStruct {
t.dsn = d
return t
}
func (t *TableToStruct) TagKey(r string) *TableToStruct {
t.tagKey = r
return t
}
func (t *TableToStruct) PackageName(r string) *TableToStruct {
t.packageName = r
return t
}
func (t *TableToStruct) RealNameMethod(r string) *TableToStruct {
t.realNameMethod = r
return t
}
func (t *TableToStruct) SavePath(p string) *TableToStruct {
t.savePath = p
return t
}
func (t *TableToStruct) DB(d *sql.DB) *TableToStruct {
t.db = d
return t
}
func (t *TableToStruct) Table(tab string) *TableToStruct {
t.table = tab
return t
}
func (t *TableToStruct) Prefix(p string) *TableToStruct {
t.prefix = p
return t
}
func (t *TableToStruct) EnableJsonTag(p bool) *TableToStruct {
t.enableJsonTag = p
return t
}
func (t *TableToStruct) EnableFormTag(p bool) *TableToStruct {
t.enableFormTag = p
return t
}
func (t *TableToStruct) DateToTime(d bool) *TableToStruct {
t.dateToTime = d
return t
}
func (t *TableToStruct) Config(c *T2tConfig) *TableToStruct {
t.config = c
return t
}
// Run 生成逻辑
func (t *TableToStruct) Run() error {
if t.config == nil {
t.config = new(T2tConfig)
}
// 链接mysql, 获取db对象
t.dialMysql()
if t.err != nil {
return t.err
}
// 获取表和字段的shcema
tableColumns, err := t.getColumns()
if err != nil {
return err
}
// 包名
var packageName string
if t.packageName == "" {
packageName = "package models\n\n"
} else {
packageName = fmt.Sprintf("package %s\n\n", t.packageName)
}
// 组装struct
var structContent string
for tableRealName, item := range tableColumns {
// 去除前缀
if t.prefix != "" {
tableRealName = tableRealName[len(t.prefix):]
}
tableName := tableRealName
structName := tableName
if t.config.StructNameToHump {
structName = t.camelCase(structName)
}
switch len(tableName) {
case 0:
case 1:
tableName = strings.ToUpper(tableName[0:1])
default:
// 字符长度大于1时
tableName = strings.ToUpper(tableName[0:1]) + tableName[1:]
}
depth := 1
structContent += "type " + structName + " struct {\n"
for _, v := range item {
//structContent += tab(depth) + v.ColumnName + " " + v.Type + " " + v.Json + "\n"
// 字段注释
var clumnComment string
if v.ColumnComment != "" {
clumnComment = fmt.Sprintf(" // %s", v.ColumnComment)
}
structContent += fmt.Sprintf("%s%s %s %s%s\n",
tab(depth), v.ColumnName, v.Type, v.Tag, clumnComment)
}
structContent += tab(depth-1) + "}\n\n"
// 添加 method 获取真实表名
if t.realNameMethod != "" {
structContent += fmt.Sprintf("func (*%s) %s() string {\n",
structName, t.realNameMethod)
structContent += fmt.Sprintf("%sreturn \"%s\"\n",
tab(depth), tableRealName)
structContent += "}\n\n"
}
//如果为 true,每个表按照 struct 分开存放一个文件
if t.config.SeperatFile {
// 如果有引入 time.Time, 则需要引入 time 包
var importContent string
if strings.Contains(structContent, "time.Time") {
importContent = "import \"time\"\n\n"
}
// 写入文件struct
var savePath = t.savePath
savePath = path.Join(savePath, tableRealName+".go")
filePath := fmt.Sprintf("%s", savePath)
f, err := os.Create(filePath)
if err != nil {
log.Println("Can not write file")
return err
}
defer f.Close()
f.WriteString(packageName + importContent + structContent)
cmd := exec.Command("gofmt", "-w", filePath)
cmd.Run()
structContent = ""
}
}
// false, 多个表放入同一个文件
if t.config.SeperatFile == false {
// 如果有引入 time.Time, 则需要引入 time 包
var importContent string
if strings.Contains(structContent, "time.Time") {
importContent = "import \"time\"\n\n"
}
// 写入文件struct
var savePath = t.savePath
if t.table != "" {
savePath = path.Join(savePath, t.table+".go")
} else {
savePath = path.Join(savePath, "models.go")
}
filePath := fmt.Sprintf("%s", savePath)
f, err := os.Create(filePath)
if err != nil {
log.Println("Can not write file")
return err
}
defer f.Close()
f.WriteString(packageName + importContent + structContent)
cmd := exec.Command("gofmt", "-w", filePath)
cmd.Run()
}
log.Println("gen model finish!!!")
return nil
}
func (t *TableToStruct) dialMysql() {
if t.db == nil {
if t.dsn == "" {
t.err = errors.New("dsn数据库配置缺失")
return
}
t.db, t.err = sql.Open("mysql", t.dsn)
}
return
}
type column struct {
ColumnName string
Type string
Nullable string
TableName string
ColumnComment string
Tag string
}
// Function for fetching schema definition of passed table
func (t *TableToStruct) getColumns(table ...string) (tableColumns map[string][]column, err error) {
// 根据设置,判断是否要把 date 相关字段替换为 string
if t.dateToTime == false {
typeForMysqlToGo["date"] = "string"
typeForMysqlToGo["datetime"] = "string"
typeForMysqlToGo["timestamp"] = "string"
typeForMysqlToGo["time"] = "string"
}
tableColumns = make(map[string][]column)
// sql
var sqlStr = `SELECT COLUMN_NAME,DATA_TYPE,IS_NULLABLE,TABLE_NAME,COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE table_schema = DATABASE()`
// 是否指定了具体的table
if t.table != "" {
sqlStr += fmt.Sprintf(" AND TABLE_NAME = '%s'", t.prefix+t.table)
}
// sql排序
sqlStr += " order by TABLE_NAME asc, ORDINAL_POSITION asc"
rows, err := t.db.Query(sqlStr)
if err != nil {
log.Println("Error reading table information: ", err.Error())
return
}
defer rows.Close()
for rows.Next() {
col := column{}
err = rows.Scan(&col.ColumnName, &col.Type, &col.Nullable, &col.TableName, &col.ColumnComment)
if err != nil {
log.Println(err.Error())
return
}
//col.Json = strings.ToLower(col.ColumnName)
col.Tag = col.ColumnName
col.ColumnName = t.camelCase(col.ColumnName)
col.Type = typeForMysqlToGo[col.Type]
jsonTag := col.Tag
// 字段首字母本身大写, 是否需要删除tag
if t.config.RmTagIfUcFirsted &&
col.ColumnName[0:1] == strings.ToUpper(col.ColumnName[0:1]) {
col.Tag = "-"
} else {
// 是否需要将tag转换成小写
if t.config.TagToLower {
col.Tag = strings.ToLower(col.Tag)
jsonTag = col.Tag
}
if t.config.JsonTagToHump {
jsonTag = t.camelCase(jsonTag)
}
if t.config.JsonTagToFirstLow {
jsonTag = FirstLower(jsonTag)
}
//if col.Nullable == "YES" {
// col.Json = fmt.Sprintf("`json:\"%s,omitempty\"`", col.Json)
//} else {
//}
}
if t.tagKey == "" {
t.tagKey = "orm"
} else if t.tagKey == "gorm" {
col.Tag = "column:" + col.Tag
}
if t.enableJsonTag {
//col.Json = fmt.Sprintf("`json:\"%s\" %s:\"%s\"`", col.Json, t.config.TagKey, col.Json)
if t.enableFormTag {
col.Tag = fmt.Sprintf("`%s:\"%s\" json:\"%s\" form:\"%s\"`", t.tagKey, col.Tag, jsonTag, jsonTag)
} else {
col.Tag = fmt.Sprintf("`%s:\"%s\" json:\"%s\"`", t.tagKey, col.Tag, jsonTag)
}
} else {
col.Tag = fmt.Sprintf("`%s:\"%s\"`", t.tagKey, col.Tag)
}
//columns = append(columns, col)
if _, ok := tableColumns[col.TableName]; !ok {
tableColumns[col.TableName] = []column{}
}
tableColumns[col.TableName] = append(tableColumns[col.TableName], col)
}
return
}
func (t *TableToStruct) camelCase(str string) string {
// 是否有表前缀, 设置了就先去除表前缀
if t.prefix != "" {
str = strings.Replace(str, t.prefix, "", 1)
}
var text string
//for _, p := range strings.Split(name, "_") {
for _, p := range strings.Split(str, "_") {
// 字段首字母大写的同时, 是否要把其他字母转换为小写
switch len(p) {
case 0:
case 1:
text += strings.ToUpper(p[0:1])
default:
// 字符长度大于1时
if t.config.UcFirstOnly == true {
text += strings.ToUpper(p[0:1]) + strings.ToLower(p[1:])
} else {
text += strings.ToUpper(p[0:1]) + p[1:]
}
}
}
return text
}
func tab(depth int) string {
return strings.Repeat("\t", depth)
}
// FirstUpper 字符串首字母大写
func FirstUpper(s string) string {
if s == "" {
return ""
}
return strings.ToUpper(s[:1]) + s[1:]
}
// FirstLower 字符串首字母小写
func FirstLower(s string) string {
if s == "" {
return ""
}
return strings.ToLower(s[:1]) + s[1:]
}
转换的struct:
package models
type Region struct {
Id int64 `gorm:"column:id" json:"id"` // 主键
RegionName string `gorm:"column:region_name" json:"regionName"` // 区域
Timezone string `gorm:"column:timezone" json:"timezone"` // 时区
CreateTime string `gorm:"column:create_time" json:"createTime"`
UpdateTime string `gorm:"column:update_time" json:"updateTime"`
}