使用golang开发MySQL binlog同步工具demo

背景

这篇是一个使用golang开发的binlog解析工具,更偏向demo和研究性质。简单来说,就是模拟MySQL binlog协议,开发一个服务,作为MySQL的“从库”,获取binlog,有点像java开发的canal。

实践

过程和结构

执行过程主要是server模块。首先连接MySQL,这里参考了我们使用的中间件部分(kingshard)。然后先关闭checkSum,然后作为从库注册到主库,发送binlog_dump命令。最后的操作就是监听获取binlog,然后通过go-mysql提供的方法,将binlog events 解析出来并打印。

代码

  1. config
    配置部分,描述binlog文件,位置,主库MySQL账号信息等。
package app

type Config struct {
    Host string
    Port int
    User string
    Pass string
    ServerId int

    LogFile string
    Position int
}

  1. server模块
    整个的核心部分,包括连接,注册,发送命令,获取binlog都是在这里。这里的解析binlog使用了go-mysql
package app

import (
    "bufio"
    "bytes"
    "context"
    "crypto/sha1"
    "encoding/binary"
    "errors"
    "fmt"
    "github.com/siddontang/go-mysql/replication"
    "io"
    "net"
    "os"
    "time"
)

const (
    MinProtocolVersion byte = 10

    OK_HEADER          byte = 0x00
    ERR_HEADER         byte = 0xff
    EOF_HEADER         byte = 0xfe
    LocalInFile_HEADER byte = 0xfb
)

const MaxPayloadLength = 1<<24 - 1

type Server struct {
    Cfg          *Config
    Ctx          context.Context
    conn         net.Conn
    io           *PacketIo
    registerSucc bool
}

func (s *Server) Run() {
    defer func() {
        s.Quit()
    }()

    s.dump()
}

func (s *Server) dump() {
    err := s.handshake()
    if err != nil {
        panic(err)
    }
    s.invalidChecksum()
    fmt.Println("dump ...")
    s.register()
    s.writeDumpCommand()
    parser := replication.NewBinlogParser()
    for {
        //time.Sleep(2 * time.Second)
        //s.query("select 1")

        data, err := s.io.readPacket()
        if err != nil || len(data) == 0 {
            continue
        }

        //s.Quit()

        if data[0] == OK_HEADER {
            //skip ok
            data = data[1:]
            if e, err := parser.Parse(data); err == nil {
                e.Dump(os.Stdout)
            } else {
                fmt.Println(err)
            }
        } else {
            s.io.HandleError(data)
        }
    }
}

func (s *Server) invalidChecksum()  {
    sql := `SET @master_binlog_checksum='NONE'`
    if err := s.query(sql); err != nil{
        fmt.Println(err)
    }
    //must read from tcp connection , either will be blocked
    _, _ = s.io.readPacket()
}

func (s *Server) handshake() error {
    conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", s.Cfg.Host, s.Cfg.Port), 10*time.Second)
    if err != nil {
        return err
    }

    tc := conn.(*net.TCPConn)
    tc.SetKeepAlive(true)
    tc.SetNoDelay(true)
    s.conn = tc

    s.io = &PacketIo{}
    s.io.r = bufio.NewReaderSize(s.conn, 16*1024)
    s.io.w = tc

    data, err := s.io.readPacket()
    if err != nil {
        return err
    }

    if data[0] == ERR_HEADER {
        return errors.New("error packet")
    }

    if data[0] < MinProtocolVersion {
        return fmt.Errorf("version is too lower, current:%d", data[0])
    }

    pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
    connId := uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
    pos += 4
    salt := data[pos : pos+8]

    pos += 8 + 1
    capability := uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))

    pos += 2

    var status uint16
    var pluginName string
    if len(data) > pos {
        //skip charset
        pos++
        status = binary.LittleEndian.Uint16(data[pos : pos+2])
        pos += 2
        capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | capability
        pos += 2

        pos += 10 + 1
        salt = append(salt, data[pos:pos+12]...)
        pos += 13

        if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
            pluginName = string(data[pos : pos+end])
        } else {
            pluginName = string(data[pos:])
        }
    }

    fmt.Printf("conn_id:%v, status:%d, plugin:%v\n", connId, status, pluginName)

    //write
    capability = 500357
    length := 4 + 4 + 1 + 23
    length += len(s.Cfg.User) + 1

    pass := []byte(s.Cfg.Pass)
    auth := calPassword(salt[:20], pass)
    length += 1 + len(auth)
    data = make([]byte, length+4)

    data[4] = byte(capability)
    data[5] = byte(capability >> 8)
    data[6] = byte(capability >> 16)
    data[7] = byte(capability >> 24)

    //utf8
    data[12] = byte(33)
    pos = 13 + 23
    if len(s.Cfg.User) > 0 {
        pos += copy(data[pos:], s.Cfg.User)
    }

    pos++
    data[pos] = byte(len(auth))
    pos += 1 + copy(data[pos+1:], auth)

    err = s.io.writePacket(data)
    if err != nil {
        return fmt.Errorf("write auth packet error")
    }

    pk, err := s.io.readPacket()
    if err != nil {
        return err
    }

    if pk[0] == OK_HEADER {
        fmt.Println("handshake ok ")
        return nil
    } else if pk[0] == ERR_HEADER {
        s.io.HandleError(pk)
        return errors.New("handshake error ")
    }

    return nil
}

func (s *Server) writeDumpCommand() {
    s.io.seq = 0
    data := make([]byte, 4+1+4+2+4+len(s.Cfg.LogFile))
    pos := 4
    data[pos] = 18 //dump binlog
    pos++
    binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.Position))
    pos += 4

    //dump command flag
    binary.LittleEndian.PutUint16(data[pos:], 0)
    pos += 2

    binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
    pos += 4

    copy(data[pos:], s.Cfg.LogFile)

    s.io.writePacket(data)
    //ok
    res, _ := s.io.readPacket()
    if res[0] == OK_HEADER {
        fmt.Println("send dump command return ok.")
    } else {
        s.io.HandleError(res)
    }
}

func (s *Server) register() {
    s.io.seq = 0
    hostname, _ := os.Hostname()
    data := make([]byte, 4+1+4+1+len(hostname)+1+len(s.Cfg.User)+1+len(s.Cfg.Pass)+2+4+4)
    pos := 4
    data[pos] = 21 //register slave  command
    pos++
    binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
    pos += 4

    data[pos] = uint8(len(hostname))
    pos++
    n := copy(data[pos:], hostname)
    pos += n

    data[pos] = uint8(len(s.Cfg.User))
    pos++
    n = copy(data[pos:], s.Cfg.User)
    pos += n

    data[pos] = uint8(len(s.Cfg.Pass))
    pos++
    n = copy(data[pos:], s.Cfg.Pass)
    pos += n

    binary.LittleEndian.PutUint16(data[pos:], uint16(s.Cfg.Port))
    pos += 2

    binary.LittleEndian.PutUint32(data[pos:], 0)
    pos += 4

    //master id = 0
    binary.LittleEndian.PutUint32(data[pos:], 0)

    s.io.writePacket(data)

    //ok
    res, _ := s.io.readPacket()
    if res[0] == OK_HEADER {
        fmt.Println("register success.")
        s.registerSucc = true
    } else {
        s.io.HandleError(data)
    }
}

func (s *Server) writeCommand(command byte) {
    s.io.seq = 0
    _ = s.io.writePacket([]byte{
        0x01, //1 byte long
        0x00,
        0x00,
        0x00, //seq
        command,
    })
}

func (s *Server) query(q string) error {
    s.io.seq = 0
    length := len(q) + 1
    data := make([]byte, length+4)
    data[4] = 3
    copy(data[5:], q)
    return s.io.writePacket(data)
}

func (s *Server) Quit() {
    //quit
    s.writeCommand(byte(1))
    //maybe only close
    if err := s.conn.Close(); nil != err {
        fmt.Printf("error in close :%v\n", err)
    }
}


type PacketIo struct {
    r   *bufio.Reader
    w   io.Writer
    seq uint8
}

func (p *PacketIo) readPacket() ([]byte, error) {
    //to read header
    header := []byte{0, 0, 0, 0}
    if _, err := io.ReadFull(p.r, header); err != nil {
        return nil, err
    }

    length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
    if length == 0 {
        p.seq++
        return []byte{}, nil
    }

    if length == 1 {
        return nil, fmt.Errorf("invalid payload")
    }

    seq := uint8(header[3])
    if p.seq != seq {
        return nil, fmt.Errorf("invalid seq %d", seq)
    }

    p.seq++
    data := make([]byte, length)
    if _, err := io.ReadFull(p.r, data); err != nil {
        return nil, err
    } else {
        if length < MaxPayloadLength {
            return data, nil
        }
        var buf []byte
        buf, err = p.readPacket()
        if err != nil {
            return nil, err
        }
        if len(buf) == 0 {
            return data, nil
        } else {
            return append(data, buf...), nil
        }
    }
}

func (p *PacketIo) writePacket(data []byte) error {
    length := len(data) - 4
    if length >= MaxPayloadLength {
        data[0] = 0xff
        data[1] = 0xff
        data[2] = 0xff
        data[3] = p.seq

        if n, err := p.w.Write(data[:4+MaxPayloadLength]); err != nil {
            return fmt.Errorf("write find error")
        } else if n != 4+MaxPayloadLength {
            return fmt.Errorf("not equal max pay load length")
        } else {
            p.seq ++
            length -= MaxPayloadLength
            data = data[MaxPayloadLength:]
        }
    }

    data[0] = byte(length)
    data[1] = byte(length >> 8)
    data[2] = byte(length >> 16)
    data[3] = p.seq

    if n, err := p.w.Write(data); err != nil {
        return errors.New("write find error")
    } else if n != len(data) {
        return errors.New("not equal length")
    } else {
        p.seq ++
        return nil
    }
}

func calPassword(scramble, password []byte) []byte {
    crypt := sha1.New()
    crypt.Write(password)
    stage1 := crypt.Sum(nil)

    crypt.Reset()
    crypt.Write(stage1)
    hash := crypt.Sum(nil)

    crypt.Reset()
    crypt.Write(scramble)
    crypt.Write(hash)
    scramble = crypt.Sum(nil)

    for i := range scramble {
        scramble[i] ^= stage1[i]
    }

    return scramble
}

func (p *PacketIo) HandleError(data []byte) {
    pos := 1
    code := binary.LittleEndian.Uint16(data[pos:])
    pos += 2
    pos++
    state := string(data[pos : pos+5])
    pos += 5
    msg := string(data[pos:])
    fmt.Printf("code:%d, state:%s, msg:%s\n", code, state, msg)
}

  1. main
package main

import (
    "flag"
    "fmt"
    "github.com/igoso/gbinlog/app"
    "os"
    "os/signal"
    "runtime"
    "syscall"
)

var myHost = flag.String("host", "127.0.0.1", "MySQL replication host")
var myPort = flag.Int("port", 3306, "MySQL replication port")
var myUser = flag.String("user", "root", "MySQL replication user")
var myPass = flag.String("pass", "****", "MySQL replication pass")
var serverId = flag.Int("server_id", 1111, "MySQL replication server id")

func main() {
    sc := make(chan os.Signal, 1)
    signal.Notify(sc,
        os.Kill,
        os.Interrupt,
        syscall.SIGHUP,
        syscall.SIGQUIT,
        syscall.SIGINT,
        syscall.SIGTERM,
    )

    runtime.GOMAXPROCS(runtime.NumCPU()/4 + 1)
    flag.Parse()
    cfg := &app.Config{
        *myHost,
        *myPort,
        *myUser,
        *myPass,
        *serverId,
        "mysql-bin.000032",
        3070,
    }
    srv := &app.Server{Cfg: cfg}
    go srv.Run()

    select {
    case n := <-sc:
        srv.Quit()
        fmt.Printf("receive signal %v, closing", n)
    }
}

  1. go.mod
    只有一个依赖
module github.com/igoso/gbinlog

go 1.15

require (
    github.com/siddontang/go-mysql v1.1.0
)

其他

注意如果使用binlog dump 连接执行quit命令,在MySQL端查看,不会立刻消失,处在close_wait状态。当下次再次有新的连接过来后,才会消失并建立新的。中间可能有1236:相同对的server_id存在的错误,但不影响使用

本来在尝试自己解析binlog,如果实际做的话工作量还是很大的,以为有很多种类的binlog event需要处理。后来在siddentang的go-mysql包中发现已经有实现了一个很好用的binlogSyncer,其中就有完善的解析方法。包括他实现的binlogSyncer也非常方便,感兴趣的可以参考如下。

package main

import (
    "context"
    "flag"
    "fmt"
    "os"

    "github.com/pingcap/errors"
    "github.com/siddontang/go-mysql/mysql"
    "github.com/siddontang/go-mysql/replication"
)

var host = flag.String("host", "127.0.0.1", "MySQL host")
var port = flag.Int("port", 3306, "MySQL port")
var user = flag.String("user", "root", "MySQL user, must have replication privilege")
var password = flag.String("password", "****", "MySQL password")

var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb")

var file = flag.String("file", "mysql-bin.000032", "Binlog filename")
var pos = flag.Int("pos", 3070, "Binlog position")

var semiSync = flag.Bool("semisync", false, "Support semi sync")
var backupPath = flag.String("backup_path", "", "backup path to store binlog files")

var rawMode = flag.Bool("raw", false, "Use raw mode")

func main() {
    flag.Parse()

    cfg := replication.BinlogSyncerConfig{
        ServerID: 101,
        Flavor:   *flavor,

        Host:            *host,
        Port:            uint16(*port),
        User:            *user,
        Password:        *password,
        RawModeEnabled:  *rawMode,
        SemiSyncEnabled: *semiSync,
        UseDecimal:      true,
    }

    b := replication.NewBinlogSyncer(cfg)

    pos := mysql.Position{Name: *file, Pos: uint32(*pos)}
    if len(*backupPath) > 0 {
        // Backup will always use RawMode.
        err := b.StartBackup(*backupPath, pos, 0)
        if err != nil {
            fmt.Printf("Start backup error: %v\n", errors.ErrorStack(err))
            return
        }
    } else {
        s, err := b.StartSync(pos)
        if err != nil {
            fmt.Printf("Start sync error: %v\n", errors.ErrorStack(err))
            return
        }

        for {
            e, err := s.GetEvent(context.Background())
            if err != nil {
                // Try to output all left events
                events := s.DumpEvents()
                for _, e := range events {
                    e.Dump(os.Stdout)
                }
                fmt.Printf("Get event error: %v\n", errors.ErrorStack(err))
                return
            }

            e.Dump(os.Stdout)
        }
    }

}

以上就是本期的全部内容。

你可能感兴趣的:(使用golang开发MySQL binlog同步工具demo)