golang WaitGroup的使用与底层实现

使用的go版本为 go1.21.2

首先我们写一个简单的WaitGroup的使用代码

package main

import (
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	wg.Add(1)

	go func() {
		defer wg.Done()
		fmt.Println("xiaochuan")
	}()

	wg.Wait()
}

WaitGroup的基本使用场景就是等待子协程完毕后,执行主协程,比如我的api需要多个下游api支持开多个协程进行访问,等待耗时最高的api返回过来后执行,这种场景是比较适合WaitGroup的。

我们来看一下WaitGroup构造体相关的底层源码

WaitGroup结构体

//代码位于 GOROOT/src/sync/waitgroup.go L:23

type WaitGroup struct {
    //防止WaitGroup被复制, 君子协议,编译可以通过,某些编辑器会报waring
    //有兴趣可以看一下这里 https://github.com/golang/go/issues/8005#issuecomment-190753527
    noCopy noCopy

    // 高32位表示计数器,低32位表示等待的waiter数量。
    // 低版本go的state字段类型是[3]uint32,需要进行位数对齐
    state atomic.Uint64
    // 信号量
    sema  uint32
}
编辑器的warning

golang WaitGroup的使用与底层实现_第1张图片

Add函数

//代码位于 GOROOT/src/sync/waitgroup.go L:43

func (wg *WaitGroup) Add(delta int) {
	if race.Enabled { //使用竞态检查
		if delta < 0 { //如果传递的数值是负数,递减等待同步
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable() //竞态检查 禁用
		defer race.Enable() //竞态检查 启用
	}
	//计算我们要进行add的值,将其加入到比特位上
	//<< 32 为二进制左位移 32位
	state := wg.state.Add(uint64(delta) << 32)
	v := int32(state >> 32) // state变量的高位是计数
	w := uint32(state) // state变量的低位是waiter计数
	//使用竞态检查,当前传入的值与v相同,说明当前是第一次调度add
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(&wg.sema))
	}
	//如果 计数器小于0 说明了多进行了done操作或者add传递负数,业务代码的出现逻辑错误了
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	// 如果当前存在等待,而且计数器不为0
	// 说明当前有地方调度了Wait后,又进行add操作了, 违反了官方的使用设计
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 计数大于0,没有等待,就是单纯的add直接返回
	if v > 0 || w == 0 {
		return
	}

	// 再做一次检测,防止有并发调度
	// 比如我有两个goroutine A goroutine 在add, B goroutine 在调度 wait 
	// 刚刚好A加完了计数,B突然wait导致state更变就会触发这个panic
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 重置waiter为0
	wg.state.Store(0)
	for ; w != 0; w-- { // 逐步释放信号量
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

Done函数

//代码位于 GOROOT/src/sync/waitgroup.go L:86

//这个很简单 调用了一下add函数传了一个-1
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait函数

//代码位于 GOROOT/src/sync/waitgroup.go L:91

func (wg *WaitGroup) Wait() {
	if race.Enabled { //使用竞态检查
		race.Disable() //竞态检查 禁用
	}
	for {
		state := wg.state.Load() // 原子操作读取state字段
		v := int32(state >> 32) // state变量的高位是计数
		w := uint32(state) // state变量的低位是waiter计数
		if v == 0 { // 如果当前计数器为0 就没必要等待直接返回了
			if race.Enabled {
				race.Enable() //竞态检查 启用
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// 将waiter计数+1 因为waiter处于低32位所以不需要位移直接加就行了
		if wg.state.CompareAndSwap(state, state+1) {
			if race.Enabled && w == 0 { // 使用竞态检查,第一次进行wait操作
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(&wg.sema))
			}
			// 获取信号量,这行代码会进行G的阻塞
			runtime_Semacquire(&wg.sema)
			//重新获取一下state,正常来讲计数为0, waiter为0
			//执行判断之前,又有一个协程进行了add操作,会触发panic
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled { //使用竞态检查
				race.Enable() //竞态检查 启用
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

总结

我们从上面的源码分析了解WaitGroup的数据结构、Add、Done和Wait这些基本操作原理,在项目中我们可以使用比特位来减少内存的占用,从源码分析我们得知Go官方设计不允许进行WaitGroup复制(君子协议)与并发调度同一个WaitGroup操作。

你可能感兴趣的:(GoLang,golang,开发语言,后端)