golang - sync.WaitGroup

go 版本基于1.18

结构体

结构体定义如下:

type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers only guarantee that 64-bit fields are 32-bit aligned.
    // For this reason on 32 bit architectures we need to check in state()
    // if state1 is aligned or not, and dynamically "swap" the field order if
    // needed.
    state1 uint64
    state2 uint32
}

当我们初始化一个WaitGroup对象时,其counter值、waiter值、semap值均为0

  • noCopy :
    空结构体,它并不会占用内存,编译器也不会对其进行字节填充。它主要是为了通过go vet工具来做静态编译检查,主要作用是防止开发者在使用WaitGroup过程中对其进行了复制,从而导致的安全隐患

  • state1, state2:
    主要代表三部分内容:

    1. 通过Add()设置的子goroutine的计数值counter
    2. 通过Wait()陷入阻塞的waiter数
    3. 信号量semap

    其中在64位 的操作系统中(对齐系数为8), 此时state1 的的高32 位代表计数器counter, 低32位代表waiter 数, state2 代表信号量
    在32 位的操作系统中(对齐系数为4), 此时将state1 和state2 unsafe.Pointer() 转化为[3]uint32的state数组,其中state[0] 代表信号量semap, state[0]作为uint64的高32位,即counter, state[1] 作为uint64的低32位, 即waiter。 具体实现的代码如下

    state方法就是返回对应的计数(counter,waiter)和信号量(semap)

      func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
              if unsafe.Alignof(wg.state1) == 8 ||     uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
          // state1 is 64-bit aligned: nothing to do.
              return &wg.state1, &wg.state2
          } else {
              // state1 is 32-bit aligned but not 64-bit aligned: this   means that
              // (&state1)+4 is 64-bit aligned.
              state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
          return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
        }
      }
    

方法

1. Add

  • 源码实现
func (wg *WaitGroup) Add(delta int) {
        // 获取计数器和信号量
    statep, semap := wg.state()
        // 竞争检测相关,与功能无关,忽略
    if race.Enabled {
        _ = *statep // trigger nil deref early
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()
    }
        // 计数值加上 delta: statep 的前四个字节是计数值,因此将 delta 前移 32位
    state := atomic.AddUint64(statep, uint64(delta)<<32)
        // 当前的counter计数值
    v := int32(state >> 32)
        // 当前的waiter 计数值
    w := uint32(state)
        // 竞争检测,忽略
    if race.Enabled && delta > 0 && v == int32(delta) {
        // The first increment must be synchronized with Wait.
        // Need to model this as a read, because there can be
        // several concurrent wg.counter transitions from 0.
        race.Read(unsafe.Pointer(semap))
    }
        // counter 计数值<0 , 曝panic 异常
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    // delta > 0 && v == int32(delta) : 表示从 0 开始添加计数值
   // w!=0 :表示已经有了等待者
   // 说明在添加counter计数值的时候,同时添加了等待者,非法操作。添加等待者需要在添加计数值之后
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
         // v>0 : 计数值不等于0,不需要唤醒等待者,直接返回
         // w==0: 没有等待者,不需要唤醒,直接返回
    if v > 0 || w == 0 {
        return
    }
    // This goroutine has set counter to 0 when waiters > 0.
    // Now there can't be concurrent mutations of state:
    // - Adds must not happen concurrently with Wait,
    // - Wait does not increment waiters if it sees counter == 0.
    // Still do a cheap sanity check to detect WaitGroup misuse.
          // 再次检查数据是否一致
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

     // 到这里说明计数值为0,且等待者大于0,需要唤醒所有的等待者,并把系统置为初始状态(0状态)
  // 将计数值和等待者数量都置为0
    *statep = 0
          // 唤醒等待者
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

2. Done

  • 源码
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

完成一个任务,将计数值减一,当计数值减为0时,需要唤醒所有的等待者

3.Wait

  • 源码
func (wg *WaitGroup) Wait() {
        // 获取计数器和信号量
    statep, semap := wg.state()
        // 竞争检测,忽略
    if race.Enabled {
        _ = *statep // trigger nil deref early
        race.Disable()
    }
        
    for {
                // 原子操作,获取计数器值
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        w := uint32(state)
                // 所有任务都完成了,counter =0,此时直接退出,即不阻塞
        if v == 0 {
            // Counter is 0, no need to wait.
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
        // waiter 计数器加一
                // 这里会有竞争,比如多个 Wait 调用,或者在同时调用 Add 方法,增加不成功会继续 for 循环
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(semap))
            }
                        //   // 增加成功后,阻塞在信号量这里,等待被唤醒
            runtime_Semacquire(semap)
                         // 被唤醒的时候,计数器应该是0状态。如果重用 WaitGroup,需要等 Wait 返回
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
    }
}

注意事项

  • 保证 Add 在 Wait 前调用: 确保在子go 程中不使用Add 方法, 又可能导致和wait 造成竞争冲突,最后导致panic
  • Add 函数不要传入负值,有可能导致panic 或者导致 wait 函数中 信号量P 操作死锁等待
  • 不要复制使用 WaitGroup,函数传递时使用指针传递, WaitGroup 不支持复制操作, 可用go tool vet 检查是否对WaitGroup 复制使用
  • 尽量不复用 WaigGroup,减少出问题的风险, 复用的前提要在wait 函数返回之后

使用示例

package main

import (
    "sync"
)

type httpPkg struct{}

func (httpPkg) Get(url string) {}

var http httpPkg

func main() {
    var wg sync.WaitGroup
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.example.com/",
    }
    for _, url := range urls {
        // Increment the WaitGroup counter.
        wg.Add(1)
        // Launch a goroutine to fetch the URL.
        go func(url string) {
            // Decrement the counter when the goroutine completes.
            defer wg.Done()
            // Fetch the URL.
            http.Get(url)
        }(url)
    }
    // Wait for all HTTP fetches to complete.
    wg.Wait()
}

你可能感兴趣的:(golang - sync.WaitGroup)