GoLang封装go-sql-driver

go get  github.com/go-sql-driver/mysql

基础类:

package mysql

import (
	"context"
	"database/sql"
	"fmt"
	"math/rand"
	"time"
	"why/config"
	"why/log"

	"github.com/opentracing/opentracing-go"

	_ "github.com/go-sql-driver/mysql"
)

type (
	DB struct {
		masterDB *sql.DB
		slaveDB  []*sql.DB
		Config   *Config
	}
)

func GetDSN(conn string) string {
	cfg := config.GetConfigEntrance("mysql", conn)
	dsn := cfg["user"] + ":" + cfg["password"] + "@tcp(" + cfg["host"] + ":" + cfg["port"] + ")/" + cfg["db"] + "?charset=" + cfg["charset"]
	return dsn
}

func New(c *Config) (db *DB, err error) {
	db = new(DB)
	db.Config = c
	db.masterDB, err = sql.Open("mysql", c.Master.DSN)
	if err != nil {
		err = errorsWrap(err, "init master db error")
		return
	}

	db.masterDB.SetMaxOpenConns(c.Master.MaxOpen)
	db.masterDB.SetMaxIdleConns(c.Master.MaxIdle)
	if err = db.masterDB.Ping(); err != nil {
		err = errorsWrap(err, "master db ping error")
		return
	}

	for i := 0; i < len(c.Slave); i++ {
		var mysqlDB *sql.DB
		mysqlDB, err = sql.Open("mysql", c.Slave[i].DSN)
		if err != nil {
			err = errorsWrap(err, "init slave db error")
			return
		}

		mysqlDB.SetMaxOpenConns(c.Slave[i].MaxOpen)
		mysqlDB.SetMaxIdleConns(c.Slave[i].MaxIdle)
		if err = mysqlDB.Ping(); err != nil {
			err = errorsWrap(err, "slave db ping error")
			return
		}

		db.slaveDB = append(db.slaveDB, mysqlDB)
	}
	return
}

func (db *DB) MasterDB() *sql.DB {
	return db.masterDB
}

func (db *DB) SlaveDB() *sql.DB {
	if len(db.slaveDB) == 0 {
		return db.masterDB
	}
	n := rand.Intn(len(db.slaveDB))
	return db.slaveDB[n]
}

// MasterDBClose 释放主库的资源
func (db *DB) MasterDBClose() error {
	if db.masterDB != nil {
		return db.masterDB.Close()
	}
	return nil
}

// SlaveDBClose 释放从库的资源
func (db *DB) SlaveDBClose() (err error) {
	for i := 0; i < len(db.slaveDB); i++ {
		err = db.slaveDB[i].Close()
		if err != nil {
			return err
		}
	}
	return nil
}

type operate int64

const (
	operateMasterExec operate = iota
	operateMasterQuery
	operateMasterQueryRow
	operateSlaveQuery
	operateSlaveQueryRow
)

var operationNames = map[operate]string{
	operateMasterExec:     "masterDBExec",
	operateMasterQuery:    "masterDBQuery",
	operateMasterQueryRow: "masterDBQueryRow",
	operateSlaveQuery:     "slaveDBQuery",
	operateSlaveQueryRow:  "slaveDBQueryRow",
}

func (db *DB) operate(ctx context.Context, op operate, query string, args ...interface{}) (i interface{}, err error) {
	var (
		parent        = opentracing.SpanFromContext(ctx)
		operationName = operationNames[op]
		span          = func() opentracing.Span {
			if parent == nil {
				return opentracing.StartSpan(operationName)
			}
			return opentracing.StartSpan(operationName, opentracing.ChildOf(parent.Context()))
		}()
		logFormat  = log.LogHeaderFromContext(ctx)
		startAt    = time.Now()
		endAt      time.Time
	)

	lastModule := logFormat.Module
	defer func() {logFormat.Module = lastModule}()

	defer span.Finish()
	defer func() {
		endAt = time.Now()

		logFormat.StartTime = startAt
		logFormat.EndTime = endAt
		latencyTime := logFormat.EndTime.Sub(logFormat.StartTime).Microseconds()// 执行时间
		logFormat.LatencyTime = latencyTime

		span.SetTag("error", err != nil)
		span.SetTag("db.type", "sql")
		span.SetTag("db.statement", query)
		logFormat.Module = "databus/mysql"
		if endAt.Sub(startAt) > db.Config.ExecTimeout.Duration {
			log.Warnf(logFormat, "%s:[%s], params:%s, used: %d milliseconds", operationName, query,
				args, endAt.Sub(startAt).Milliseconds())
		}

		if err != nil {
			log.Errorf(logFormat, "%s:[%s], params:%s, error: %s", operationName, query,
				args, err)
		}
	}()



	switch op {
	case operateMasterQuery:
		i, err = db.MasterDB().QueryContext(ctx, query, args...)
	case operateMasterQueryRow:
		i = db.MasterDB().QueryRowContext(ctx, query, args...)
	case operateMasterExec:
		i, err = db.MasterDB().ExecContext(ctx, query, args...)
	case operateSlaveQuery:
		i, err = db.SlaveDB().QueryContext(ctx, query, args...)
	case operateSlaveQueryRow:
		i = db.SlaveDB().QueryRowContext(ctx, query, args...)
	}
	return
}

func (db *DB) MasterDBExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
	r, err := db.operate(ctx, operateMasterExec, query, args...)
	if err != nil {
		return nil, err
	}
	return r.(sql.Result), err
}

func (db *DB) MasterDBQueryContext(ctx context.Context, query string, args ...interface{}) (result *sql.Rows, err error) {
	r, err := db.operate(ctx, operateMasterQuery, query, args...)
	if err != nil {
		return nil, err
	}
	return r.(*sql.Rows), err
}

func (db *DB) MasterDBQueryRowContext(ctx context.Context, query string, args ...interface{}) (result *sql.Row) {
	r, _ := db.operate(ctx, operateMasterQueryRow, query, args...)
	return r.(*sql.Row)
}

func (db *DB) SlaveDBQueryContext(ctx context.Context, query string, args ...interface{}) (result *sql.Rows, err error) {
	r, err := db.operate(ctx, operateMasterQuery, query, args...)
	if err != nil {
		return nil, err
	}
	return r.(*sql.Rows), err
}

func (db *DB) SlaveDBQueryRowContext(ctx context.Context, query string, args ...interface{}) (result *sql.Row) {
	r, _ := db.operate(ctx, operateSlaveQueryRow, query, args...)
	return r.(*sql.Row)
}

func errorsWrap(err error, msg string) error {
	return fmt.Errorf("%s: %w", msg, err)
}


/* example


 */

common数据库连接方法:

package models

import (
	"why/util"
	"why/mysql"
)

var priceInstance *mysql.DB
var ymtInstance *mysql.DB

func GetConn(conn string) *mysql.DB{
	db := &mysql.DB{}

	if conn == "hangqing" {
		if priceInstance == nil {
			db = getPriceConn()
		}else {
			db = priceInstance
		}
	}else if conn == "ymt360" {
		if ymtInstance == nil {
			db = getYmtConn()
		}else {
			db = ymtInstance
		}
	}else{
		panic("err conn string")
	}

	return db
}

func getYmtConn() *mysql.DB{
	write := mysql.GetDSN("ymt360_write")
	read := mysql.GetDSN("ymt360_read")

	writeDSN := mysql.Conn{
		DSN:     write,
		MaxOpen: 5,
		MaxIdle: 5,
	}
	readDSN := mysql.Conn{
		DSN:     read,
		MaxOpen: 5,
		MaxIdle: 5,
	}

	arrDSN := []mysql.Conn{}
	arrDSN = append(arrDSN, readDSN)

	cfg := &mysql.Config{
		Master:      writeDSN,
		Slave:       arrDSN,
	}

	db, err := mysql.New(cfg)
	util.Must(err)

	return db
}

func getPriceConn() *mysql.DB{
	write := mysql.GetDSN("hangqing_write")
	read := mysql.GetDSN("hangqing_read")

	writeDSN := mysql.Conn{
		DSN:     write,
		MaxOpen: 5,
		MaxIdle: 5,
	}
	readDSN := mysql.Conn{
		DSN:     read,
		MaxOpen: 5,
		MaxIdle: 5,
	}

	arrDSN := []mysql.Conn{}
	arrDSN = append(arrDSN, readDSN)

	cfg := &mysql.Config{
		Master:      writeDSN,
		Slave:       arrDSN,
	}

	db, err := mysql.New(cfg)
	util.Must(err)

	return db
}

model层:

package hangqing

import (
	"context"
	"github.com/jmoiron/sqlx"
	"why/util"
	"models"
)

type HqCustomer struct {
	Province_id     int
	City_id      	int
	County_id      	int
	Location_id     int
	Market_info_id  int
	Point_key      	string
	Point_key2      string
	Product_id      int
	Breed_id     	int
	Customer_id     int
}

func GetCustomerBreedsByCid(ctx context.Context, cid int) ( data []map[string]interface{} ) {
	query := "select province_id,city_id,county_id,location_id,market_info_id,point_key,point_key2,product_id,breed_id,customer_id from hq_customer where customer_id = ?"
	db := models.GetConn("hangqing")
	rows, err := db.SlaveDBQueryContext(ctx, query, cid)
	util.Must(err)

	var list = []*HqCustomer{}
	err = sqlx.StructScan(rows, &list)
	util.Must(err)

	if len(list) == 0 {
		return
	}

	for _, v := range list {
		tmp := util.StructToMap(*v)
		data = append(data, tmp)
	}

	return
}

 

config参考:https://blog.csdn.net/why444216978/article/details/103992579

 

 

你可能感兴趣的:(go)