go-zero源码阅读-负载均衡(下)#第六期

一致性哈希

一致性哈希主要针对的是缓存服务做负载均衡,以保证缓存节点变更后缓存失效过多,导致缓存穿透,从而把数据库打死。

一致性哈希原理可以参考这篇文章图解一致性哈希算法,细节剖析本文不再赘述。

我们来看看其核心算法

// service node 结构体定义
type ServiceNode struct {
    Ip    string
    Port  string
    Index int
}

// 返回service node实例
func NewServiceNode(ip, port string) *ServiceNode {
    return &ServiceNode{
        Ip:   ip,
        Port: port,
    }
}

func (sn *ServiceNode) SetIndex(index int) {
    sn.Index = index
}

type UInt32Slice []uint32

// Len()
func (s UInt32Slice) Len() int {
    return len(s)
}

// Less()
func (s UInt32Slice) Less(i, j int) bool {
    return s[i] < s[j]
}

// Swap()
func (s UInt32Slice) Swap(i, j int) {
    s[i], s[j] = s[j], s[i]
}

// 虚拟节点结构定义
type VirtualNode struct {
    VirtualNodes map[uint32]*ServiceNode
    NodeKeys     UInt32Slice
    sync.RWMutex
}

// 实例化虚拟节点对象
func NewVirtualNode() *VirtualNode {
    return &VirtualNode{
        VirtualNodes: map[uint32]*ServiceNode{},
    }
}

// 添加虚拟节点
func (v *VirtualNode) AddVirtualNode(serviceNode *ServiceNode, virtualNum uint) {
    // 并发读写map-加锁
    v.Lock()
    defer v.Unlock()
    for i := uint(0); i < virtualNum; i++ {
        hashStr := serviceNode.Ip + ":" + serviceNode.Port + ":" + strconv.Itoa(int(i))
        v.VirtualNodes[v.getHashCode(hashStr)] = serviceNode
    }
    // 虚拟节点hash值排序
    v.sortHash()
}

// 移除虚拟节点
func (v *VirtualNode) RemoveVirtualNode(serviceNode *ServiceNode, virtualNum uint) {
    // 并发读写map-加锁
    v.Lock()
    defer v.Unlock()
    for i := uint(0); i < virtualNum; i++ {
        hashStr := serviceNode.Ip + ":" + serviceNode.Port + ":" + strconv.Itoa(int(i))
        delete(v.VirtualNodes, v.getHashCode(hashStr))
    }
    v.sortHash()
}

// 获取虚拟节点(二分查找)
func (v *VirtualNode) GetVirtualNodel(routeKey string) *ServiceNode {
    // 并发读写map-加读锁,可并发读不可同时写
    v.RLock()
    defer v.RUnlock()
    index := 0
    hashCode := v.getHashCode(routeKey)
    i := sort.Search(len(v.NodeKeys), func(i int) bool { return v.NodeKeys[i] > hashCode })
    // 当i大于下标最大值时,证明没找到, 给到第0个虚拟节点, 当i小于node节点数时, index为当前节点
    if i < len(v.NodeKeys) {
        index = i
    } else {
        index = 0
    }
    // 返回具体节点
    return v.VirtualNodes[v.NodeKeys[index]]
}

// hash数值排序
func (v *VirtualNode) sortHash() {
    v.NodeKeys = nil
    for k := range v.VirtualNodes {
        v.NodeKeys = append(v.NodeKeys, k)
    }
    sort.Sort(v.NodeKeys)
}

// 获取hash code(采用md5字符串后计算)
func (v *VirtualNode) getHashCode(nodeHash string) uint32 {
    // crc32方式hash code
    // return crc32.ChecksumIEEE([]byte(nodeHash))
    md5 := md5.New()
    md5.Write([]byte(nodeHash))
    md5Str := hex.EncodeToString(md5.Sum(nil))
    h := 0
    byteHash := []byte(md5Str)
    for i := 0; i < 32; i++ {
        h <<= 8
        h |= int(byteHash[i]) & 0xFF
    }
    return uint32(h)
}

我们来写测试代码,测试下

func Test_HashConsistency(t *testing.T) {
    // 实例化10个实体节点
    var serverNodes []*hashconsistency.ServiceNode
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3300"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3301"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3302"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3303"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3304"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3305"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3306"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3307"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3308"))
    serverNodes = append(serverNodes, hashconsistency.NewServiceNode("127.0.0.1", "3309"))
    serverNodesLen := uint(len(serverNodes))
    virtualNodeService := hashconsistency.NewVirtualNode()
    // 添加对应的虚拟化节点数
    for _, sn := range serverNodes {
        virtualNodeService.AddVirtualNode(sn, serverNodesLen)
    }
    // 打印节点列表
    var nodes1, nodes2 []string
    fmt.Println("-------- node 调度顺序--------")
    for i := 1; i <= 20; i++ {
        // 移除node2节点
        if i == 11 {
            virtualNodeService.RemoveVirtualNode(serverNodes[1], serverNodesLen)
        }
        cacheKey := fmt.Sprintf("user:id:%d", i%10)
        // 获取对应节点地址
        serviceNode := virtualNodeService.GetVirtualNodel(cacheKey)
        str := fmt.Sprintf("node: %s cachekey: %s", serviceNode.Ip+":"+serviceNode.Port, cacheKey)
        if i <= 10 {
            nodes1 = append(nodes1, str)
        } else {
            nodes2 = append(nodes2, str)
        }
    }
    utils.PrintDiff(strings.Join(nodes1, "\n"), strings.Join(nodes2, "\n"))
}

测试结果如下:

-------- node 调度顺序--------
-node: 127.0.0.1:3301 cachekey: user:id:1 // node1宕机
+node: 127.0.0.1:3300 cachekey: user:id:1 // 原node1的缓路由到此node0
 node: 127.0.0.1:3309 cachekey: user:id:2
 node: 127.0.0.1:3309 cachekey: user:id:3
 node: 127.0.0.1:3309 cachekey: user:id:4
 node: 127.0.0.1:3300 cachekey: user:id:5
 node: 127.0.0.1:3307 cachekey: user:id:6
-node: 127.0.0.1:3301 cachekey: user:id:7 // node1宕机
+node: 127.0.0.1:3302 cachekey: user:id:7 // 原node1的缓路由到此node2
 node: 127.0.0.1:3305 cachekey: user:id:8
-node: 127.0.0.1:3301 cachekey: user:id:9 // node1宕机
+node: 127.0.0.1:3300 cachekey: user:id:9 // 原node1的缓路由到此node0
 node: 127.0.0.1:3309 cachekey: user:id:0

从测试中可以看出宕机的node都被自动路由到最近的node,而没有宕机的node继续承接旧的缓存key,说明通过一致性哈希算法,可以保证我们的缓存不会因为服务宕机操作大面积缓存失效的问题

我们再把一致性哈希算法带入到服务中,来看看效果如何

// Config is a configuration.
type Config struct {
    Proxy                     Proxy   `json:"proxy"`
    Nodes                     []*Node `json:"nodes"`
    HashConsistency           *VirtualNode
    HashConsistencyVirtualNum uint
}

// Proxy is a reverse proxy, and means load balancer.
type Proxy struct {
    Url string `json:"url"`
}

// Node is servers which load balancer is transferred.
type Node struct {
    URL      string `json:"url"`
    IsDead   bool
    UseCount int
    mu       sync.RWMutex
}

var cfg Config

func init() {
    data, err := ioutil.ReadFile("./config.json")
    if err != nil {
        log.Fatal(err.Error())
    }
    json.Unmarshal(data, &cfg)
    if cfg.HashConsistencyVirtualNum == 0 {
        cfg.HashConsistencyVirtualNum = 10
    }
    cfg.HashConsistency = NewVirtualNode()
    for i, node := range cfg.Nodes {
        addr := strings.Split(node.URL, ":")
        serviceNode := NewServiceNode(addr[0], addr[1])
        serviceNode.SetIndex(i)
        cfg.HashConsistency.AddVirtualNode(serviceNode, cfg.HashConsistencyVirtualNum)
    }
}

func GetCfg() Config {
    return cfg
}

// SetDead updates the value of IsDead in node.
func (node *Node) SetDead(b bool) {
    node.mu.Lock()
    node.IsDead = b
    addr := strings.Split(node.URL, ":")
    serviceNode := NewServiceNode(addr[0], addr[1])
    cfg.HashConsistency.RemoveVirtualNode(serviceNode, cfg.HashConsistencyVirtualNum)
    node.mu.Unlock()
}

// GetIsDead returns the value of IsDead in node.
func (node *Node) GetIsDead() bool {
    node.mu.RLock()
    isAlive := node.IsDead
    node.mu.RUnlock()
    return isAlive
}

var mu sync.Mutex

// rrlbbHandler is a handler for round robin load balancing
func rrlbbHandler(w http.ResponseWriter, r *http.Request) {
    // Round Robin
    mu.Lock()
    cacheKey := r.Header.Get("cache-key")
    virtualNodel := cfg.HashConsistency.GetVirtualNodel(cacheKey)
    targetURL, err := url.Parse(fmt.Sprintf("http://%s:%s", virtualNodel.Ip, virtualNodel.Port))
    if err != nil {
        log.Fatal(err.Error())
    }
    currentNode := cfg.Nodes[virtualNodel.Index]
    currentNode.UseCount++
    if currentNode.GetIsDead() {
        rrlbbHandler(w, r)
        return
    }
    mu.Unlock()
    reverseProxy := httputil.NewSingleHostReverseProxy(targetURL)
    reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, e error) {
        // NOTE: It is better to implement retry.
        log.Printf("%v is dead.", targetURL)
        currentNode.SetDead(true)
        rrlbbHandler(w, r)
    }
    w.Header().Add("balancer-node", virtualNodel.Ip+virtualNodel.Port)
    reverseProxy.ServeHTTP(w, r)
}

// pingNode checks if the node is alive.
func isAlive(url *url.URL) bool {
    conn, err := net.DialTimeout("tcp", url.Host, time.Minute*1)
    if err != nil {
        log.Printf("Unreachable to %v, error %s:", url.Host, err.Error())
        return false
    }
    defer conn.Close()
    return true
}

// healthCheck is a function for healthcheck
func healthCheck() {
    t := time.NewTicker(time.Minute * 1)
    for {
        select {
        case <-t.C:
            for _, node := range cfg.Nodes {
                pingURL, err := url.Parse(node.URL)
                if err != nil {
                    log.Fatal(err.Error())
                }
                isAlive := isAlive(pingURL)
                node.SetDead(!isAlive)
                msg := "ok"
                if !isAlive {
                    msg = "dead"
                }
                log.Printf("%v checked %s by healthcheck", node.URL, msg)
            }
        }
    }
}

// ProxyServerStart serves a proxy
func ProxyServerStart() {
    var err error
    go healthCheck()
    s := http.Server{
        Addr:    cfg.Proxy.Url,
        Handler: http.HandlerFunc(rrlbbHandler),
    }
    if err = s.ListenAndServe(); err != nil {
        log.Fatal(err.Error())
    }
}

// ProxyServerStart serves a node
func NodeServerStart() {
    http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
        w.Write([]byte("pong"))
    })
    wg := new(sync.WaitGroup)
    wg.Add(len(cfg.Nodes))
    for i, node := range cfg.Nodes {
        go func() {
            if i != 0 {
                log.Fatal(http.ListenAndServe(node.URL, nil))
            }
            // log.Fatal(http.ListenAndServe(node.URL, nil))
            wg.Done()
        }()
        time.Sleep(time.Millisecond * 100)
    }
    wg.Wait()
}

编写测试代码测试下:

func Test_HashConsistencyWithServer(t *testing.T) {
    go hashconsistency.NodeServerStart()
    time.Sleep(time.Millisecond * 200)
    go hashconsistency.ProxyServerStart()
    time.Sleep(time.Millisecond * 100)
    for _, tt := range [...]struct {
        name, method, uri string
        body              io.Reader
        want              *http.Request
        wantBody          string
    }{
        {
            name:     "GET with ping url",
            method:   "GET",
            uri:      "http://127.0.0.1:8080/ping",
            body:     nil,
            wantBody: "pong",
        },
    } {
        t.Run(tt.name, func(t *testing.T) {
            fmt.Println("-------- node 调度顺序--------")
            var nodes1, nodes2 []string
            for i := 1; i <= 20; i++ {
                cacheKey := fmt.Sprintf("user:id:%d", i%10)
                cli := utils.NewHttpClient().
                    SetHeader(map[string]string{
                        "cache-key": cacheKey,
                    }).SetMethod(tt.method).SetUrl(tt.uri).SetBody(tt.body)
                err := cli.Request(nil)
                if err != nil {
                    t.Errorf("ReadAll: %v", err)
                }
                str := fmt.Sprintf("node: %s cachekey: %s", cli.GetRspHeader().Get("balancer-node"), cacheKey)
                if err != nil {
                    t.Errorf("ReadAll: %v", err)
                }
                if string(cli.GetRspBody()) != tt.wantBody {
                    t.Errorf("Body = %q; want %q", cli.GetRspBody(), tt.wantBody)
                }
                if i <= 10 {
                    nodes1 = append(nodes1, str)
                } else {
                    nodes2 = append(nodes2, str)
                }
            }
            utils.PrintDiff(strings.Join(nodes1, "\n"), strings.Join(nodes2, "\n"))
            fmt.Println("-------- node 调用次数 --------")
            for _, node := range hashconsistency.GetCfg().Nodes {
                log.Printf("node: %s useCount: %d", node.URL, node.UseCount)
            }
        })
    }
}

测试结果如下:

-------- node 调度顺序--------
2022/04/08 15:14:55 http://127.0.0.1:8081 is dead.
 node: 127.0.0.18082 cachekey: user:id:1
-node: 127.0.0.18081 cachekey: user:id:2
+node: 127.0.0.18083 cachekey: user:id:2
 node: 127.0.0.18083 cachekey: user:id:3
 node: 127.0.0.18082 cachekey: user:id:4
 node: 127.0.0.18082 cachekey: user:id:5
 node: 127.0.0.18082 cachekey: user:id:6
 node: 127.0.0.18083 cachekey: user:id:7
 node: 127.0.0.18083 cachekey: user:id:8
 node: 127.0.0.18082 cachekey: user:id:9
 node: 127.0.0.18083 cachekey: user:id:0
-------- node 调用次数 --------
2022/04/08 15:14:55 node: 127.0.0.1:8081 useCount: 1
2022/04/08 15:14:55 node: 127.0.0.1:8082 useCount: 10
2022/04/08 15:14:55 node: 127.0.0.1:8083 useCount: 10

测试结果符合预期,nice :)

go-zero

go-zero 的负载均衡算法通过替换 grpc 默认负载均衡算法来实现负载均衡

详细注释代码请参阅 https://github.com/TTSimple/g...

我们看看其中核心的两个算法

  • 一、牛顿冷却

原理请参阅 https://www.ruanyifeng.com/bl...

const (
    decayTime = int64(time.Second * 1) // 衰退时间
)

type NLOC struct{}

func NewNLOC() *NLOC {
    return &NLOC{}
}

func (n *NLOC) Hot(timex time.Time) float64 {
    td := time.Now().Unix() - timex.Unix()
    if td < 0 {
        td = 0
    }
    w := math.Exp(float64(-td) / float64(decayTime))
    // w, _ = utils.MathRound(w, 9)
    return w
}

我们来测试下:

func Test_NLOC(t *testing.T) {
    timer := time.NewTimer(time.Second * 10)
    quit := make(chan struct{})

    defer timer.Stop()
    go func() {
        <-timer.C
        close(quit)
    }()

    timex := time.Now()
    go func() {
        n := NewNLOC()
        ticker := time.NewTicker(time.Second * 1)
        for {
            <-ticker.C
            fmt.Println(n.Hot(timex))
        }
    }()

    for {
        <-quit
        return
    }
}

测试结果如下:

0.999999900000005
0.99999980000002
0.999999700000045
0.99999960000008
0.999999500000125
0.99999940000018
0.999999300000245
0.99999920000032
0.999999100000405
0.9999990000005

从上面结果中可以看出,热度是随时间逐渐衰退的

  • 二、EWMA 滑动平均

原理请参阅 https://blog.csdn.net/mzpmzk/...

const (
    AVG_METRIC_AGE float64 = 30.0
    DECAY float64 = 2 / (float64(AVG_METRIC_AGE) + 1)
)

type SimpleEWMA struct {
    // 当前平均值。在用Add()添加后,这个值会更新所有数值的平均值。
    value float64
}

// 添加并更新滑动平均值
func (e *SimpleEWMA) Add(value float64) {
    if e.value == 0 { // this is a proxy for "uninitialized"
        e.value = value
    } else {
        e.value = (value * DECAY) + (e.value * (1 - DECAY))
    }
}

// 获取当前滑动平均值
func (e *SimpleEWMA) Value() float64 {
    return e.value
}

// 设置 ewma 值
func (e *SimpleEWMA) Set(value float64) {
    e.value = value
}

编写测试代码测试下:

const testMargin = 0.00000001

var samples = [100]float64{
    4599, 5711, 4746, 4621, 5037, 4218, 4925, 4281, 5207, 5203, 5594, 5149,
    4948, 4994, 6056, 4417, 4973, 4714, 4964, 5280, 5074, 4913, 4119, 4522,
    4631, 4341, 4909, 4750, 4663, 5167, 3683, 4964, 5151, 4892, 4171, 5097,
    3546, 4144, 4551, 6557, 4234, 5026, 5220, 4144, 5547, 4747, 4732, 5327,
    5442, 4176, 4907, 3570, 4684, 4161, 5206, 4952, 4317, 4819, 4668, 4603,
    4885, 4645, 4401, 4362, 5035, 3954, 4738, 4545, 5433, 6326, 5927, 4983,
    5364, 4598, 5071, 5231, 5250, 4621, 4269, 3953, 3308, 3623, 5264, 5322,
    5395, 4753, 4936, 5315, 5243, 5060, 4989, 4921, 4480, 3426, 3687, 4220,
    3197, 5139, 6101, 5279,
}

func withinMargin(a, b float64) bool {
    return math.Abs(a-b) <= testMargin
}

func TestSimpleEWMA(t *testing.T) {
    var e SimpleEWMA
    for _, f := range samples {
        e.Add(f)
    }
    fmt.Println(e.Value())
    if !withinMargin(e.Value(), 4734.500946466118) {
        t.Errorf("e.Value() is %v, wanted %v", e.Value(), 4734.500946466118)
    }
    e.Set(1.0)
    if e.Value() != 1.0 {
        t.Errorf("e.Value() is %v", e.Value())
    }
}

测试成功,加油!!!

引用文章:

你可能感兴趣的:(go-zero源码阅读-负载均衡(下)#第六期)