go-zero数据库

目录结构说明

本节内容的代码都放在/rpc/database/下,目录结构如下:

├─gorm                      
├─sql                       
│  └─user
├─sqlc
└─sqlx
  • gorm:gorm相关代码;
  • sql:主要是sql文件,下面可以进一步分组;
  • sqlc:带缓存的数据库操作代码;
  • sqlx:无缓存的数据库操作代码;

相关命令

参考:goctl model mysql 指令

  • goctl model mysql 指令用于生成基于 MySQL 的 model 代码,支持生成带缓存和不带缓存的代码。
  • MySQL 代码生成支持从 sql 文件,数据库连接两个来源生成代码。

注意:虽然go-zero的goctl model mysql 指令支持从 sql 文件,数据库连接两个来源生成代码,两者生成的代码是完全一样的。但是我个人比较推荐根据sql文件生成,因为可以记录sql文件的变化。

生成sqlx代码命令

注意:最后的参数-style=go_zero是指定生成文件名称的格式,这里是蛇形命名,不喜欢的可以去除这个参数。

使用sql 文件生成sqlx代码的命令:【推荐】

单表:

goctl model mysql ddl -src="./rpc/database/sql/user/zero_users.sql" -dir="./rpc/database/sqlx/usermodel" -style=go_zero

多表:

goctl model mysql ddl -src="./rpc/database/sql/user/zero_*.sql" -dir="./rpc/database/sqlx/usermodel" -style=go_zero

-srcsql文件目录;
-dirsqlx代码目录;

使用数据库连接生成sqlx代码的命令:

goctl model mysql datasource -url="root:root@tcp(127.0.0.1:3357)/go-zero-micro" -table="zero_users" -dir="./rpc/database/sqlx/usermodel"

-url:数据库连接;
-table:数据表;
-dirsqlx代码目录;

生成sqlc代码命令

同生成sqlx代码的命令类似,只是后面需要再加一个 -cache即可。

使用sql 文件生成sqlc代码的命令:【推荐】

单表:

goctl model mysql ddl -src="./rpc/database/sql/user/zero_users.sql" -dir="./rpc/database/sqlc/usermodel" -style=go_zero -cache

多表:

goctl model mysql ddl -src="./rpc/database/sql/user/zero_*.sql" -dir="./rpc/database/sqlc/usermodel" -style=go_zero -cache

-srcsql文件目录;
-dirsqlx代码目录;

使用数据库连接生成sqlc代码的命令:

goctl model mysql datasource -url="root:root@tcp(127.0.0.1:3357)/go-zero-micro" -table="zero_users" -dir="./rpc/database/sqlc/usermodel" -cache

-url:数据库连接;
-table:数据表;
-dirsqlx代码目录;

sqlx

sqlx代码讲解

通过 2.1的命令生成的sqlx代码有三个文件:

  • vars.go:只声明了数据不存在的错误;
  • zerousersmodel_gen.go:基本的增删改查方法,不推荐手动修改;
  • zerousersmodel.go:自定义的model,可以在这里新增所需要的数据库操作接口及其实现。

新增操作接口及其实现

主要代码都在 zerousersmodel.go,这里使用了反射对拼接的sql语句进行了优化:

注意:其实自定义的操作接口应该都加入context参数,便于链路追踪,这一点已在该分支最新提交的代码中补上。

package usermodel

import (
	"context"
	"database/sql"
	"fmt"
	"github.com/zeromicro/go-zero/core/stores/sqlc"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/common/utils"
	"reflect"
	"strings"
	"time"
)

var _ ZeroUsersModel = (*customZeroUsersModel)(nil)

type (
	// ZeroUsersModel is an interface to be customized, add more methods here,
	// and implement the added methods in customZeroUsersModel.
	ZeroUsersModel interface {
		zeroUsersModel

		Trans(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error
		Count(data *ZeroUsers, beginTime, endTime string) (int64, error)
		FindPageListByParam(data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error)
		FindAllByParam(data *ZeroUsers) ([]*ZeroUsers, error)
		FindOneByParam(data *ZeroUsers) (*ZeroUsers, error)
		Save(ctx context.Context, data *ZeroUsers) (sql.Result, error)
		Edit(ctx context.Context, data *ZeroUsers) (sql.Result, error)
		DeleteData(ctx context.Context, data *ZeroUsers) error
	}

	customZeroUsersModel struct {
		*defaultZeroUsersModel
	}
)

func (c customZeroUsersModel) Trans(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
	return c.conn.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
		return fn(ctx, session)
	})
}

/*
*
根据条件拼接的sql
*/
func userSqlJoins(queryModel *ZeroUsers) string {
	typ := reflect.TypeOf(queryModel).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(queryModel).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	sql := ""
	for i := 0; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		colName := typ.Field(i).Tag.Get("db")

		if colType == "int64" {
			if Field.Int() > 0 {
				sql += fmt.Sprintf(" AND %s=%d", colName, Field.Int())
			}
		} else if colType == "string" {
			if Field.String() != "" {
				sql += fmt.Sprintf(" AND %s LIKE %s", colName, "'%"+Field.String()+"%'")
			}
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				sql += fmt.Sprintf(" AND %s='%s'", colName, Field.String())
			}
		}
	}
	return sql
}

func (c customZeroUsersModel) Count(data *ZeroUsers, beginTime, endTime string) (int64, error) {
	sql := fmt.Sprintf("SELECT count(*) as count FROM %s WHERE deleted_flag = %d", c.table, utils.DelNo)
	joinSql := userSqlJoins(data)
	beginTimeSql := ""
	if beginTime != "" {
		beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
	}
	endTimeSql := ""
	if endTime != "" {
		endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
	}
	sql = sql + joinSql + beginTimeSql + endTimeSql

	var count int64
	err := c.conn.QueryRow(&count, sql)
	switch err {
	case nil:
		return count, nil
	case sqlc.ErrNotFound:
		return 0, ErrNotFound
	default:
		return 0, err
	}
}

func (c customZeroUsersModel) FindPageListByParam(data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error) {
	sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := userSqlJoins(data)
	beginTimeSql := ""
	if beginTime != "" {
		beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
	}
	endTimeSql := ""
	if endTime != "" {
		endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
	}
	orderSql := " ORDER BY created_at DESC"
	limitSql := fmt.Sprintf(" LIMIT %d,%d", (current-1)*pageSize, pageSize)
	sql = sql + joinSql + beginTimeSql + endTimeSql + orderSql + limitSql

	var result []*ZeroUsers
	err := c.conn.QueryRows(&result, sql)
	switch err {
	case nil:
		return result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) FindAllByParam(data *ZeroUsers) ([]*ZeroUsers, error) {
	sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := userSqlJoins(data)
	orderSql := " ORDER BY created_at DESC"
	sql = sql + joinSql + orderSql

	var result []*ZeroUsers
	err := c.conn.QueryRows(&result, sql)
	switch err {
	case nil:
		return result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) FindOneByParam(data *ZeroUsers) (*ZeroUsers, error) {
	sql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := userSqlJoins(data)
	orderSql := " ORDER BY created_at DESC"
	sql = sql + joinSql + orderSql

	var result ZeroUsers
	err := c.conn.QueryRow(&result, sql)
	switch err {
	case nil:
		return &result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) Save(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	values := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("%d,", Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			values += fmt.Sprintf("'%s',", Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	values = strings.TrimRight(values, ",")
	saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)
	result, err := c.conn.ExecCtx(ctx, saveSql)
	return result, err
}

func (c customZeroUsersModel) Edit(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(utils.DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", c.table, utils.DelNo, names, data.Id)
	result, err := c.conn.ExecCtx(ctx, sql)
	return result, err
}

func (c customZeroUsersModel) DeleteData(ctx context.Context, data *ZeroUsers) error {
	UpdateTime := data.UpdatedAt.Format(utils.DateTimeFormat)
	sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d,deleted_at= %s WHERE id = %d", c.table, utils.DelYes, "'"+UpdateTime+"'", data.Id)
	_, err := c.conn.ExecCtx(ctx, sql)
	return err
}

// NewZeroUsersModel returns a model for the database table.
func NewZeroUsersModel(conn sqlx.SqlConn) ZeroUsersModel {
	return &customZeroUsersModel{
		defaultZeroUsersModel: newZeroUsersModel(conn),
	}
}

RPC服务使用sqlx步骤

RPC服务yaml配置加入MySQL连接配置:

 MySQL:
   #本地数据库
   DataSource: root:root@tcp(127.0.0.1:3357)/go-zero-micro?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai

RPC服务yaml配置映射类internal/config/config.go加入MySQL连接配置:

 package config

 import "github.com/zeromicro/go-zero/zrpc"

 type Config struct {
     zrpc.RpcServerConf

     JWT struct {
         AccessSecret string
         AccessExpire int64
     }

     MySQL struct {
         DataSource string
     }

     UploadFile UploadFile
 }

 type UploadFile struct {
     MaxFileNum  int64
     MaxFileSize int64
     SavePath    string
 }

internal/svc/servicecontext.go创建操作数据库的连接

 package svc

 import (
     "github.com/zeromicro/go-zero/core/stores/sqlx"
     "go-zero-micro/rpc/code/ucenter/internal/config"
     "go-zero-micro/rpc/database/sqlx/usermodel"
 )

 type ServiceContext struct {
     Config     config.Config
     UsersModel usermodel.ZeroUsersModel
 }

 func NewServiceContext(c config.Config) *ServiceContext {
     mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)

     return &ServiceContext{
         Config:     c,
         UsersModel: usermodel.NewZeroUsersModel(mysqlConn),
     }
 }

internal/logic/ucentersqlx/loginuserlogic.go使用具体的操作接口

 package ucentersqlxlogic

 import (
     "context"
     "errors"
     "fmt"
     "go-zero-micro/common/utils"
     "go-zero-micro/rpc/database/sqlx/usermodel"
     "time"

     "go-zero-micro/rpc/code/ucenter/internal/svc"
     "go-zero-micro/rpc/code/ucenter/ucenter"

     "github.com/jinzhu/copier"
     "github.com/zeromicro/go-zero/core/logx"
 )

 type LoginUserLogic struct {
     ctx    context.Context
     svcCtx *svc.ServiceContext
     logx.Logger
 }

 func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
     return &LoginUserLogic{
         ctx:    ctx,
         svcCtx: svcCtx,
         Logger: logx.WithContext(ctx),
     }
 }

 // LoginUser 用户登录
 func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
     param := &usermodel.ZeroUsers{
         Account: in.Account,
     }
     dbRes, err := l.svcCtx.UsersModel.FindOneByParam(param)
     if err != nil {
         logx.Error(err)
         errInfo := fmt.Sprintf("LoginUser:FindOneByParam:db err:%v , in : %+v", err, in)
         return nil, errors.New(errInfo)
     }
     if utils.ComparePassword(in.Password, dbRes.Password) {
         copier.Copy(in, dbRes)
         return l.LoginSuccess(in)
     } else {
         errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
         return nil, errors.New(errInfo)
     }
 }

 func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
     AccessSecret := l.svcCtx.Config.JWT.AccessSecret
     AccessExpire := l.svcCtx.Config.JWT.AccessExpire
     now := time.Now().Unix()

     jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
     if err != nil {
         return nil, err
     }
     resp := &ucenter.UserLoginResp{}
     copier.Copy(resp, in)
     resp.AccessToken = jwtToken
     resp.AccessExpire = now + AccessExpire
     resp.RefreshAfter = now + AccessExpire/2
     return resp, nil
 }

事务

提示:本示例对前面的代码有较大优化

注意:

  • 其实自定义的操作接口应该都加入context参数,便于链路追踪,这一点已在该分支最新提交的代码中补上。
  • 事务特性需要手动实现,改动的地方较多,如果不熟悉或者遗忘某一环节,容易导致事务特性失效,这一点不够友好,后续可以改为使用gorm。

使用事务特性的步骤

在 xxxmodel.go 中加入调用事务的接口,新增含有session的数据库操作接口

调用事务的接口及其实现
//接口
TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error

//实现
func (c customZeroUsersModel) TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
	return c.conn.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
		return fn(ctx, session)
	})
}
含有session的数据库操作接口(以添加用户为例)

可以发现,与没有事务特性的插入相比只是更改了操作的调用者为session。

//接口,有session参数
TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error)

//实现
func (c customZeroUsersModel) TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error) {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	values := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("%d,", Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			values += fmt.Sprintf("'%s',", Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	values = strings.TrimRight(values, ",")
	saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)

	//result, err := c.conn.ExecCtx(ctx, saveSql)
	//return result, err
	result, err := session.ExecCtx(ctx, saveSql)
	return result, err
}

在xxxlogic.go中使用事务

  • 在xxxlogic.go中将对主子表的操作全部放到同一个事务中,每一步操作有错误就返回错误,事务遇到返回的错误会回滚,没有错误最后就返回nil;
  • 注意在同一事务里的每步操作(往主子表插入数据时),需要使用同一个session,否则事务特性不生效;

代码示例在 internal/logic/ucentersqlx/adduserlogic.go中:

package ucentersqlxlogic

import (
	"context"
	"github.com/jinzhu/copier"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/common/errorx"
	"go-zero-micro/common/utils"
	"go-zero-micro/rpc/database/sqlx/usermodel"
	"time"

	"go-zero-micro/rpc/code/ucenter/internal/svc"
	"go-zero-micro/rpc/code/ucenter/ucenter"

	"github.com/zeromicro/go-zero/core/logx"
)

type AddUserLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewAddUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AddUserLogic {
	return &AddUserLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// AddUser 添加用户
func (l *AddUserLogic) AddUser(in *ucenter.User) (*ucenter.BaseResp, error) {
	userId := utils.GetUidFromCtxInt64(l.ctx, "userId")
	currentTime := time.Now()
	/**
	  1、需求逻辑:User表保存账号信息,UserInfo表是子表,保存关联信息,比如:邮箱、手机号等
	  2、代码逻辑:先插入User表,后插入UserInfo表数据,插入UserInfo表时需要获取User表插入的id
	  3、无事务特性时:可能会出现主表有数据,但子表无数据的情况,导致数据不一致
	*/
	var InsertUserId int64

	//将对主子表的操作全部放到同一个事务中,每一步操作有错误就返回错误,没有错误最后就返回nil,事务遇到错误会回滚;
	if err := l.svcCtx.UsersModel.TransCtx(l.ctx, func(context context.Context, session sqlx.Session) error {
		userParam := &usermodel.ZeroUsers{}
		copier.Copy(userParam, in)
		userParam.Password = utils.GeneratePassword(l.svcCtx.Config.DefaultConfig.DefaultPassword)
		userParam.CreatedBy = userId
		userParam.CreatedAt = currentTime
		dbUserRes, err := l.svcCtx.UsersModel.TransSaveCtx(l.ctx, session, userParam)
		if err != nil {
			return err
		}
		uid, err := dbUserRes.LastInsertId()
		if err != nil {
			return err
		}

		userInfoParam := &usermodel.ZeroUserInfos{}
		copier.Copy(userInfoParam, in)
		userInfoParam.UserId = uid
		userInfoParam.CreatedBy = userId
		userInfoParam.CreatedAt = currentTime
		_, err = l.svcCtx.UserInfosModel.TransSaveCtx(l.ctx, session, userInfoParam)
		if err != nil {
			return err
		}
		InsertUserId = uid
		return nil
	}); err != nil {
		return nil, errorx.NewDefaultError(errorx.DbAddErrorCode)
	}

	return &ucenter.BaseResp{
		Id: InsertUserId,
	}, nil
}

使用泛型优化sqlx代码

  • 前面使用反射特性减少了操作每个数据表(即xxxmodel.go)中查询、添加、修改接口的代码量,减少了手动拼接sql。
  • 本节使用golang新增的泛型特性,对查询、新增、修改这三种类型的代码进一步优化,从而实现使用同一个接口操作多个数据表的目的。
  • 使用泛型优化查询、新增、修改的拼接sql,具体代码在common/utils/database.go:
// QuerySqlJoins 根据查询条件拼接sql,使用泛型更加通用
func QuerySqlJoins[T any](data *T) string {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	sql := ""
	for i := 0; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		colName := typ.Field(i).Tag.Get("db")

		if colType == "int64" {
			if Field.Int() > 0 {
				sql += fmt.Sprintf(" AND %s=%d", colName, Field.Int())
			}
		} else if colType == "string" {
			if Field.String() != "" {
				sql += fmt.Sprintf(" AND %s LIKE %s", colName, "'%"+Field.String()+"%'")
			}
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				sql += fmt.Sprintf(" AND %s='%s'", colName, Field.String())
			}
		}
	}
	return sql
}

// SaveSqlJoins 根据实际参数拼接sql,使用泛型更加通用
func SaveSqlJoins[T any](data *T, table string) string {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	values := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			//if Field.Int() > 0 {
			//	names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			//	values += fmt.Sprintf("%d,", Field.Int())
			//}
			names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			values += fmt.Sprintf("%d,", Field.Int())
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			values += fmt.Sprintf("'%s',", Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("'%s',", value.Format(DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	values = strings.TrimRight(values, ",")
	sql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", table, names, values)
	return sql
}

// EditSqlJoins 根据实际参数拼接sql,使用泛型更加通用
func EditSqlJoins[T any](data *T, table string, Id int64) string {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", table, DelNo, names, Id)
	return sql
}

替换xxx_model.go中的相应代码,例如

查询

//原调用的查询拼接sql:
//joinSql := userSqlJoins(data)
//新调用的查询拼接sql:
joinSql := utils.QuerySqlJoins(data)

新增

原sql:

func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	values := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("%d,", Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
			values += fmt.Sprintf("'%s',", Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`,", typ.Field(i).Tag.Get("db"))
				values += fmt.Sprintf("'%s',", value.Format(utils.DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	values = strings.TrimRight(values, ",")
	saveSql := fmt.Sprintf("INSERT INTO %s(%s) VALUE(%s)", c.table, names, values)
	saveSql := utils.SaveSqlJoins(data, c.table)
	result, err := c.conn.ExecCtx(ctx, saveSql)
	return result, err
}

新sql:

func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	saveSql := utils.SaveSqlJoins(data, c.table)
	result, err := c.conn.ExecCtx(ctx, saveSql)
	return result, err
}

修改

原sql

func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	typ := reflect.TypeOf(data).Elem()  //指针类型需要加 Elem()
	val := reflect.ValueOf(data).Elem() //指针类型需要加 Elem()
	fieldNum := val.NumField()

	names := ""
	for i := 1; i < fieldNum; i++ {
		Field := val.Field(i)
		colType := Field.Type().String()
		if colType == "int64" {
			if Field.Int() > 0 {
				names += fmt.Sprintf("`%s`=%d,", typ.Field(i).Tag.Get("db"), Field.Int())
			}
		} else if colType == "string" {
			names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), Field.String())
		} else if colType == "time.Time" {
			value := Field.Interface().(time.Time)
			if !value.IsZero() {
				names += fmt.Sprintf("`%s`='%s',", typ.Field(i).Tag.Get("db"), value.Format(utils.DateTimeFormat))
			}
		}
	}
	names = strings.TrimRight(names, ",")
	sql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d, %s WHERE id = %d", c.table, utils.DelNo, names, data.Id)
	result, err := c.conn.ExecCtx(ctx, sql)
	editSql := utils.EditSqlJoins(data, c.table, data.Id)
	result, err := c.conn.ExecCtx(ctx, editSql)
	return result, err
}

新sql

func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	editSql := utils.EditSqlJoins(data, c.table, data.Id)
	result, err := c.conn.ExecCtx(ctx, editSql)
	return result, err
}

sqlc及缓存的使用

sqlc说明

  • sqlc相比sqlx,主要是加入了缓存(Redis),可以避免频繁访问数据库。
  • 要想使用sqlc,只需要把生成sqlx的命令中再加入 -cache,同时加入缓存相关的配置即可。
  • ※默认的缓存接口主要针对的是单条数据,因为设置到Redis里的key只细化到了id,因此针对单条数据的 、删、改、查才使用缓存。

sqlc使用步骤

以zero_users数据表为例:

执行生成sqlc代码命令。

参考上文

在database/sqlc/usermodel/zero_users_model.go自定义其他查询接口及具体实现。

注意:这里只有FindOneByParamCtx、EditCtx、DeleteDataCtx使用了缓存。

package usermodel

import (
	"database/sql"
	"fmt"
	"github.com/zeromicro/go-zero/core/stores/cache"
	"github.com/zeromicro/go-zero/core/stores/sqlc"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/common/utils"
	"golang.org/x/net/context"
)

var _ ZeroUsersModel = (*customZeroUsersModel)(nil)

type (
	// ZeroUsersModel is an interface to be customized, add more methods here,
	// and implement the added methods in customZeroUsersModel.
	ZeroUsersModel interface {
		zeroUsersModel

		TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error
		CountCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string) (int64, error)
		FindPageListByParamCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error)
		FindAllByParamCtx(ctx context.Context, data *ZeroUsers) ([]*ZeroUsers, error)
		FindOneByParamCtx(ctx context.Context, data *ZeroUsers) (*ZeroUsers, error)
		SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error)
		EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error)
		DeleteDataCtx(ctx context.Context, data *ZeroUsers) error

		TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error)
	}

	customZeroUsersModel struct {
		*defaultZeroUsersModel
	}
)

func (c customZeroUsersModel) TransCtx(ctx context.Context, fn func(context context.Context, session sqlx.Session) error) error {
	return c.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
		return fn(ctx, session)
	})
}

func (c customZeroUsersModel) CountCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string) (int64, error) {
	querySql := fmt.Sprintf("SELECT count(*) as count FROM %s WHERE deleted_flag = %d", c.table, utils.DelNo)
	joinSql := utils.QuerySqlJoins(data)
	beginTimeSql := ""
	if beginTime != "" {
		beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
	}
	endTimeSql := ""
	if endTime != "" {
		endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
	}
	querySql = querySql + joinSql + beginTimeSql + endTimeSql

	var count int64
	err := c.QueryRowNoCacheCtx(ctx, &count, querySql)
	switch err {
	case nil:
		return count, nil
	case sqlc.ErrNotFound:
		return 0, ErrNotFound
	default:
		return 0, err
	}
}

func (c customZeroUsersModel) FindPageListByParamCtx(ctx context.Context, data *ZeroUsers, beginTime, endTime string, current, pageSize int64) ([]*ZeroUsers, error) {
	querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := utils.QuerySqlJoins(data)
	beginTimeSql := ""
	if beginTime != "" {
		beginTimeSql = fmt.Sprintf(" AND created_at >= %s", "'"+beginTime+"'")
	}
	endTimeSql := ""
	if endTime != "" {
		endTimeSql = fmt.Sprintf(" AND created_at <= %s", "'"+endTime+"'")
	}
	orderSql := " ORDER BY created_at DESC"
	limitSql := fmt.Sprintf(" LIMIT %d,%d", (current-1)*pageSize, pageSize)
	querySql = querySql + joinSql + beginTimeSql + endTimeSql + orderSql + limitSql

	var result []*ZeroUsers
	err := c.QueryRowsNoCacheCtx(ctx, &result, querySql)
	switch err {
	case nil:
		return result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) FindAllByParamCtx(ctx context.Context, data *ZeroUsers) ([]*ZeroUsers, error) {
	querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := utils.QuerySqlJoins(data)
	orderSql := " ORDER BY created_at DESC"
	querySql = querySql + joinSql + orderSql

	var result []*ZeroUsers
	err := c.QueryRowsNoCacheCtx(ctx, &result, querySql)
	switch err {
	case nil:
		return result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) FindOneByParamCtx(ctx context.Context, data *ZeroUsers) (*ZeroUsers, error) {
	querySql := fmt.Sprintf("SELECT %s FROM %s WHERE deleted_flag = %d", zeroUsersRows, c.table, utils.DelNo)
	joinSql := utils.QuerySqlJoins(data)
	orderSql := " ORDER BY created_at DESC"
	querySql = querySql + joinSql + orderSql

	var result ZeroUsers
	var err error
	if data.Id > 0 {
		zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
		err = c.QueryRowCtx(ctx, &result, zeroUsersIdKey, func(ctx context.Context, conn sqlx.SqlConn, v any) error {
			return conn.QueryRowCtx(ctx, v, querySql)
		})
	} else {
		err = c.QueryRowNoCacheCtx(ctx, &result, querySql)
	}

	switch err {
	case nil:
		return &result, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (c customZeroUsersModel) SaveCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	saveSql := utils.SaveSqlJoins(data, c.table)
	//zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
	//result, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
	//	return conn.ExecCtx(ctx, saveSql)
	//}, zeroUsersIdKey)

	result, err := c.ExecNoCacheCtx(ctx, saveSql)
	return result, err
}

func (c customZeroUsersModel) EditCtx(ctx context.Context, data *ZeroUsers) (sql.Result, error) {
	editSql := utils.EditSqlJoins(data, c.table, data.Id)
	zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)

	result, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		return conn.ExecCtx(ctx, editSql)
	}, zeroUsersIdKey)
	return result, err
}

func (c customZeroUsersModel) DeleteDataCtx(ctx context.Context, data *ZeroUsers) error {
	UpdateTime := data.UpdatedAt.Format(utils.DateTimeFormat)
	deleteSql := fmt.Sprintf("UPDATE %s SET deleted_flag = %d,deleted_at= %s WHERE id = %d", c.table, utils.DelYes, "'"+UpdateTime+"'", data.Id)

	zeroUsersIdKey := fmt.Sprintf("%s%v", cacheZeroUsersIdPrefix, data.Id)
	_, err := c.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		return conn.ExecCtx(ctx, deleteSql)
	}, zeroUsersIdKey)
	return err
}

func (c customZeroUsersModel) TransSaveCtx(ctx context.Context, session sqlx.Session, data *ZeroUsers) (sql.Result, error) {
	saveSql := utils.SaveSqlJoins(data, c.table)
	//result, err := c.conn.ExecCtx(ctx, saveSql)
	//return result, err
	result, err := session.ExecCtx(ctx, saveSql)
	return result, err
}

// NewZeroUsersModel returns a model for the database table.
func NewZeroUsersModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) ZeroUsersModel {
	return &customZeroUsersModel{
		defaultZeroUsersModel: newZeroUsersModel(conn, c, opts...),
	}
}

添加Cache相关配置

RPC服务的yaml中新增缓存的配置,缓存用的是Redis。
CacheRedis:
  - Host: 127.0.0.1:6379
    Type: node
    Pass: ""
internal/config/config.go加入缓存配置的映射。
package config

import (
	"github.com/zeromicro/go-zero/core/stores/cache"
	"github.com/zeromicro/go-zero/zrpc"
)

type Config struct {
	zrpc.RpcServerConf

	JWT struct {
		AccessSecret string
		AccessExpire int64
	}

	MySQL struct {
		DataSource string
	}

	CacheRedis cache.CacheConf
}

在internal/svc/servicecontext.go加入sqlc数据库查询依赖。
package svc

import (
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/rpc/code/ucenter/internal/config"
	sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
	sqlx_usermodel "go-zero-micro/rpc/database/sqlx/usermodel"
)

type ServiceContext struct {
	Config             config.Config
	SqlxUsersModel     sqlx_usermodel.ZeroUsersModel
	SqlxUserInfosModel sqlx_usermodel.ZeroUserInfosModel

	SqlcUsersModel     sqlc_usermodel.ZeroUsersModel
	SqlcUserInfosModel sqlc_usermodel.ZeroUserInfosModel
}

func NewServiceContext(c config.Config) *ServiceContext {
	mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)

	return &ServiceContext{
		Config:             c,
		SqlxUsersModel:     sqlx_usermodel.NewZeroUsersModel(mysqlConn),
		SqlxUserInfosModel: sqlx_usermodel.NewZeroUserInfosModel(mysqlConn),

		SqlcUsersModel:     sqlc_usermodel.NewZeroUsersModel(mysqlConn, c.CacheRedis),
		SqlcUserInfosModel: sqlc_usermodel.NewZeroUserInfosModel(mysqlConn, c.CacheRedis),
	}
}

将xxxlogic.go处理逻辑中的sqlx代码替换为sqlc代码

internal/logic/ucentersqlx/loginuserlogic.go

注意:在loginuserlogic.go中查询参数加入了Id,这样就可以测试缓存是否生效了。

package ucentersqlxlogic

import (
	"context"
	"errors"
	"fmt"
	"go-zero-micro/common/utils"
	sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
	"time"

	"go-zero-micro/rpc/code/ucenter/internal/svc"
	"go-zero-micro/rpc/code/ucenter/ucenter"

	"github.com/jinzhu/copier"
	"github.com/zeromicro/go-zero/core/logx"
)

type LoginUserLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
	return &LoginUserLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// LoginUser 用户登录
func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
	param := &sqlc_usermodel.ZeroUsers{
		Id:      1, //测试缓存
		Account: in.Account,
	}
	dbRes, err := l.svcCtx.SqlcUsersModel.FindOneByParamCtx(l.ctx, param)
	if err != nil {
		logx.Error(err)
		errInfo := fmt.Sprintf("LoginUser:FindOneByParam:db err:%v , in : %+v", err, in)
		return nil, errors.New(errInfo)
	}
	if utils.ComparePassword(in.Password, dbRes.Password) {
		copier.Copy(in, dbRes)
		return l.LoginSuccess(in)
	} else {
		errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
		return nil, errors.New(errInfo)
	}
}

func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
	AccessSecret := l.svcCtx.Config.JWT.AccessSecret
	AccessExpire := l.svcCtx.Config.JWT.AccessExpire
	now := time.Now().Unix()

	jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
	if err != nil {
		return nil, err
	}
	resp := &ucenter.UserLoginResp{}
	copier.Copy(resp, in)
	resp.AccessToken = jwtToken
	resp.AccessExpire = now + AccessExpire
	resp.RefreshAfter = now + AccessExpire/2
	return resp, nil
}

internal/logic/ucentersqlx/adduserlogic.go

package ucentersqlxlogic

import (
	"context"
	"github.com/jinzhu/copier"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/common/errorx"
	"go-zero-micro/common/utils"
	sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
	"time"

	"go-zero-micro/rpc/code/ucenter/internal/svc"
	"go-zero-micro/rpc/code/ucenter/ucenter"

	"github.com/zeromicro/go-zero/core/logx"
)

type AddUserLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewAddUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AddUserLogic {
	return &AddUserLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// AddUser 添加用户
func (l *AddUserLogic) AddUser(in *ucenter.User) (*ucenter.BaseResp, error) {
	userId := utils.GetUidFromCtxInt64(l.ctx, "userId")
	currentTime := time.Now()
	/**
	  1、需求逻辑:User表保存账号信息,UserInfo表是子表,保存关联信息,比如:邮箱、手机号等
	  2、代码逻辑:先插入User表,后插入UserInfo表数据,插入UserInfo表时需要获取User表插入的id
	  3、无事务特性时:可能会出现主表有数据,但子表无数据的情况,导致数据不一致
	*/
	var InsertUserId int64

	//将对主子表的操作全部放到同一个事务中,每一步操作有错误就返回错误,没有错误最后就返回nil,事务遇到错误会回滚;
	if err := l.svcCtx.SqlcUsersModel.TransCtx(l.ctx, func(context context.Context, session sqlx.Session) error {
		userParam := &sqlc_usermodel.ZeroUsers{}
		copier.Copy(userParam, in)
		userParam.Password = utils.GeneratePassword(l.svcCtx.Config.DefaultConfig.DefaultPassword)
		userParam.CreatedBy = userId
		userParam.CreatedAt = currentTime
		dbUserRes, err := l.svcCtx.SqlcUsersModel.TransSaveCtx(l.ctx, session, userParam)
		if err != nil {
			return err
		}
		uid, err := dbUserRes.LastInsertId()
		if err != nil {
			return err
		}

		userInfoParam := &sqlc_usermodel.ZeroUserInfos{}
		copier.Copy(userInfoParam, in)
		userInfoParam.UserId = uid
		userInfoParam.CreatedBy = userId
		userInfoParam.CreatedAt = currentTime
		_, err = l.svcCtx.SqlcUserInfosModel.TransSaveCtx(l.ctx, session, userInfoParam)
		if err != nil {
			return err
		}
		InsertUserId = uid
		return nil
	}); err != nil {
		return nil, errorx.NewDefaultError(errorx.DbAddErrorCode)
	}

	return &ucenter.BaseResp{
		Id: InsertUserId,
	}, nil
}

sqlc执行源码分析与model详解

总结:通过简单分析源码可以得出缓存的具体使用过程,其中设置指定key的数据为*的目的是防止雪崩击穿。

  • 先查询缓存,缓存中指定key的数据是*,则说明前面已经查询过数据库了,结果是数据库中没有数据,所以暂时不用马上再次查询数据库;
  • 查询缓存时出错,则有可能是Redis节点故障;
  • 查询缓存没有出现故障,而且没有指定key的数据,则需要查询一次数据库;
  • 查询数据库后没有数据,则在Redis缓存中将指定key的数据设置为*,时长自定义。
  • 查询数据库有数据,则在Redis缓存中将指定key的数据设置为数据库查到的数据,时长自定义,同时将查询到的结果返回。

gorm

go-zero使用gorm的步骤

RPC服务:

在database/gorm/usermodel/gorm_zero_models.go添加数据表结构对应的结构体
package usermodel

import (
	"database/sql"
	"time"
)

type (
	ZeroUsers struct {
		Id          int64        // id
		Account     string       // 账号
		Username    string       // 用户名
		Password    string       // 密码
		Gender      int64        // 性别 1:未设置;2:男性;3:女性
		UpdatedBy   int64        // 更新人
		UpdatedAt   time.Time    // 更新时间
		CreatedBy   int64        // 创建人
		CreatedAt   time.Time    // 创建时间
		DeletedAt   sql.NullTime // 删除时间
		DeletedFlag int64        // 是否删除 1:正常  2:已删除
	}

	ZeroUserInfos struct {
		Id          int64        // id
		UserId      int64        // 用户id
		Email       string       // 邮箱
		Phone       string       // 手机号
		UpdatedBy   int64        // 更新人
		UpdatedAt   time.Time    // 更新时间
		CreatedBy   int64        // 创建人
		CreatedAt   time.Time    // 创建时间
		DeletedAt   sql.NullTime // 删除时间
		DeletedFlag int64        // 是否删除 1:正常  2:已删除
	}
)
在internal/svc/servicecontext.go中创建gorm连接
package svc

import (
	"fmt"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"go-zero-micro/rpc/code/ucenter/internal/config"
	sqlc_usermodel "go-zero-micro/rpc/database/sqlc/usermodel"
	sqlx_usermodel "go-zero-micro/rpc/database/sqlx/usermodel"

	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"gorm.io/gorm/schema"
)

type ServiceContext struct {
	Config             config.Config
	SqlxUsersModel     sqlx_usermodel.ZeroUsersModel
	SqlxUserInfosModel sqlx_usermodel.ZeroUserInfosModel

	SqlcUsersModel     sqlc_usermodel.ZeroUsersModel
	SqlcUserInfosModel sqlc_usermodel.ZeroUserInfosModel
	GormDb             *gorm.DB
}

func NewServiceContext(c config.Config) *ServiceContext {
	mysqlConn := sqlx.NewMysql(c.MySQL.DataSource)

	gormDb, err := gorm.Open(mysql.Open(c.MySQL.DataSource), &gorm.Config{
		NamingStrategy: schema.NamingStrategy{
			//TablePrefix:   "tech_", // 表名前缀,`User` 的表名应该是 `t_users`
			//SingularTable: true, // 使用单数表名,启用该选项,此时,`User` 的表名应该是 `t_user`
		},
	})
	if err != nil {
		errInfo := fmt.Sprintf("Gorm connect database err:%v", err)
		panic(errInfo)
	}
	//自动同步更新表结构,不要建表了O(∩_∩)O哈哈~
	//db.AutoMigrate(&models.User{})

	return &ServiceContext{
		Config:             c,
		SqlxUsersModel:     sqlx_usermodel.NewZeroUsersModel(mysqlConn),
		SqlxUserInfosModel: sqlx_usermodel.NewZeroUserInfosModel(mysqlConn),

		SqlcUsersModel:     sqlc_usermodel.NewZeroUsersModel(mysqlConn, c.CacheRedis),
		SqlcUserInfosModel: sqlc_usermodel.NewZeroUserInfosModel(mysqlConn, c.CacheRedis),
		GormDb:             gormDb,
	}
}

在internal/logic/ucentergorm/loginuserlogic.go中使用
package ucentergormlogic

import (
	"context"
	"errors"
	"fmt"
	"github.com/jinzhu/copier"
	"go-zero-micro/common/utils"
	gorm_usermodel "go-zero-micro/rpc/database/gorm/usermodel"
	"time"

	"go-zero-micro/rpc/code/ucenter/internal/svc"
	"go-zero-micro/rpc/code/ucenter/ucenter"

	"github.com/zeromicro/go-zero/core/logx"
)

type LoginUserLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
	return &LoginUserLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// LoginUser 用户登录
func (l *LoginUserLogic) LoginUser(in *ucenter.User) (*ucenter.UserLoginResp, error) {
	param := &gorm_usermodel.ZeroUsers{
		Id:      in.Id,
		Account: in.Account,
	}
	dbRes := &gorm_usermodel.ZeroUsers{}
	l.svcCtx.GormDb.Where(param).First(dbRes)

	if utils.ComparePassword(in.Password, dbRes.Password) {
		copier.Copy(in, dbRes)
		return l.LoginSuccess(in)
	} else {
		errInfo := fmt.Sprintf("LoginUser:user password error:in : %+v", in)
		return nil, errors.New(errInfo)
	}
}

func (l *LoginUserLogic) LoginSuccess(in *ucenter.User) (*ucenter.UserLoginResp, error) {
	AccessSecret := l.svcCtx.Config.JWT.AccessSecret
	AccessExpire := l.svcCtx.Config.JWT.AccessExpire
	now := time.Now().Unix()

	jwtToken, err := utils.GenerateJwtToken(AccessSecret, now, AccessExpire, in.Id)
	if err != nil {
		return nil, err
	}
	resp := &ucenter.UserLoginResp{}
	copier.Copy(resp, in)
	resp.AccessToken = jwtToken
	resp.AccessExpire = now + AccessExpire
	resp.RefreshAfter = now + AccessExpire/2
	return resp, nil
}

API服务

在 internal/logic/login/loginbypasswordlogic.go中将UcenterSqlxRpc替换为UcenterGormRpc即可

gorm中使用缓存

gorm本身不支持缓存,如果想使用缓存的话,可以参考sqlc中是如何使用缓存的。

sqlx切换成gorm的流程:
(1)sqlx 切成gorm,同时结合sqlc;这种骚操作怎么改的,看到网上有同学这样干。
(2)只需要把带缓存生成的model中,sqlx执行db部分换成gorm即可。
(3)替换后不影响go-zero 中封装的数据库分布式事务,因为dtm支持gorm,可以看dtm官网。

参考文章https://blog.csdn.net/Mr_XiMu/article/details/131658247

你可能感兴趣的:(后端)