本节内容的代码都放在/rpc/database/下,目录结构如下:
├─gorm
├─sql
│ └─user
├─sqlc
└─sqlx
参考:goctl model mysql 指令
注意:虽然go-zero的goctl model mysql 指令支持从 sql 文件,数据库连接两个来源生成代码,两者生成的代码是完全一样的。但是我个人比较推荐根据sql文件生成,因为可以记录sql文件的变化。
注意:最后的参数-style=go_zero是指定生成文件名称的格式,这里是蛇形命名,不喜欢的可以去除这个参数。
单表:
goctl model mysql ddl -src="./rpc/database/sql/user/zero_users.sql" -dir="./rpc/database/sqlx/usermodel" -style=go_zero
多表:
goctl model mysql ddl -src="./rpc/database/sql/user/zero_*.sql" -dir="./rpc/database/sqlx/usermodel" -style=go_zero
-src
:sql
文件目录;
-dir
:sqlx
代码目录;
goctl model mysql datasource -url="root:root@tcp(127.0.0.1:3357)/go-zero-micro" -table="zero_users" -dir="./rpc/database/sqlx/usermodel"
-url
:数据库连接;
-table
:数据表;
-dir
:sqlx
代码目录;
同生成sqlx代码的命令类似,只是后面需要再加一个 -cache即可。
单表:
goctl model mysql ddl -src="./rpc/database/sql/user/zero_users.sql" -dir="./rpc/database/sqlc/usermodel" -style=go_zero -cache
多表:
goctl model mysql ddl -src="./rpc/database/sql/user/zero_*.sql" -dir="./rpc/database/sqlc/usermodel" -style=go_zero -cache
-src
:sql
文件目录;
-dir
:sqlx
代码目录;
goctl model mysql datasource -url="root:root@tcp(127.0.0.1:3357)/go-zero-micro" -table="zero_users" -dir="./rpc/database/sqlc/usermodel" -cache
-url
:数据库连接;
-table
:数据表;
-dir
:sqlx
代码目录;
通过 2.1的命令生成的sqlx代码有三个文件:
主要代码都在 zerousersmodel.go,这里使用了反射对拼接的sql语句进行了优化:
注意:其实自定义的操作接口应该都加入context参数,便于链路追踪,这一点已在该分支最新提交的代码中补上。
package usermodel
import (
"context"
"database/sql"
"fmt"
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/common/utils"
"reflect"
"strings"
"time"
)
var _ ZeroUsersModel = (*customZeroUsersModel)(nil)
type (
// ZeroUsersModel is an interface to be customized, add more methods here,
// and implement the added methods in customZeroUsersModel.
ZeroUsersModel interface {
zeroUsersModel
Trans(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error
Count(data *ZeroUsers, beginTime, endTime string) (int64, error)
FindPageListByParam(data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error)
FindAllByParam(data *ZeroUsers) ([]*ZeroUsers, error)
FindOneByParam(data *ZeroUsers) (*ZeroUsers, error)
Save(ctx context.Context, data *ZeroUsers) (sql.Result, error)
Edit(ctx context.Context, data *ZeroUsers) (sql.Result, error)
DeleteData(ctx context.Context, data *ZeroUsers) error
}
customZeroUsersModel struct {
*defaultZeroUsersModel
}
)
func (c customZeroUsersModel) Trans(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
return c.conn.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
return fn(ctx, session)
})
}
/*
*
根据条件拼接的sql
*/
func userSqlJoins(queryModel *ZeroUsers) string {
typ := reflect.TypeOf(queryModel).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(queryModel).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
sql := ""
for i := 0; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
colName := typ.Field(i).Tag.Get("db")
if colType == "int64" {
if Field.Int() > 0 {
sql += fmt.Sprintf(" AND %s=%d", colName, Field.Int())
}
} else if colType == "string" {
if Field.String() != "" {
sql += fmt.Sprintf(" AND %s LIKE %s", colName, "'%"+Field.String()+"%'")
}
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
sql += fmt.Sprintf(" AND %s='%s'", colName, Field.String())
}
}
}
return sql
}
func (c customZeroUsersModel) Count(data *ZeroUsers, beginTime, endTime string) (int64, error) {
sql := fmt.Sprintf("SELECT count(*) as count FROM %s WHERE deleted_flag = %d", c.table, utils.DelNo)
joinSql := userSqlJoins(data)
beginTimeSql := ""
if beginTime != "" {
beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
}
endTimeSql := ""
if endTime != "" {
endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
}
sql = sql + joinSql + beginTimeSql + endTimeSql
var count int64
err := c.conn.QueryRow(&count, sql)
switch err {
case nil:
return count, nil
case sqlc.ErrNotFound:
return 0, ErrNotFound
default:
return 0, err
}
}
func (c customZeroUsersModel) FindPageListByParam(data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error) {
sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := userSqlJoins(data)
beginTimeSql := ""
if beginTime != "" {
beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
}
endTimeSql := ""
if endTime != "" {
endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
}
orderSql := " ORDER BY created_at DESC"
limitSql := fmt.Sprintf(" LIMIT %d,%d", (current-1)*pageSize, pageSize)
sql = sql + joinSql + beginTimeSql + endTimeSql + orderSql + limitSql
var result []*ZeroUsers
err := c.conn.QueryRows(&result, sql)
switch err {
case nil:
return result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) FindAllByParam(data *ZeroUsers) ([]*ZeroUsers, error) {
sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := userSqlJoins(data)
orderSql := " ORDER BY created_at DESC"
sql = sql + joinSql + orderSql
var result []*ZeroUsers
err := c.conn.QueryRows(&result, sql)
switch err {
case nil:
return result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) FindOneByParam(data *ZeroUsers) (*ZeroUsers, error) {
sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := userSqlJoins(data)
orderSql := " ORDER BY created_at DESC"
sql = sql + joinSql + orderSql
var result ZeroUsers
err := c.conn.QueryRow(&result, sql)
switch err {
case nil:
return &result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) Save(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
values := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("%d,", Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
values = strings.TrimRight(values, ",")
saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)
result, err := c.conn.ExecCtx(ctx, saveSql)
return result, err
}
func (c customZeroUsersModel) Edit(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(utils.DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", c.table, utils.DelNo, names, data.Id)
result, err := c.conn.ExecCtx(ctx, sql)
return result, err
}
func (c customZeroUsersModel) DeleteData(ctx context.Context, data *ZeroUsers) error {
UpdateTime := data.UpdatedAt.Format(utils.DateTimeFormat)
sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d,deleted_at= %s WHERE id = %d", c.table, utils.DelYes, "'"+UpdateTime+"'", data.Id)
_, err := c.conn.ExecCtx(ctx, sql)
return err
}
// NewZeroUsersModel returns a model for the database table.
func NewZeroUsersModel(conn sqlx.SqlConn) ZeroUsersModel {
return &customZeroUsersModel{
defaultZeroUsersModel: newZeroUsersModel(conn),
}
}
MySQL:
#本地数据库
DataSource: root:root@tcp(127.0.0.1:3357)/go-zero-micro?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
package config
import "github.com/zeromicro/go-zero/zrpc"
type Config struct {
zrpc.RpcServerConf
JWT struct {
AccessSecret string
AccessExpire int64
}
MySQL struct {
DataSource string
}
UploadFile UploadFile
}
type UploadFile struct {
MaxFileNum int64
MaxFileSize int64
SavePath string
}
package svc
import (
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/rpc/code/ucenter/internal/config"
"go-zero-micro/rpc/database/sqlx/usermodel"
)
type ServiceContext struct {
Config config.Config
UsersModel usermodel.ZeroUsersModel
}
func NewServiceContext(c config.Config) *ServiceContext {
mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)
return &ServiceContext{
Config: c,
UsersModel: usermodel.NewZeroUsersModel(mysqlConn),
}
}
package ucentersqlxlogic
import (
"context"
"errors"
"fmt"
"go-zero-micro/common/utils"
"go-zero-micro/rpc/database/sqlx/usermodel"
"time"
"go-zero-micro/rpc/code/ucenter/internal/svc"
"go-zero-micro/rpc/code/ucenter/ucenter"
"github.com/jinzhu/copier"
"github.com/zeromicro/go-zero/core/logx"
)
type LoginUserLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
return &LoginUserLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// LoginUser 用户登录
func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
param := &usermodel.ZeroUsers{
Account: in.Account,
}
dbRes, err := l.svcCtx.UsersModel.FindOneByParam(param)
if err != nil {
logx.Error(err)
errInfo := fmt.Sprintf("LoginUser:FindOneByParam:db err:%v , in : %+v", err, in)
return nil, errors.New(errInfo)
}
if utils.ComparePassword(in.Password, dbRes.Password) {
copier.Copy(in, dbRes)
return l.LoginSuccess(in)
} else {
errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
return nil, errors.New(errInfo)
}
}
func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
AccessSecret := l.svcCtx.Config.JWT.AccessSecret
AccessExpire := l.svcCtx.Config.JWT.AccessExpire
now := time.Now().Unix()
jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
if err != nil {
return nil, err
}
resp := &ucenter.UserLoginResp{}
copier.Copy(resp, in)
resp.AccessToken = jwtToken
resp.AccessExpire = now + AccessExpire
resp.RefreshAfter = now + AccessExpire/2
return resp, nil
}
提示:本示例对前面的代码有较大优化
注意:
//接口
TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error
//实现
func (c customZeroUsersModel) TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
return c.conn.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
return fn(ctx, session)
})
}
可以发现,与没有事务特性的插入相比只是更改了操作的调用者为session。
//接口,有session参数
TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error)
//实现
func (c customZeroUsersModel) TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error) {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
values := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("%d,", Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
values = strings.TrimRight(values, ",")
saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)
//result, err := c.conn.ExecCtx(ctx, saveSql)
//return result, err
result, err := session.ExecCtx(ctx, saveSql)
return result, err
}
代码示例在 internal/logic/ucentersqlx/adduserlogic.go中:
package ucentersqlxlogic
import (
"context"
"github.com/jinzhu/copier"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/common/errorx"
"go-zero-micro/common/utils"
"go-zero-micro/rpc/database/sqlx/usermodel"
"time"
"go-zero-micro/rpc/code/ucenter/internal/svc"
"go-zero-micro/rpc/code/ucenter/ucenter"
"github.com/zeromicro/go-zero/core/logx"
)
type AddUserLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewAddUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AddUserLogic {
return &AddUserLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// AddUser 添加用户
func (l *AddUserLogic) AddUser(in *ucenter.User) (*ucenter.BaseResp, error) {
userId := utils.GetUidFromCtxInt64(l.ctx, "userId")
currentTime := time.Now()
/**
1、需求逻辑:User表保存账号信息,UserInfo表是子表,保存关联信息,比如:邮箱、手机号等
2、代码逻辑:先插入User表,后插入UserInfo表数据,插入UserInfo表时需要获取User表插入的id
3、无事务特性时:可能会出现主表有数据,但子表无数据的情况,导致数据不一致
*/
var InsertUserId int64
//将对主子表的操作全部放到同一个事务中,每一步操作有错误就返回错误,没有错误最后就返回nil,事务遇到错误会回滚;
if err := l.svcCtx.UsersModel.TransCtx(l.ctx, func(context context.Context, session sqlx.Session) error {
userParam := &usermodel.ZeroUsers{}
copier.Copy(userParam, in)
userParam.Password = utils.GeneratePassword(l.svcCtx.Config.DefaultConfig.DefaultPassword)
userParam.CreatedBy = userId
userParam.CreatedAt = currentTime
dbUserRes, err := l.svcCtx.UsersModel.TransSaveCtx(l.ctx, session, userParam)
if err != nil {
return err
}
uid, err := dbUserRes.LastInsertId()
if err != nil {
return err
}
userInfoParam := &usermodel.ZeroUserInfos{}
copier.Copy(userInfoParam, in)
userInfoParam.UserId = uid
userInfoParam.CreatedBy = userId
userInfoParam.CreatedAt = currentTime
_, err = l.svcCtx.UserInfosModel.TransSaveCtx(l.ctx, session, userInfoParam)
if err != nil {
return err
}
InsertUserId = uid
return nil
}); err != nil {
return nil, errorx.NewDefaultError(errorx.DbAddErrorCode)
}
return &ucenter.BaseResp{
Id: InsertUserId,
}, nil
}
// QuerySqlJoins 根据查询条件拼接sql,使用泛型更加通用
func QuerySqlJoins[T any](data *T) string {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
sql := ""
for i := 0; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
colName := typ.Field(i).Tag.Get("db")
if colType == "int64" {
if Field.Int() > 0 {
sql += fmt.Sprintf(" AND %s=%d", colName, Field.Int())
}
} else if colType == "string" {
if Field.String() != "" {
sql += fmt.Sprintf(" AND %s LIKE %s", colName, "'%"+Field.String()+"%'")
}
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
sql += fmt.Sprintf(" AND %s='%s'", colName, Field.String())
}
}
}
return sql
}
// SaveSqlJoins 根据实际参数拼接sql,使用泛型更加通用
func SaveSqlJoins[T any](data *T, table string) string {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
values := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
//if Field.Int() > 0 {
// names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
// values += fmt.Sprintf("%d,", Field.Int())
//}
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("%d,", Field.Int())
} else if colType == "string" {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", value.Format(DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
values = strings.TrimRight(values, ",")
sql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", table, names, values)
return sql
}
// EditSqlJoins 根据实际参数拼接sql,使用泛型更加通用
func EditSqlJoins[T any](data *T, table string, Id int64) string {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", table, DelNo, names, Id)
return sql
}
//原调用的查询拼接sql:
//joinSql := userSqlJoins(data)
//新调用的查询拼接sql:
joinSql := utils.QuerySqlJoins(data)
原sql:
func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
values := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("%d,", Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
values = strings.TrimRight(values, ",")
saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)
saveSql := utils.SaveSqlJoins(data, c.table)
result, err := c.conn.ExecCtx(ctx, saveSql)
return result, err
}
新sql:
func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
saveSql := utils.SaveSqlJoins(data, c.table)
result, err := c.conn.ExecCtx(ctx, saveSql)
return result, err
}
原sql
func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
typ := reflect.TypeOf(data).Elem() //指针类型需要加 Elem()
val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
fieldNum := val.NumField()
names := ""
for i := 1; i < fieldNum; i++ {
Field := val.Field(i)
colType := Field.Type().String()
if colType == "int64" {
if Field.Int() > 0 {
names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
}
} else if colType == "string" {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
} else if colType == "time.Time" {
value := Field.Interface().(time.Time)
if !value.IsZero() {
names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(utils.DateTimeFormat))
}
}
}
names = strings.TrimRight(names, ",")
sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", c.table, utils.DelNo, names, data.Id)
result, err := c.conn.ExecCtx(ctx, sql)
editSql := utils.EditSqlJoins(data, c.table, data.Id)
result, err := c.conn.ExecCtx(ctx, editSql)
return result, err
}
新sql
func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
editSql := utils.EditSqlJoins(data, c.table, data.Id)
result, err := c.conn.ExecCtx(ctx, editSql)
return result, err
}
以zero_users数据表为例:
参考上文
注意:这里只有FindOneByParamCtx、EditCtx、DeleteDataCtx使用了缓存。
package usermodel
import (
"database/sql"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/common/utils"
"golang.org/x/net/context"
)
var _ ZeroUsersModel = (*customZeroUsersModel)(nil)
type (
// ZeroUsersModel is an interface to be customized, add more methods here,
// and implement the added methods in customZeroUsersModel.
ZeroUsersModel interface {
zeroUsersModel
TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error
CountCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string) (int64, error)
FindPageListByParamCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error)
FindAllByParamCtx(ctx context.Context, data *ZeroUsers) ([]*ZeroUsers, error)
FindOneByParamCtx(ctx context.Context, data *ZeroUsers) (*ZeroUsers, error)
SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error)
EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error)
DeleteDataCtx(ctx context.Context, data *ZeroUsers) error
TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error)
}
customZeroUsersModel struct {
*defaultZeroUsersModel
}
)
func (c customZeroUsersModel) TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
return c.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
return fn(ctx, session)
})
}
func (c customZeroUsersModel) CountCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string) (int64, error) {
querySql := fmt.Sprintf("SELECT count(*) as count FROM %s WHERE deleted_flag = %d", c.table, utils.DelNo)
joinSql := utils.QuerySqlJoins(data)
beginTimeSql := ""
if beginTime != "" {
beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
}
endTimeSql := ""
if endTime != "" {
endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
}
querySql = querySql + joinSql + beginTimeSql + endTimeSql
var count int64
err := c.QueryRowNoCacheCtx(ctx, &count, querySql)
switch err {
case nil:
return count, nil
case sqlc.ErrNotFound:
return 0, ErrNotFound
default:
return 0, err
}
}
func (c customZeroUsersModel) FindPageListByParamCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error) {
querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := utils.QuerySqlJoins(data)
beginTimeSql := ""
if beginTime != "" {
beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
}
endTimeSql := ""
if endTime != "" {
endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
}
orderSql := " ORDER BY created_at DESC"
limitSql := fmt.Sprintf(" LIMIT %d,%d", (current-1)*pageSize, pageSize)
querySql = querySql + joinSql + beginTimeSql + endTimeSql + orderSql + limitSql
var result []*ZeroUsers
err := c.QueryRowsNoCacheCtx(ctx, &result, querySql)
switch err {
case nil:
return result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) FindAllByParamCtx(ctx context.Context, data *ZeroUsers) ([]*ZeroUsers, error) {
querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := utils.QuerySqlJoins(data)
orderSql := " ORDER BY created_at DESC"
querySql = querySql + joinSql + orderSql
var result []*ZeroUsers
err := c.QueryRowsNoCacheCtx(ctx, &result, querySql)
switch err {
case nil:
return result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) FindOneByParamCtx(ctx context.Context, data *ZeroUsers) (*ZeroUsers, error) {
querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
joinSql := utils.QuerySqlJoins(data)
orderSql := " ORDER BY created_at DESC"
querySql = querySql + joinSql + orderSql
var result ZeroUsers
var err error
if data.Id > 0 {
zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
err = c.QueryRowCtx(ctx, &result, zeroUsersIdKey, func(ctx context.Context, conn sqlx.SqlConn, v any) error {
return conn.QueryRowCtx(ctx, v, querySql)
})
} else {
err = c.QueryRowNoCacheCtx(ctx, &result, querySql)
}
switch err {
case nil:
return &result, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
saveSql := utils.SaveSqlJoins(data, c.table)
//zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
//result, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
// return conn.ExecCtx(ctx, saveSql)
//}, zeroUsersIdKey)
result, err := c.ExecNoCacheCtx(ctx, saveSql)
return result, err
}
func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
editSql := utils.EditSqlJoins(data, c.table, data.Id)
zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
result, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
return conn.ExecCtx(ctx, editSql)
}, zeroUsersIdKey)
return result, err
}
func (c customZeroUsersModel) DeleteDataCtx(ctx context.Context, data *ZeroUsers) error {
UpdateTime := data.UpdatedAt.Format(utils.DateTimeFormat)
deleteSql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d,deleted_at= %s WHERE id = %d", c.table, utils.DelYes, "'"+UpdateTime+"'", data.Id)
zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
_, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
return conn.ExecCtx(ctx, deleteSql)
}, zeroUsersIdKey)
return err
}
func (c customZeroUsersModel) TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error) {
saveSql := utils.SaveSqlJoins(data, c.table)
//result, err := c.conn.ExecCtx(ctx, saveSql)
//return result, err
result, err := session.ExecCtx(ctx, saveSql)
return result, err
}
// NewZeroUsersModel returns a model for the database table.
func NewZeroUsersModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) ZeroUsersModel {
return &customZeroUsersModel{
defaultZeroUsersModel: newZeroUsersModel(conn, c, opts...),
}
}
CacheRedis:
- Host: 127.0.0.1:6379
Type: node
Pass: ""
package config
import (
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/zrpc"
)
type Config struct {
zrpc.RpcServerConf
JWT struct {
AccessSecret string
AccessExpire int64
}
MySQL struct {
DataSource string
}
CacheRedis cache.CacheConf
}
package svc
import (
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/rpc/code/ucenter/internal/config"
sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
sqlx_usermodel "go-zero-micro/rpc/database/sqlx/usermodel"
)
type ServiceContext struct {
Config config.Config
SqlxUsersModel sqlx_usermodel.ZeroUsersModel
SqlxUserInfosModel sqlx_usermodel.ZeroUserInfosModel
SqlcUsersModel sqlc_usermodel.ZeroUsersModel
SqlcUserInfosModel sqlc_usermodel.ZeroUserInfosModel
}
func NewServiceContext(c config.Config) *ServiceContext {
mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)
return &ServiceContext{
Config: c,
SqlxUsersModel: sqlx_usermodel.NewZeroUsersModel(mysqlConn),
SqlxUserInfosModel: sqlx_usermodel.NewZeroUserInfosModel(mysqlConn),
SqlcUsersModel: sqlc_usermodel.NewZeroUsersModel(mysqlConn, c.CacheRedis),
SqlcUserInfosModel: sqlc_usermodel.NewZeroUserInfosModel(mysqlConn, c.CacheRedis),
}
}
internal/logic/ucentersqlx/loginuserlogic.go
注意:在loginuserlogic.go中查询参数加入了Id,这样就可以测试缓存是否生效了。
package ucentersqlxlogic
import (
"context"
"errors"
"fmt"
"go-zero-micro/common/utils"
sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
"time"
"go-zero-micro/rpc/code/ucenter/internal/svc"
"go-zero-micro/rpc/code/ucenter/ucenter"
"github.com/jinzhu/copier"
"github.com/zeromicro/go-zero/core/logx"
)
type LoginUserLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
return &LoginUserLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// LoginUser 用户登录
func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
param := &sqlc_usermodel.ZeroUsers{
Id: 1, //测试缓存
Account: in.Account,
}
dbRes, err := l.svcCtx.SqlcUsersModel.FindOneByParamCtx(l.ctx, param)
if err != nil {
logx.Error(err)
errInfo := fmt.Sprintf("LoginUser:FindOneByParam:db err:%v , in : %+v", err, in)
return nil, errors.New(errInfo)
}
if utils.ComparePassword(in.Password, dbRes.Password) {
copier.Copy(in, dbRes)
return l.LoginSuccess(in)
} else {
errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
return nil, errors.New(errInfo)
}
}
func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
AccessSecret := l.svcCtx.Config.JWT.AccessSecret
AccessExpire := l.svcCtx.Config.JWT.AccessExpire
now := time.Now().Unix()
jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
if err != nil {
return nil, err
}
resp := &ucenter.UserLoginResp{}
copier.Copy(resp, in)
resp.AccessToken = jwtToken
resp.AccessExpire = now + AccessExpire
resp.RefreshAfter = now + AccessExpire/2
return resp, nil
}
internal/logic/ucentersqlx/adduserlogic.go
package ucentersqlxlogic
import (
"context"
"github.com/jinzhu/copier"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/common/errorx"
"go-zero-micro/common/utils"
sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
"time"
"go-zero-micro/rpc/code/ucenter/internal/svc"
"go-zero-micro/rpc/code/ucenter/ucenter"
"github.com/zeromicro/go-zero/core/logx"
)
type AddUserLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewAddUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AddUserLogic {
return &AddUserLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// AddUser 添加用户
func (l *AddUserLogic) AddUser(in *ucenter.User) (*ucenter.BaseResp, error) {
userId := utils.GetUidFromCtxInt64(l.ctx, "userId")
currentTime := time.Now()
/**
1、需求逻辑:User表保存账号信息,UserInfo表是子表,保存关联信息,比如:邮箱、手机号等
2、代码逻辑:先插入User表,后插入UserInfo表数据,插入UserInfo表时需要获取User表插入的id
3、无事务特性时:可能会出现主表有数据,但子表无数据的情况,导致数据不一致
*/
var InsertUserId int64
//将对主子表的操作全部放到同一个事务中,每一步操作有错误就返回错误,没有错误最后就返回nil,事务遇到错误会回滚;
if err := l.svcCtx.SqlcUsersModel.TransCtx(l.ctx, func(context context.Context, session sqlx.Session) error {
userParam := &sqlc_usermodel.ZeroUsers{}
copier.Copy(userParam, in)
userParam.Password = utils.GeneratePassword(l.svcCtx.Config.DefaultConfig.DefaultPassword)
userParam.CreatedBy = userId
userParam.CreatedAt = currentTime
dbUserRes, err := l.svcCtx.SqlcUsersModel.TransSaveCtx(l.ctx, session, userParam)
if err != nil {
return err
}
uid, err := dbUserRes.LastInsertId()
if err != nil {
return err
}
userInfoParam := &sqlc_usermodel.ZeroUserInfos{}
copier.Copy(userInfoParam, in)
userInfoParam.UserId = uid
userInfoParam.CreatedBy = userId
userInfoParam.CreatedAt = currentTime
_, err = l.svcCtx.SqlcUserInfosModel.TransSaveCtx(l.ctx, session, userInfoParam)
if err != nil {
return err
}
InsertUserId = uid
return nil
}); err != nil {
return nil, errorx.NewDefaultError(errorx.DbAddErrorCode)
}
return &ucenter.BaseResp{
Id: InsertUserId,
}, nil
}
总结:通过简单分析源码可以得出缓存的具体使用过程,其中设置指定key的数据为*的目的是防止雪崩击穿。
package usermodel
import (
"database/sql"
"time"
)
type (
ZeroUsers struct {
Id int64 // id
Account string // 账号
Username string // 用户名
Password string // 密码
Gender int64 // 性别 1:未设置;2:男性;3:女性
UpdatedBy int64 // 更新人
UpdatedAt time.Time // 更新时间
CreatedBy int64 // 创建人
CreatedAt time.Time // 创建时间
DeletedAt sql.NullTime // 删除时间
DeletedFlag int64 // 是否删除 1:正常 2:已删除
}
ZeroUserInfos struct {
Id int64 // id
UserId int64 // 用户id
Email string // 邮箱
Phone string // 手机号
UpdatedBy int64 // 更新人
UpdatedAt time.Time // 更新时间
CreatedBy int64 // 创建人
CreatedAt time.Time // 创建时间
DeletedAt sql.NullTime // 删除时间
DeletedFlag int64 // 是否删除 1:正常 2:已删除
}
)
package svc
import (
"fmt"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"go-zero-micro/rpc/code/ucenter/internal/config"
sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
sqlx_usermodel "go-zero-micro/rpc/database/sqlx/usermodel"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
type ServiceContext struct {
Config config.Config
SqlxUsersModel sqlx_usermodel.ZeroUsersModel
SqlxUserInfosModel sqlx_usermodel.ZeroUserInfosModel
SqlcUsersModel sqlc_usermodel.ZeroUsersModel
SqlcUserInfosModel sqlc_usermodel.ZeroUserInfosModel
GormDb *gorm.DB
}
func NewServiceContext(c config.Config) *ServiceContext {
mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)
gormDb, err := gorm.Open(mysql.Open(c.MySQL.DataSource), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
//TablePrefix: "tech_", // 表名前缀,`User` 的表名应该是 `t_users`
//SingularTable: true, // 使用单数表名,启用该选项,此时,`User` 的表名应该是 `t_user`
},
})
if err != nil {
errInfo := fmt.Sprintf("Gorm connect database err:%v", err)
panic(errInfo)
}
//自动同步更新表结构,不要建表了O(∩_∩)O哈哈~
//db.AutoMigrate(&models.User{})
return &ServiceContext{
Config: c,
SqlxUsersModel: sqlx_usermodel.NewZeroUsersModel(mysqlConn),
SqlxUserInfosModel: sqlx_usermodel.NewZeroUserInfosModel(mysqlConn),
SqlcUsersModel: sqlc_usermodel.NewZeroUsersModel(mysqlConn, c.CacheRedis),
SqlcUserInfosModel: sqlc_usermodel.NewZeroUserInfosModel(mysqlConn, c.CacheRedis),
GormDb: gormDb,
}
}
package ucentergormlogic
import (
"context"
"errors"
"fmt"
"github.com/jinzhu/copier"
"go-zero-micro/common/utils"
gorm_usermodel "go-zero-micro/rpc/database/gorm/usermodel"
"time"
"go-zero-micro/rpc/code/ucenter/internal/svc"
"go-zero-micro/rpc/code/ucenter/ucenter"
"github.com/zeromicro/go-zero/core/logx"
)
type LoginUserLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
return &LoginUserLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// LoginUser 用户登录
func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
param := &gorm_usermodel.ZeroUsers{
Id: in.Id,
Account: in.Account,
}
dbRes := &gorm_usermodel.ZeroUsers{}
l.svcCtx.GormDb.Where(param).First(dbRes)
if utils.ComparePassword(in.Password, dbRes.Password) {
copier.Copy(in, dbRes)
return l.LoginSuccess(in)
} else {
errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
return nil, errors.New(errInfo)
}
}
func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
AccessSecret := l.svcCtx.Config.JWT.AccessSecret
AccessExpire := l.svcCtx.Config.JWT.AccessExpire
now := time.Now().Unix()
jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
if err != nil {
return nil, err
}
resp := &ucenter.UserLoginResp{}
copier.Copy(resp, in)
resp.AccessToken = jwtToken
resp.AccessExpire = now + AccessExpire
resp.RefreshAfter = now + AccessExpire/2
return resp, nil
}
在 internal/logic/login/loginbypasswordlogic.go中将UcenterSqlxRpc替换为UcenterGormRpc即可
gorm本身不支持缓存,如果想使用缓存的话,可以参考sqlc中是如何使用缓存的。
sqlx切换成gorm的流程:
(1)sqlx 切成gorm,同时结合sqlc;这种骚操作怎么改的,看到网上有同学这样干。
(2)只需要把带缓存生成的model中,sqlx执行db部分换成gorm即可。
(3)替换后不影响go-zero 中封装的数据库分布式事务,因为dtm支持gorm,可以看dtm官网。
参考文章https://blog.csdn.net/Mr_XiMu/article/details/131658247