WaitGroup是多个goroutine之间协作的一种实现方式,主要功能就是阻塞等待一组goroutine执行完成。
常用的使用场景:主goroutine调用Add函数设置需要等待的goroutine的数量,当每个goroutine执行完成后调用Done函数(将counter减1),Wait函数用于阻塞等待直到该组中的所有goroutine都执行完成。
源码中主要设计到的三个概念:counter、waiter和semaphore
counter: 当前还未执行结束的goroutine计数器
waiter : 等待goroutine-group结束的goroutine数量,即有多少个等候者
semaphore: 信号量
信号量是Unix系统提供的一种保护共享资源的机制,用于防止多个线程同时访问某个资源。
可简单理解为信号量为一个数值:
当信号量>0时,表示资源可用,获取信号量时系统自动将信号量减1;
当信号量=0时,表示资源暂不可用,获取信号量时,当前线程会进入睡眠,当信号量为正时被唤醒。
Golang源码版本 :1.10.3
1.结构体
type WaitGroup struct {
noCopy noCopy //该WaitGroup对象不允许拷贝使用,只能用指针传递
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers do not ensure it. So we allocate 12 bytes and then use
// the aligned 8 bytes in them as state.
//用于存储计数器(counter)和waiter的值
// 只需要64位,即8个字节,其中高32位是counter值,低32位值是waiter值
// 不直接使用uint64,是因为uint64的原子操作需要64位系统,而32位系统下,可能会出现崩溃
// 所以这里用byte数组来实现,32位系统下4字节对齐,64位系统下8字节对齐,所以申请12个字节,其中必定有8个字节是符合8字节对齐的,下面的state()函数中有进行判断
state1 [12]byte
sema uint32 //信号量
}
从结构体中我们看到
state1是一个12位长度的byte数组,用于存储counter和waiter的值
sema就是传说中的信号量
2.state函数
state是一个内部函数,用于获取counter和 waiter的值
//获取counter 、 waiter的值 (counter是uint64的高32位,waiter是uint64的低32位)
func (wg *WaitGroup) state() *uint64 {
// 根据state1的起始地址分析,若是8字节对齐的,则直接用前8个字节作为*uint64类型
// 若不是,说明是4字节对齐,则后移4个字节后,这样必为8字节对齐,然后取后面8个字节作为*uint64类型
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1))
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[4]))
}
}
3.Add方法
//用于增加或减少计数器(counter)的值
//如果计数器为0,则释放调用Wait方法时的阻塞,如果计数器为负,则panic
//Add()方法应该在Wait()方法调用之前
func (wg *WaitGroup) Add(delta int) {
//获取当前counter和 waiter的值
statep := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
//将delta的值添加到counter上
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) //counter值
w := uint32(state) //waiter值
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
//counter为负数,则触发panic
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// waiter值不为0,累加后的counter值和delta相等,说明Wait()方法没有在Add()方法之后调用,触发panic,因为正确的做法是先Add()后Wait()
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
//Add()添加正常返回
//1.counter > 0,说明还不需要释放信号量,可以直接返回
//2. waiter = 0 ,说明没有等待的goroutine,也不需要释放信号量,可以直接返回
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
//下面是 counter == 0 并且 waiter > 0的情况
//现在若原state和新的state不等,则有以下两种可能
//1. Add 和 Wait方法同时调用
//2. counter已经为0,但waiter值有增加,这种情况永远不会触发信号量了
// 以上两种情况都是错误的,所以触发异常
//注:state := atomic.AddUint64(statep, uint64(delta)<<32) 这一步调用之后,state和*statep的值应该是相等的,除非有以上两种情况发生
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
//将waiter 和 counter都置为0
*statep = 0
//原子递减信号量,并通知等待的goroutine
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false)
}
}
4.Done方法
// Done decrements the WaitGroup counter by one.
//将计数器(counter)的值减1
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
5.Wait方法
// Wait blocks until the WaitGroup counter is zero.
//调用Wait方法会阻塞当前调用的goroutine直到 counter的值为0
//也会增加waiter的值
func (wg *WaitGroup) Wait() {
//获取当前counter和 waiter的值
statep := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()
}
//一直等待,直到无需等待或信号量触发,才返回
for {
state := atomic.LoadUint64(statep)
v := int32(state >> 32) //counter值
w := uint32(state) //waiter值
//如果counter值为0,则说明所有goroutine都退出了,无需等待,直接退出
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
//原子增加waiter的值,CAS方法,外面for循环会一直尝试,保证多个goroutine同时调用Wait()也能正常累加waiter
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
//一直等待信号量sema,直到信号量触发,
runtime_Semacquire(&wg.sema)
//从上面的Add()方法看到,触发信号量之前会将seatep置为0(即counter和waiter都置为0),所以此时应该也为0
//如果不为0,说明WaitGroup此时又执行了Add()或者Wait()操作,所以会触发panic
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
1.Add()必须在Wait()前调用
2.Add()设置的值必须与实际等待的goroutine个数一致,如果设置的值大于实际的goroutine数量,可能会一直阻塞。如果小于会触发panic
3. WaitGroup不可拷贝,可以通过指针传递,否则很容易造成BUG
以下为值拷贝引起的Bug示例
demo1:因为值拷贝引起的死锁
func main() {
var wg sync.WaitGroup
wg.Add(5)
for i := 0 ; i < 5 ; i++ {
test(wg)
}
wg.Wait()
}
func test(wg sync.WaitGroup) {
go func() {
fmt.Println("hello")
wg.Done()
}()
}
demo2:因为值拷贝引起的不会阻塞等待现象
func main() {
var wg sync.WaitGroup
for i := 0 ; i < 5 ; i++ {
test(wg)
}
wg.Wait()
}
func test(wg sync.WaitGroup) {
go func() {
wg.Add(1)
fmt.Println("hello")
time.Sleep(time.Second*5)
wg.Done()
}()
}
demo3:因为值拷贝引发的panic
type person struct {
wg sync.WaitGroup
}
func (t *person) say() {
go func() {
fmt.Println("say Hello!")
time.Sleep(time.Second*5)
t.wg.Done()
}()
}
func main() {
var wg sync.WaitGroup
t := person{wg:wg}
wg.Add(5)
for i := 0 ; i< 5 ;i++ {
t.say()
}
wg.Wait()
}
感谢:https://blog.csdn.net/yzf279533105/article/details/97302666