mysql驱动源码解析_go-sql-driver 源码解析

Intro

最近正在给 mysql 封装一个库,顺带研究一下 go-mysql-driver 这个库的源码实现。

Buffer.go

buffer 是一个用于给 数据库连接 (net.Conn) 进行缓冲的一个数据结构,其结构为:

type buffer struct {

buf []byte // 缓冲池中的数据

nc net.Conn // 负责缓冲的数据库连接对象

idx int // 已读数据索引

length int // 缓冲池中未读数据的长度

timeout time.Duration // 数据库连接的超时设置

}

可以看到,因为 数据库连接 (net.Conn) 在通信的时候是 同步 的。而为了让其能够 同时 读/写 ,所以实现了 buffer 这个数据结构,通过该 buffer 进行数据缓冲还能实现 零拷贝 ( zero-copy-ish ) 。

其函数分别有:

newBuffer(nc net.Conn) buffer :创建并返回一个 buffer

(*buffer) readNext(need int) ([]byte, error) :读取并返回未读数据的 need 位,如果 need 大于 buffer 的 length ,就会调用 fill(need int) error 对 buffer进行 扩容 。

(*buffer) fill(need int) error :对 buffer 进行 (need/defaultBufSize) 的倍数扩容,并在 timeout 时间结束前从 buffer.nc 中读取 need 长度的数据。

(*buffer) takeBuffer(length int) []byte :读取 buffer 中 length 长度的数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil 。如果需要读取的长度大于 buffer 的容量,则会进行扩容。

(*buffer) takeSmallBuffer(length int) []byte :读取保证不超过 defaultBufSize 长度的数据的快捷函数(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil 。

(*buffer) takeCompleteBuffer() []byte : 读取全部的 buffer 数据(只包含已读),如果 buffer.length > 0 ,即还有未读数据,则立即返回 nil 。

Collations.go

collations 包含了 MySQL 所有支持的 字符集 格式,并支持通过 COLLATION_NAME 返回其字符集 ID。

如果需要查询 MySQL 支持的 字符集 格式,可以使用 SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS 语句获取。

Dsn.go

DSN 即 数据源名称 (Data Source Name) ,是 驱动程序连接数据库的变量信息 ,简而言之就是根据你连接的不同数据库使用对应的连接信息。

通常,数据库的连接配置就是在这里定义的:

// Config 基本的数据库连接信息

type Config struct {

User string // Username

Passwd string // Password (requires User)

Net string // Network type

Addr string // Network address (requires Net)

DBName string // Database name

Params map[string]string // Connection parameters

Collation string // Connection collation

Loc *time.Location // Location for time.Time values

TLSConfig string // TLS configuration name

tls *tls.Config // TLS configuration

Timeout time.Duration // Dial timeout

ReadTimeout time.Duration // I/O read timeout

WriteTimeout time.Duration // I/O write timeout

AllowAllFiles bool // 允许文件使用 LOAD DATA LOCAL INFILE 导入数据库

AllowCleartextPasswords bool // 支持明文密码客户端

AllowOldPasswords bool // 允许使用不可靠的旧密码

ClientFoundRows bool // 返回匹配的行数而不是受影响的行数

ColumnsWithAlias bool // 将表名前置在列名

InterpolateParams bool // 将占位符插入查询的SQL字符串

MultiStatements bool // 允许一条语句多次查询

ParseTime bool // 格式化时间值为 time.Time 变量

Strict bool // 将 warnings 返回 errors

}

这都是一些常见的配置项,就此略过。

该文件有两个公共函数支持 Config 与 DSN 之间转换。

(*Config)FormatDSN() string

ParseDSN(dsn string) (*Config, error)

Errors.go

errors 定义了 Logger 、MySQLError 、 MySQLWarning 等数据结构。

Logger

复用了 Go 原生的 log 包,并将其中的输出重定向至控制台的 标准错误 。

type Logger interface {

Print(v ...interface{})

}

var errLog = Logger(log.New(os.Stderr, "[mysql]", log.Ldate|log.Ltime|log.Lshortfile))

func SetLogger(logger Logger) error { // 当然,你也可以使用自定义的错误 Logger

if logger == nil {

return errors.New("logger is nil")

}

errLog =logger

return nil

}

MySQLError

而 MySQLError 则简单定义了 MySQL 输出的错误的结构。

type MySQLError struct {

Number uint16

Message string

}

MySQLWarning

MySQLWarning 则有些不一样,它需要从 MySQL 中进行一次 查询 ,以获取所有的警告信息,所以该包也定义了 MySQLWarning 的 slice 结构。

type MySQLWarning struct {

Level string

Code string

Message string

}

type MySQLWarnings []MySQLWarning

func (mc *mysqlConn) getWarnings() (err error) {

rows, err := mc.Query("SHOW WARNINGS", nil)

// handle err

// initzation MySQLWarnings

for {

err = rows.Next(values)

switch err {

case nil:

warning := MySQLWarning{}

if raw, ok := values[0].([]byte); ok {

warning.Level = string(raw)

}else {

warning.Level = fmt.Sprintf("%s", values[0])

}

if raw, ok := values[1].([]byte); ok {

warning.Code = string(raw)

} else {

warning.Code = fmt.Sprintf("%s", values[1])

}

if raw, ok := values[2].([]byte); ok {

warning.Message = string(raw)

} else {

warning.Message = fmt.Sprintf("%s", values[0])

}

warnings = append(warnings, warning)

}

case io.EOF:

return warnings

default:

rows.Close() // 值得注意的是,如果该函数没有 case 运行 default ,该 rows 就不会被默认关闭,就会占用连接池中的一个连接,是否应该使用 `defer rows.Close() ` 避免该情况?

return

}

}

Infile.go

前面也有提到 MySQL 在导入大型文件的时候,需要使用 LOAD DATA LOCAL INFILE 的形式进行导入,而该 infile.go 就是实现该协议的代码。

本包在实现的 LOAD DATA 的时候提供了两种方式进行导入:

最常见的,使用服务器的文件路径,如 /data/students.csv ,下文命名其为 文件路径注册器

最通用的,使用实现了 io.Reader 接口的数据结构,通过返回该数据结构的数据进行导入,如 bytes os.file 等,下文命名其为 Reader 接口注册器

在实现该功能的时候,注册器 的实现是用名字作为 Key 的 Map ,为了避免 Map 的 读写竞态 ,需要对其配置一个读写锁。

var (

fileRegister map[string]bool // 文件路径注册器

fileRegisterLock sync.RWMutex // 文件路径注册器读写锁

readerRegister map[string]func() io.Reader // Reader 接口注册器

readerRegisterLock sync.RWMutex // Reader 接口注册器读写锁

)

除了对两个注册器的 注册 以及 注销 函数,还有一个需要分析的一个函数:

(mc *mysqlConn) handleInFileRequest(name string) (err error)

通过传入 文件路径 或者 Reader 名称 就可以将数据发往 MySQL 了。

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {

packSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP

if mc.maxWriteSize < packSize { // 设置发往 MySQL 的数据块大小

packSize = mc.maxWriteSize

}

// 获取 文件 或 Reader 的数据,并将其赋值到 rdr 中

// var rdr io.Reader

// send context packets

if err != nil {

data := make([]byte, 4+packetSize) // 需要留 4 个 byte 给协议使用

var n int

for err == nil {

n, err = rdr.Read(data[4:]) // 将数据存入 data 的 [4:] 中

if n > 0 {

if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { // 将 data 数据发往 MySQL

return ioErr

}

}

}

if err == io.EOF { // rdr 中的数据读完了

err = nil

}

}

// send empty packet (termination)

if data == nil {

data = make([]byte, 4)

}

if ioErr := mc.writePacket(data[:4]); ioErr != nil { // 告诉 MySQL 文件发送完毕

return ioErr

}

// read OK packet

if err == nil { // 一切正常结束

return mc.readResultOK()

}

mc.readPacket() // 如果中途出错,将错误信息读取到 mysqlConn 中,并返回该错误

return err

}

到此,infile.go 的实现已经整理完毕了,可以看到, 作者 在实现这个功能的时候还是做了一些优化的,比如 map Lazy init ,send packet size limited 等。而我们通过分析规范的源码包,能够提升自己的编码水平。

Packets.go

接下来就要深入到 MySQL 的通信协议中了,官方的 通信协议文档 非常齐全,我在这里只将一些基础的,我后面分析源码会用到的协议分析下,如果有兴趣,可以到官方文档处进行查阅。

Protocol Basics

基础数据类型

MySQL 通信的基本数据类型有两种, Integer 、 String

Integer : 分别有 1, 2, 3, 4, 8 个字节长度的类型,使用小端传输。

String : 分别有 固定长度字符串(协议规定),NULL结尾字符串(长度不固定),长度编码字符串(长度不固定)。

报文协议

报文分为 消息头 以及 消息体,而 消息头 由 3 字节的 消息长度 以及 1 字节的 序号 sequence (新客户端由 0 开始)组成,消息体 则由 消息长度 的字节组成。

3 字节的 消息长度 最大值为 0xFFFFFF ,即为 16 MB - 1 byte ,这就意味着,如果整个消息(不包括消息头)的长度大于 16MB - 1byte - 4byte 大小时,消息就会被分包。

1 字节的 序号 在每次新的客户端发起请求时,以 0 开始,依次递增 1 ,如果消息需要分包, 序号 会随着分包的数量递增。而在一次应答中, 客户端会校验服务器 返回序号 是否与 发送序号 一致,如果不一致,则返回错误异常。

协议类型

handshake : 发起连接

auth : 登录权限校验

ok | error : 返回结果状态 *

ok : 首字节为 0 (0x00)

error : 首字节为 255 (0xff)

resultset : 结果集

header

field

eof

row

command package : 命令

在整个 MySQL 发起交互的过程如下图所示:

mysql connect

在了解这些 MySQL 基础协议知识后,我们再来看 packages.go 的源码就轻松多了。

源码

先来看看 readPacket ,结合上面的知识点应该非常好理解。

func (mc *mysqlConn) readPacket() ([]byte, error) {

var payload []byte

for { // for 循环是为了读取有可能分片的数据

// Read package header

data, err := mc.buf.readNext(4) // 从 buffer 缓冲器中读取 4 字节的 header

if err != nil { // 如果读取发生异常,则关闭连接,并返回一个错误连接的异常

errLog.Print(err)

mc.Close()

return nil, driver.ErrBadConn

}

// Packet Length [24 bit]

pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // 读取 3 字节的消息长度

if pktLen < 1 {

// 如上所示,关闭连接,并返回一个错误连接的异常

}

// Check Packet Sync [8 bit]

if data[3] != mc.sequence { // 判断服务端返回的序号是否与客户端一致

if data[3] > mc.sequence {

return nil, ErrPktSyncMul // 如果服务端返回序号大于客户端的序号,则有可能是在一次请求中做了多次操作

}

return nil, ErrPktSync // 返回序号不一致错误

}

mc.sequence++ // 本次序号匹配相符,为了匹配下一次请求,先将序号自增1

data, err := mc.buf.readNext(pktLen) // 读取 消息长度 的数据

if err != nil {

// 如上所示,关闭连接,并返回一个错误连接的异常

}

isLastPacket := (pktLen < maxPacketSize) // 如果是最后一个数据包,必然小于 maxPacketSize (16MB - 1byte)

// Zero allocations for non-splitting packets

if isLastPacket && payload == nil { // 无分包情况,立即返回

return data, nil

}

payload = append(payload, data...)

if isLastPacket { // 如果是最后一个包,读取完毕后返回

return payload, nil

}

// 还有未读数据,开始下一次循环

}

}

下面来看下结合 握手报文协议 来看下客户端向服务端发起请求的 readInitPacket :

mysql handshack protocol

func (mc *mysqlConn) readInitPacket() ([]byte, error) {

data, err := mc.readPacket() // 调用上面的函数读取服务端返回的数据

if err != nil {

return nil, err

}

if data[0] == iERR { // iERR = 0xff 消息体的第一个字节返回 0xff ,则意味着 error package

return nil, mc.handleErrorPacket(data)

}

// protocol version [1 byte]

if data[0] < minProtocolVersion { // 判断是否是兼容的协议版本

return nil, fmt.Errorf(

"unsupported protocol version %d. Version %d or higher is required",

data[0],

minProtocolVersion,

)

}

// server version [null terminated string]

// connection id [4 bytes]

pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // 读取 NULL (0x00)为结尾的字符串,跳过服务器线程 ID

// first part of the password cipher [8 bytes]

cipher := data[pos : pos+8] // 获取挑战随机数

// (filler) always 0x00 [1 byte]

pos += 8 + 1

// capability flags (lower 2 bytes) [2 bytes]

mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) // 获取服务器权能标识

if mc.flags&clientProtocol41 == 0 { // 说明 MySQL 服务器不支持高于 41 版本的协议

return nil, ErrOldProtocol

}

if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { // 说明 MySQL 服务器需要 SSL 加密,但是客户端没有配置 SSL

return nil, ErrNoTLS

}

pos += 2 // 指针向后两位

if len(data) > pos {

// 指针跳过标志位

pos += 1 + 2 + 2 + 1 + 10

// second part of the password cipher [mininum 13 bytes],

// where len=MAX(13, length of auth-plugin-data - 8)

//

// The web documentation is ambiguous about the length. However,

// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,

// the 13th byte is "\0 byte, terminating the second part of

// a scramble". So the second part of the password cipher is

// a NULL terminated string that's at least 13 bytes with the

// last byte being NULL.

//

// The official Python library uses the fixed length 12

// which seems to work but technically could have a hidden bug.

cipher = append(cipher, data[pos:pos+12]...)

// TODO: Verify string termination

// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)

// \NUL otherwise

//

//if data[len(data)-1] == 0 {

// return

//}

//return ErrMalformPkt

// make a memory safe copy of the cipher slice

var b [20]byte

copy(b[:], cipher)

return b[:], nil

}

// make a memory safe copy of the cipher slice

var b [8]byte // 返回 8 字节的挑战随机数

copy(b[:], cipher)

return b[:], nil

}

除了上面解析的两个函数, packages.go 还有 initialisation process / result packages / prepared statements 等协议的 写入/读取 ,有兴趣的读者可以结合上面的知识点自行阅读。

Driver.go

接下来就要分析一些比较重要的代码了,比如接下来要讲的 driver.go ,它主要负责与 MySQL 数据库进行各种协议的连接,并返回该连接。可以说它才是最基础、最核心的功能。

不过首先我们需要看下 database/sql 包中的 Driver 接口需要如何实现:

// database/sql/driver/driver.go

// 数据库驱动

type Driver interface {

Open(name string) (Conn, error)

}

// ...

// 非并发安全数据库连接

type Conn interface {

// 返回一个绑定到 sql 的准备语句

Prepare(query string) (Stmt, error)

// 关闭该连接,并标记为不再使用,停止所有准备语句和事务

// 因为 database/sql 包维护了一个空闲的连接池,并且在空闲连接过多的时候会自动调用 Close ,所以驱动程序包不需要显式调用该函数

Close() error

// 开始并返回一个新的事务,而新的事务与旧的连接没有任何关联

Begin() (Tx, error)

}

根据 database/sql 提供的 Driver 接口, go-sql-driver/mysql 实现了自己的 数据库驱动 结构:

type MySQLDriver struct{}

func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {

mc := &mysqlConn {

// set max value

}

mc.cfg = ParseDSN(dsn) // 通过解析 DSN 设置 MySQL 连接的配置

// set parseTime and strict

// ...

// connect to server

if dial, ok := dials[mc.cfg.Net]; ok { // 根据 地址 以及 协议类型,尝试连接上服务器

mc.netConn, err = dial(mc.cfg.Addr)

} else { // 连接服务器失败,尝试重连

nd := net.Dialer{Timeout: mc.cfg.Timeout}

mc.netConn, err := nd.Dial(mc.cfg.Net, mc.cfg.Addr)

}

if err != nil { // 重试失败,返回异常

return nil, err

}

// Enable TCP Keepalives on TCP connections

if tc, ok := mc.netConn.(*net.Conn); ok { // tcp 连接类型转换

if err := tc.SetKeepAlive(true); err != nil {

// Don't send COM_QUIT before handshake.

mc.netConn.Close() // 如果设置长连接失败,返回异常之前一定要记得将连接断开

mc.netConn = nil

return nil, err

}

}

mc.buff = newBuff(mc.netConn) // 生成一个带缓冲的 buffer,如上面 buffer.go 中所说

// set I/O timeout

// ...

// Reading Handshake Initialization Packet

cipher, err := mc.readInitPacket() // 发起数据库首次握手

if err != nil {

mc.cleanup() // 将当前 mysqlConn 对象销毁,后面我们会说这个函数

return nil, err

}

// Send Client Authentication Packet

if err = mc.writeAuthPacket(cipher); err != nil { // 向数据库发送登录信息校验

mc.cleanup()

return nil, err

}

}

connection.go

终于要讲到这个包的核心数据结构 mysqlConn 了,可以说,驱动的所有功能几乎都围绕着这个数据结构,我们先来看看它的结构:

type mysqlConn struct {

buf buffer // buffer 缓冲器

netConn net.Conn // 网络连接

affectedRows uint64 // sql 执行成功影响行数

insertId uint64 // sql 添加成功最新的主键 ID

cfg *Config // dsn 中的 基础配置

maxPacketAllowed int // 允许的最大报文的字节长度,最大不能超过 (16MB - 1byte)

maxWriteSize int // 允许最大的写入字节长度,最大不能超过 (16MB - 1byte)

writeTimeout time.Duration // 执行 sql 的 超时时间

flags clientFlag // 客户端状态标识

status statusFlag // 服务端状态标识

sequence uint8 // 序号

parseTime bool // 是否格式化时间

strict bool // 是否使用严格模式

}

// driver.go

// 而创建一个 mysqlConn 连接需要通过 driver.go 中的 Open 函数,也说明 mysqlConn 实现了 driver.Conn 接口

func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {

mc := &mysqlConn{

// ...

}

// ...

return mc, nil

}

当一个新的客户端连接上服务器的时候 (三次握手结束,客户端进入 established 状态),需要先对 MySQL 服务器进行 会话的用户/系统环境变量 的设置。

// Handles parameters set in DSN after the connection is established

func (mc *mysqlConn) handleParams() (err error) {

for param, val := range mc.cfg.Params { // Params: map[string]string

switch param {

// Charset

case "charset": // 如果是字符集,则调用 SET NAMES 命令

charsets := strings.Split(val, ",")

for i := range charsets {

// ignore errors here - a charset may not exist

err = mc.exec("SET NAMES " + charsets[i])

if err == nil {

break

}

}

if err != nil {

return

}

// System Vars

default: // 执行系统环境变量设置

err = mc.exec("SET " + param + "=" + val + "")

if err != nil {

return

}

}

}

}

conntion.go 还负责 事务 、预处理语句 、执行/查询 的管理,但是基本都是往 mysqlConn 中发送 command package ,如:

// Begin 开启事务

func (mc *mysqlConn) Begin() (driver.Tx, error) {

if mc.netConn == nil {

errLog.Print(ErrInvalidConn)

return nil, driver.ErrBadConn

}

err := mc.exec("START TRANSACTION")

if err == nil {

return &mysqlTx{mc}, err // 返回成功开启的事务,重用之前的连接

}

return nil, err

}

// Internal function to execute commands

func (mc *mysqlConn) exec(query string) error {

// Send command

err := mc.writeCommandPacketStr(comQurey, query)

if err != nil {

return err

}

// Read Result

resLen, err := mc.readResultSetHeaderPacket() // 根据 data[0] 的值判断是否出错,如果没有错误,则返回消息体的长度

if err == nil && resLen > 0 { // 存在有效消息体

if err = mc.readUntilEOF(); err != nil { // 读取 columns

return err

}

err = mc.readUntilEOF() // 读取 rows

}

return err

}

我想 conntion.go 中最重要的一个函数应该是 cleanup ,它负责将 连接关闭 、 重置环境变量 等功能,但是该函数不能随意调用,它只有在 登录权限校验异常 时候才应该被调用,否则服务器在不知道客户端 被强行关闭 的情况下,依然会向该客户端发送消息,导致严重异常:

// Closes the network connection and unsets internal variables. Do not call this

// function after successfully authentication, call Close instead. This function

// is called before auth or on auth failure because MySQL will have already

// closed the network connection.

func (mc *mysqlConn) cleanup() {

// Makes cleanup idempotent 保证函数的幂等性

if mc.netConn != nil {

if err := mc.netConn.Close(); err != nil { // Close 会尝试发送 comQuit command 到服务器

errLog.Print(err)

}

mc.netConn = nil // 不管 Close 是否成功,必须将 netConn 清空

}

mc.cfg = nil

mc.buf.nc = nil // 缓冲器中的 netConn 也要关闭

}

Result.go

每当 MySQL 返回一个 OK 的 状态报文 ,该报文协议会携带上本次执行的结果 affectedRows 以及 insertId ,而 result.go 就包含着一个数据结构,用于存储本次的执行结果。

type mysqlResult struct {

affectedRows int64

insertId int64

}

// 两个 getter

func (res *mysqlResult) LastInsertId() (int64, error) {

return res.insertId, nil

}

func (res *mysqlResult) RowsAffected() (int64, error) {

return res.affectedRows, nil

}

接下来我们看下在 conntion.go 中是怎么生成 mysqlResult 对象的:

// connect.go

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {

// ...

err := exec(query)

if err == nil {

return &mysqlResult{ // 返回执行的结果

affectedRows: int64(mc.affectedRows),

insertId: int64(mc.insertId),

}, err

}

return nil, err

}

// exec 函数的解析可以返回上面 package.go 中浏览

// package.go

func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {

data, err := mc.readPacket()

if err == nil {

switch data[0] {

case iOK:

return 0, mc.handleOkPacket(data) // 处理 OK 状态报文

// ...

}

func (mc *mysqlConn) handleOkPacket(data []byte) error {

var n, m int

// 0x00 [1 byte]

// Affected rows [Length Coded Binary]

mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])

// Insert id [Length Coded Binary]

mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

// ...

}

Row.go

当 MySQL 执行 插入、更新、删除 等操作后,都会返回 Result ,但是 查询 返回的是 Rows ,我们先来看看 go-mysql-driver 驱动所实现的 接口 Rows 的接口描述:

// database/sql/driver/driver.go

// Rows 是执行查询返回的结果的 游标

type Rows interface {

// Columns 返回列的名称,从 slice 的长度可以判断列的长度

// 如果一个列的名称未知,则为该列返回一个空字符串

Columns() []string

// Close 关闭游标

Close() error

// Next 将下一行数据填充到 desc 切片中

// 如果读取的是最后一行数据,应该返回一个 io.EOF 错误

Next(desc []Value) error

}

type Value interface{} // Value is a value that drivers must be able to handle.

为什么我要说这是 go-mysql-driver 驱动所实现的 接口 Rows 呢?眼尖的同学应该已经看到了, Next 函数好像和我们平常见到的不一样啊!!

是的,因为我们平常使用的:

rows.Next()

rows.Scan(dest ...interface{}) error

等函数的对象 rows 并不是上面的 接口描述 Rows ,而是另一个封装的 同名数据结构 Rows ,它就在 database/sql 包中 :

// database/sql.go

type Rows struct {

dc *driverConn

releaseConn func(error)

rowsi driver.Rows // 接口描述的 Rows 藏在这!!!

// 忽略其他字段,因为我们不分析这个包...

// lastcols is only used in Scan, Next, and NextResultSet which are expected

// not not be called concurrently.

lastcols []driver.Value

}

我们跳过 database/sql 包中的 Rows 实现,其无非是提供了更多功能的一个结果集而已,让我们回到真正与数据库进行交互的 Rows 中进行源码分析。

在 go-sql-driver 实现的 mysqlRows 数据结构只实现了 Columns() 和 Close() 两个行数,剩下的 Next(desc []driver.Value) 实现则交给了 MySQL 的两种结果集协议:

// rows.go

type mysqlField struct {

tableName string

name string

flags fieldFlag

fieldType byte

decimals byte

}

type mysqlRows struct {

mc *mysqlConn

columns []mysqlField

}

type binaryRows struct { // 二进制结果集协议

mysqlRows // 对于 Go 的 组合特性 应该不会陌生吧?

}

type textRows struct { // 文本结果集协议

mysqlRows

}

func (rows *mysqlRows) Columns() []string {

columns := make([]string, len(rows.columns))

// 将列名赋值到 columns ,如果有设置别名则赋值别名...

return columns

}

func (rows *mysqlRows) Close() error {

// 将连接里面的未读数据读完,然后将连接置空

}

// 接下来的 Next 函数实现就交由 binaryRows 和 textRows 了

func (rows *binaryRows) Next(desc []driver.Value) error {

if mc := rows.mc; mc != nil {

if mc.netConn == nil {

return ErrInvalidConn

}

return rows.readRow(dest) // 读二进制协议结果集

}

return io.EOF

}

func (rows *testRows) Next(desc []driver.Value) error {

if mc := rows.mc; mc != nil {

if mc.netConn == nil {

return ErrInvalidConn

}

return rows.readRow(dest) // 读取文本协议

}

return io.EOF

}

可以说,实现了 driver.Rows 接口的只有 binaryRows 和 testRows ,而他们里面的 readRow(desc) 实现由于都是和协议强相关的代码,就不再解析了。

我们跟着源码可以看到,使用 textRows 的场景在 getSystemVar 以及 Query 中,而使用 binaryRows 的场景在 statement 中,就是我们下一步需要解析的部分。

Statement.go

Prepared Statement ,即预处理语句,他有什么优势呢,为什么 MySQL 要加入它?

执行性能更高:MySQL 会对 Prepared Statement 语句预先进行编译成模板,并将 占位符 替换 参数 的位置,这样如果频繁执行一条参数只有少量替换的语句时候,性能会得到大量提高。可能有同学会有疑问,为什么 MySQL 语句还需要编译?那么可以来参考下这篇 MySQL Prepare 原理 。

传输协议更优:Prepare Statement 在传输时候使用的是 Binary Protocol ,比使用 Text Protocol 的查询具有 传输数据量更小 、 无需转换数据格式 等优势,缓解了 CPU 和 网络 的开销。

安全性更好:由 MySQL Prepare 原理 我们可以知道,Perpare 编译之后会生成 语法树,在执行的时候才会将参数传进来,这样就避免了平常直接执行 SQL 语句 会发生的 SQL 注入 问题。

好了,先来看下 mysqlStmt 的数据结构:

type mysqlStmt struct {

mc *mysqlConn

id uint32

paramCount int

columns []mysqlField // cached from the first query (既然SQL已经预编译好了,返回的结果集列名已经是确定的,所以在收到 PREPARE_OK 之后解析数据后会缓存下来)

}

我们发现,它比 mysqlRows 多了两个成员变量:

id :MySQL 预处理语句之后,会给该语句分配一个 id 并返回客户端,用于:

客户端提交该 id 给服务器调用对应的预处理语句。

paramCount :参数数量,等于 占位符 的个数,用于:

判断传入的参数个数是否与预编译语句中的占位符个数一致。

判断返回的 PREPARE_OK 响应报文是否带有 参数列名 数据。

下面来看看如何创建并使用一个 Prepare Statement :

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // 传入需要预编译的 SQL 语句

// 检查连接是否可用...

err = mc.writeCommandPacketStr(comStmtPrepare, query) // 将 SQL 发往数据库进行预编译

if err != nil {

return nil, err

}

stmt := &mysqlStmt{ // 预编译成功,先创建 stmt 对象

mc: mc,

}

// Read Result

columnCount, err := stmt.readPrepareResultPacket() // 从 stmt 的连接读取返回 响应报文

if err == nil {

if stmt.paramCount > 0 { // 如果预编译的 SQL 的有参数

if err = mc.readUntilEOF(); err != nil { // 读取参数列名数据

return nil, err

}

}

if columnCount > 0 { // 返回执行结果的列表个数

err = mc.readUntilEOF() // 读取执行结果的列名数据

}

}

return stmt, err

}

因为是已经预编译好的语句,所以在执行的时候只需要将参数传进去就可以了。

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {

// 检查连接是否可用...

err := stmt.writeExecutePacket(args)

if err != nil {

return nil, err

}

// 读取结果集的行、列数据

}

func(stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {

if len(args) != stmt.paramCount { // 判断传进来的参数和预编译好的SQL参数 个数是否一致

return fmt.Errorf(

"argument count mismatch (got: %d; has: %d)",

len(args),

stmt.paramCount,

)

}

// 读取缓冲器中的数据,如果为空,则返回异常...

// command [1 byte]

data[4] = comStmtExecute

// statement_id [4 bytes] 将预编译语句的 id 转换为 4字节的二进制数据

data[5] = byte(stmt.id)

data[6] = byte(stmt.id >> 8)

data[7] = byte(stmt.id >> 16)

data[8] = byte(stmt.id >> 24)

// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]

data[9] = 0x00

// iteration_count (uint32(1)) [4 bytes]

data[10] = 0x01

data[11] = 0x00

data[12] = 0x00

data[13] = 0x00

// 将参数按照不同的类型转换为 binary protobuf 并 append 到 data 中...

return mc.writePacket(data)

}

相信看到这里,已经能对看懂源码的 70% 了,剩余的代码都是和协议相关,就留待有兴趣的读者继续研究,这里就不再展开讲了。

Transaction.go

事务是 MySQL 中很重要的一部分,但是驱动的实现却很简单,因为一切的事务控制都已经交由 MySQL 去执行了,驱动所需要做的,只要发送一个 commit 或者 rollback 的 command packet 即可。

type mysqlTx struct {

mc *mysqlConn

}

func (tx *mysqlTx) Commit() (err error) {

if tx.mc == nil || tx.mc.netConn == nil {

return ErrInvalidConn

}

err = tx.mc.exec("COMMIT")

tx.mc = nil

return

}

func (tx *mysqlTx) Rollback() (err error) {

if tx.mc == nil || tx.mc.netConn == nil {

return ErrInvalidConn

}

err = tx.mc.exec("ROLLBACK")

tx.mc = nil

return

}

总结

最后,其实 buffer 的实现对我来说印象是最深刻的,因为它是最简单而又是最有效的实现了一个消息缓冲器,它实现的巧妙让我决定把它放到第一节,而其他的几乎都和 MySQL 的协议相关,看这些源码让我对 MySQL 有了更多的认识。

好了,本篇字数比较多,也会有很多不足,希望大家能够给本篇博客多提点意见,让我可以改进的更好。如果还有机会,我会带来其他篇章的源码解析,敬请期待 :)

参考链接

你可能感兴趣的:(mysql驱动源码解析)