在并发编程里,sync.WaitGroup 并发原语的使用频率非常高,它经常用于协同等待的场景:goroutine A 在检查
点等待一组执行任务的 worker goroutine 全部完成,如果在执行任务的这些 goroutine 还没全部完成,
goroutine A 就会阻塞在检查点,直到所有 woker goroutine 都完成后才能继续执行。
如果在 woker goroutine 的执行过程中遇到错误并想要处理该怎么办? WaitGroup 并没有提供传播错误的功能,
遇到这种场景我们该怎么办? Go 语言在扩展库提供了 ErrorGroup 并发原语正好适合在这种场景下使用,它在
WaitGroup 的基础上还提供了,错误传播以及上下文取消的功能。
Go 扩展库通过 errorgroup.Group 提供 ErrorGroup 原语的功能,它有三个方法可调用:
func WithContext(ctx context.Context) (*Group, context.Context)
func (g *Group) Go(f func() error)
func (g *Group) Wait() error
接下来我们让主 goroutine 使用 ErrorGroup 代替 WaitGroup 等待所以子任务的完成,ErrorGroup 有一个特点是
会返回所以执行任务的 goroutine 遇到的第一个错误。试着执行一下下面的程序,观察程序的输出。
package main
import (
"fmt"
"golang.org/x/sync/errgroup"
"net/http"
)
func main() {
var urls = []string{
"http://www.golang.org/",
"http://www.baidu.com/",
"http://www.noexist11111111.com/",
}
g := new(errgroup.Group)
for _, url := range urls {
url := url
g.Go(func() error {
resp, err := http.Get(url)
if err != nil {
fmt.Println(err)
return err
}
fmt.Printf("get [%s] success: [%d] \n", url, resp.StatusCode)
return resp.Body.Close()
})
}
if err := g.Wait(); err != nil {
fmt.Println(err)
} else {
fmt.Println("All success!")
}
}
输出:
Get "http://www.noexist11111111.com/": dial tcp: lookup www.noexist11111111.com: no such host
get [http://www.baidu.com/] success: [200]
Get "http://www.golang.org/": dial tcp 172.217.24.113:80: connectex: A connection attempt failed because the connected party did not properly respond after a period o
f time, or established connection failed because connected host has failed to respond.
Get "http://www.noexist11111111.com/": dial tcp: lookup www.noexist11111111.com: no such host
ErrorGroup 有一个特点是会返回所以执行任务的 goroutine 遇到的第一个错误:
package main
import (
"fmt"
"golang.org/x/sync/errgroup"
"log"
"time"
)
func main() {
var eg errgroup.Group
for i := 0; i < 100; i++ {
i := i
eg.Go(func() error {
time.Sleep(2 * time.Second)
if i > 90 {
fmt.Println("Error:", i)
return fmt.Errorf("Error occurred: %d", i)
}
fmt.Println("End:", i)
return nil
})
}
if err := eg.Wait(); err != nil {
log.Fatal(err)
}
}
上面程序,遇到 i 大于 90 的都会产生错误结束执行,但是只有第一个执行时产生的错误被 ErrorGroup 返回,程
序的输出大概如下:
输出:
......
End: 35
End: 38
End: 28
End: 37
End:38;2;127;0;0m2023/06/29 14:18:03 Error occurred: 98
32
Error: 92
End: 23
End: 30
Error: 95
Error: 94
End: 74
End: 25
......
最早执行遇到错误的 goroutine 输出了Error: 98 但是所有未执行完的其他任务并没有停止执行,那么想让程序遇
到错误就终止其他子任务该怎么办呢?我们可以用 errgroup.Group 提供的 WithContext 方法创建一个带可取消
上下文功能的 ErrorGroup。
使用 errorgroup.Group 时注意它的两个特点:
package main
import (
"context"
"fmt"
"golang.org/x/sync/errgroup"
"log"
"time"
)
func main() {
eg, ctx := errgroup.WithContext(context.Background())
for i := 0; i < 100; i++ {
i := i
eg.Go(func() error {
time.Sleep(2 * time.Second)
select {
case <-ctx.Done():
fmt.Println("Canceled:", i)
return nil
default:
if i > 90 {
fmt.Println("Error:", i)
return fmt.Errorf("Error: %d", i)
}
fmt.Println("End:", i)
return nil
}
})
}
if err := eg.Wait(); err != nil {
log.Fatal(err)
}
}
Go 方法单独开启的 gouroutine 在执行参数传递进来的函数时,如果函数返回了错误,会对 ErrorGroup 持有的
err 字段进行赋值并及时调用 cancel 函数,通过上下文通知其他子任务取消执行任务。所以上面更新后的程序运
行后有如下类似的输出。
......
Canceled: 87
Canceled: 34
Canceled: 92
Canceled: 86
Cancled: 78
Canceled: 46
Cancel[38;2;127;0;0m2023/06/29 14:22:07 Error: 99
ed: 45
Canceled: 44
Canceled: 77
Canceled: 43
Canceled: 50
Canceled: 42
Canceled: 25
Canceled: 76
Canceled: 24
Canceled: 75
Canceled: 40
......
errorgroup源码:
在上面的例子中,子 goroutine 出现错误后,会 cancle 到其他的子任务,但是我们并没有看到调用 ctx 的 cancel
方法,下面我们看下源码,看看内部是怎么处理的。errgroup 的设计非常精练,全部代码如下:
package errgroup
import (
"context"
"sync"
)
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
cancel func()
wg sync.WaitGroup
errOnce sync.Once
err error
}
// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
}
return g.err
}
// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
}()
}
可以看到,errgroup 的实现依靠于结构体 Group,它通过封装 sync.WaitGroup,继承了 WaitGroup 的特性,在
Go() 方法中新起一个子任务 goroutine,并在 Wait() 方法中通过 sync.WaitGroup 的 Wait 进行阻塞等待。
同时 Group 利用 sync.Once 保证了它有且仅会保留第一个子 goroutine 错误。
Group 通过嵌入 context.WithCancel 方法产生的 cancel 函数,能够在子 goroutine 发生错误时,及时通过调用
cancle 函数,将 Context 的取消信号及时传播出去。
再看一个实际应用的例子:
package main
import (
"context"
"fmt"
"golang.org/x/sync/errgroup"
)
func main() {
g, ctx := errgroup.WithContext(context.Background())
dataChan := make(chan int, 20)
// 数据生产端任务子goroutine
g.Go(func() error {
defer close(dataChan)
for i := 1; ; i++ {
if i == 10 {
return fmt.Errorf("data 10 is wrong")
}
dataChan <- i
fmt.Println(fmt.Sprintf("sending %d", i))
}
})
// 数据消费端任务子goroutine
for i := 0; i < 3; i++ {
g.Go(func() error {
for j := 1; ; j++ {
select {
case <-ctx.Done():
return ctx.Err()
case number := <-dataChan:
fmt.Println(fmt.Sprintf("receiving %d", number))
}
}
})
}
// 主任务goroutine等待pipeline结束数据流
err := g.Wait()
if err != nil {
fmt.Println(err)
}
fmt.Println("main goroutine done!")
}
# 输出
sending 1
sending 2
sending 3
sending 4
sending 5
sending 6
sending 7
sending 8
sending 9
receiving 2
receiving 1
receiving 3
data 10 is wrong
main goroutine done!
自己实现一个 ErrGroup:
package main
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
)
const (
M = 2
N = 8
)
func main() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*50)
defer cancel()
result := make([]int, N+1)
errCh := make(chan error, 1)
var firstSendErr int32
wg := new(sync.WaitGroup)
done := make(chan struct{}, 1)
limit := make(chan struct{}, M)
for i := 1; i <= N; i++ {
limit <- struct{}{}
var quit bool
select {
// context已经被cancel,不需要起新的goroutine了
case <-ctx.Done():
quit = true
default:
}
if quit {
break
}
wg.Add(1)
go func(x int) {
defer func() {
wg.Done()
<-limit
}()
if ret, err := doTask(ctx, x); err != nil {
if atomic.CompareAndSwapInt32(&firstSendErr, 0, 1) {
errCh <- err
// cancel其他的请求
cancel()
}
} else {
result[x] = ret
}
}(i)
}
go func() {
wg.Wait()
close(done)
}()
select {
case err := <-errCh:
handleErr(err, result[1:])
<-done
case <-done:
if len(errCh) > 0 {
err := <-errCh
handleErr(err, result[1:])
return
}
fmt.Println("success handle all task:", result[1:])
}
}
func handleErr(err error, result []int) {
fmt.Println("task err occurs: ", err, "result", result)
}
func doTask(ctx context.Context, i int) (ret int, err error) {
fmt.Println("task start", i)
defer func() {
fmt.Println("task done", i, "err", err)
}()
select {
// 模拟处理任务时间
case <-time.After(time.Second * time.Duration(i)):
// 处理任务要支持被context cancel,不然就一直等到处理完再返回了
case <-ctx.Done():
fmt.Println("task canceled", i)
return -1, ctx.Err()
}
// 模拟出现错误
if i == 6 {
return -1, errors.New("err test")
}
return i, nil
}
# 输出
task start 2
task start 1
task done 1 err <nil>
task start 3
task done 2 err <nil>
task start 4
task done 3 err <nil>
task start 5
task done 4 err <nil>
task start 6
task done 5 err <nil>
task start 7
task done 6 err err test
task canceled 7
task done 7 err context canceled
task err occurs: err test result [1 2 3 4 5 0 0 0]
总结:
使用 errorgroup.Group 时注意它的特点:
继承了 WaitGroup 的功能
errgroup.Group 在出现错误或者等待结束后都会调用 Context 对象 的 cancel 方法同步取消信号。
只有第一个出现的错误才会被返回,剩余的错误都会被直接抛弃。
context 信号传播:如果子任务 goroutine 中有循环逻辑,则可以添加 ctx.Done 逻辑,此时通过 context 的
取消信号,提前结束子任务执行。