Go语言ErrGroup

Go语言ErrGroup

在并发编程里,sync.WaitGroup 并发原语的使用频率非常高,它经常用于协同等待的场景:goroutine A 在检查

点等待一组执行任务的 worker goroutine 全部完成,如果在执行任务的这些 goroutine 还没全部完成,

goroutine A 就会阻塞在检查点,直到所有 woker goroutine 都完成后才能继续执行。

如果在 woker goroutine 的执行过程中遇到错误并想要处理该怎么办? WaitGroup 并没有提供传播错误的功能,

遇到这种场景我们该怎么办? Go 语言在扩展库提供了 ErrorGroup 并发原语正好适合在这种场景下使用,它在

WaitGroup 的基础上还提供了,错误传播以及上下文取消的功能。

Go 扩展库通过 errorgroup.Group 提供 ErrorGroup 原语的功能,它有三个方法可调用:

func WithContext(ctx context.Context) (*Group, context.Context)
func (g *Group) Go(f func() error)
func (g *Group) Wait() error

接下来我们让主 goroutine 使用 ErrorGroup 代替 WaitGroup 等待所以子任务的完成,ErrorGroup 有一个特点是

会返回所以执行任务的 goroutine 遇到的第一个错误。试着执行一下下面的程序,观察程序的输出。

package main

import (
	"fmt"
	"golang.org/x/sync/errgroup"
	"net/http"
)

func main() {
	var urls = []string{
		"http://www.golang.org/",
		"http://www.baidu.com/",
		"http://www.noexist11111111.com/",
	}
	g := new(errgroup.Group)
	for _, url := range urls {
		url := url
		g.Go(func() error {
			resp, err := http.Get(url)
			if err != nil {
				fmt.Println(err)
				return err
			}
			fmt.Printf("get [%s] success: [%d] \n", url, resp.StatusCode)
			return resp.Body.Close()
		})
	}
	if err := g.Wait(); err != nil {
		fmt.Println(err)
	} else {
		fmt.Println("All success!")
	}
}

输出:

Get "http://www.noexist11111111.com/": dial tcp: lookup www.noexist11111111.com: no such host
get [http://www.baidu.com/] success: [200]
Get "http://www.golang.org/": dial tcp 172.217.24.113:80: connectex: A connection attempt failed because the connected party did not properly respond after a period o
f time, or established connection failed because connected host has failed to respond.
Get "http://www.noexist11111111.com/": dial tcp: lookup www.noexist11111111.com: no such host

ErrorGroup 有一个特点是会返回所以执行任务的 goroutine 遇到的第一个错误:

package main

import (
	"fmt"
	"golang.org/x/sync/errgroup"
	"log"
	"time"
)

func main() {
	var eg errgroup.Group
	for i := 0; i < 100; i++ {
		i := i
		eg.Go(func() error {
			time.Sleep(2 * time.Second)
			if i > 90 {
				fmt.Println("Error:", i)
				return fmt.Errorf("Error occurred: %d", i)
			}
			fmt.Println("End:", i)
			return nil
		})
	}
	if err := eg.Wait(); err != nil {
		log.Fatal(err)
	}
}

上面程序,遇到 i 大于 90 的都会产生错误结束执行,但是只有第一个执行时产生的错误被 ErrorGroup 返回,程

序的输出大概如下:

输出:

......
End: 35
End: 38
End: 28
End: 37
End:38;2;127;0;0m2023/06/29 14:18:03 Error occurred: 98
32
Error: 92
End: 23
End: 30
Error: 95
Error: 94
End: 74
End: 25
......

最早执行遇到错误的 goroutine 输出了Error: 98 但是所有未执行完的其他任务并没有停止执行,那么想让程序遇

到错误就终止其他子任务该怎么办呢?我们可以用 errgroup.Group 提供的 WithContext 方法创建一个带可取消

上下文功能的 ErrorGroup。

使用 errorgroup.Group 时注意它的两个特点:

  • errgroup.Group 在出现错误或者等待结束后都会调用 Context 对象的 cancel 方法同步取消信号。
  • 只有第一个出现的错误才会被返回,剩余的错误都会被直接抛弃。
package main

import (
	"context"
	"fmt"
	"golang.org/x/sync/errgroup"
	"log"
	"time"
)

func main() {
	eg, ctx := errgroup.WithContext(context.Background())
	for i := 0; i < 100; i++ {
		i := i
		eg.Go(func() error {
			time.Sleep(2 * time.Second)
			select {
			case <-ctx.Done():
				fmt.Println("Canceled:", i)
				return nil
			default:
				if i > 90 {
					fmt.Println("Error:", i)
					return fmt.Errorf("Error: %d", i)
				}
				fmt.Println("End:", i)
				return nil
			}
		})
	}
	if err := eg.Wait(); err != nil {
		log.Fatal(err)
	}
}

Go 方法单独开启的 gouroutine 在执行参数传递进来的函数时,如果函数返回了错误,会对 ErrorGroup 持有的

err 字段进行赋值并及时调用 cancel 函数,通过上下文通知其他子任务取消执行任务。所以上面更新后的程序运

行后有如下类似的输出。

......
Canceled: 87
Canceled: 34
Canceled: 92
Canceled: 86
Cancled: 78
Canceled: 46
Cancel[38;2;127;0;0m2023/06/29 14:22:07 Error: 99
ed: 45
Canceled: 44
Canceled: 77
Canceled: 43
Canceled: 50
Canceled: 42
Canceled: 25
Canceled: 76
Canceled: 24
Canceled: 75
Canceled: 40
......

errorgroup源码:

在上面的例子中,子 goroutine 出现错误后,会 cancle 到其他的子任务,但是我们并没有看到调用 ctx 的 cancel

方法,下面我们看下源码,看看内部是怎么处理的。errgroup 的设计非常精练,全部代码如下:

package errgroup

import (
    "context"
    "sync"
)

// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
    cancel func()
    wg sync.WaitGroup
    errOnce sync.Once
    err     error
}

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
    ctx, cancel := context.WithCancel(ctx)
    return &Group{cancel: cancel}, ctx
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
    g.wg.Wait()
    if g.cancel != nil {
        g.cancel()
    }
    return g.err
}

// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
    g.wg.Add(1)

    go func() {
        defer g.wg.Done()

        if err := f(); err != nil {
            g.errOnce.Do(func() {
                g.err = err
                if g.cancel != nil {
                    g.cancel()
                }
            })
        }
    }()
}

可以看到,errgroup 的实现依靠于结构体 Group,它通过封装 sync.WaitGroup,继承了 WaitGroup 的特性,在

Go() 方法中新起一个子任务 goroutine,并在 Wait() 方法中通过 sync.WaitGroup 的 Wait 进行阻塞等待。

同时 Group 利用 sync.Once 保证了它有且仅会保留第一个子 goroutine 错误。

Group 通过嵌入 context.WithCancel 方法产生的 cancel 函数,能够在子 goroutine 发生错误时,及时通过调用

cancle 函数,将 Context 的取消信号及时传播出去。

再看一个实际应用的例子:

package main

import (
	"context"
	"fmt"
	"golang.org/x/sync/errgroup"
)

func main() {
	g, ctx := errgroup.WithContext(context.Background())
	dataChan := make(chan int, 20)
	// 数据生产端任务子goroutine
	g.Go(func() error {
		defer close(dataChan)
		for i := 1; ; i++ {
			if i == 10 {
				return fmt.Errorf("data 10 is wrong")
			}
			dataChan <- i
			fmt.Println(fmt.Sprintf("sending %d", i))
		}
	})
	// 数据消费端任务子goroutine
	for i := 0; i < 3; i++ {
		g.Go(func() error {
			for j := 1; ; j++ {
				select {
				case <-ctx.Done():
					return ctx.Err()
				case number := <-dataChan:
					fmt.Println(fmt.Sprintf("receiving %d", number))
				}
			}
		})
	}
	// 主任务goroutine等待pipeline结束数据流
	err := g.Wait()
	if err != nil {
		fmt.Println(err)
	}
	fmt.Println("main goroutine done!")
}
# 输出
sending 1
sending 2
sending 3
sending 4
sending 5
sending 6
sending 7
sending 8
sending 9
receiving 2
receiving 1
receiving 3
data 10 is wrong
main goroutine done!

自己实现一个 ErrGroup:

package main

import (
	"context"
	"errors"
	"fmt"
	"sync"
	"sync/atomic"
	"time"
)

const (
	M = 2
	N = 8
)

func main() {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*50)
	defer cancel()
	result := make([]int, N+1)
	errCh := make(chan error, 1)
	var firstSendErr int32
	wg := new(sync.WaitGroup)
	done := make(chan struct{}, 1)
	limit := make(chan struct{}, M)
	for i := 1; i <= N; i++ {
		limit <- struct{}{}
		var quit bool
		select {
		// context已经被cancel,不需要起新的goroutine了
		case <-ctx.Done():
			quit = true
		default:
		}
		if quit {
			break
		}
		wg.Add(1)
		go func(x int) {
			defer func() {
				wg.Done()
				<-limit
			}()
			if ret, err := doTask(ctx, x); err != nil {
				if atomic.CompareAndSwapInt32(&firstSendErr, 0, 1) {
					errCh <- err
					// cancel其他的请求
					cancel()
				}
			} else {
				result[x] = ret
			}
		}(i)
	}
	go func() {
		wg.Wait()
		close(done)
	}()
	select {
	case err := <-errCh:
		handleErr(err, result[1:])
		<-done
	case <-done:
		if len(errCh) > 0 {
			err := <-errCh
			handleErr(err, result[1:])
			return
		}
		fmt.Println("success handle all task:", result[1:])
	}
}

func handleErr(err error, result []int) {
	fmt.Println("task err occurs: ", err, "result", result)
}

func doTask(ctx context.Context, i int) (ret int, err error) {
	fmt.Println("task start", i)
	defer func() {
		fmt.Println("task done", i, "err", err)
	}()
	select {
	// 模拟处理任务时间
	case <-time.After(time.Second * time.Duration(i)):
	// 处理任务要支持被context cancel,不然就一直等到处理完再返回了
	case <-ctx.Done():
		fmt.Println("task canceled", i)
		return -1, ctx.Err()
	}
	// 模拟出现错误
	if i == 6 {
		return -1, errors.New("err test")
	}
	return i, nil
}
# 输出
task start 2
task start 1
task done 1 err <nil>
task start 3
task done 2 err <nil>
task start 4
task done 3 err <nil>
task start 5
task done 4 err <nil>
task start 6
task done 5 err <nil>
task start 7
task done 6 err err test
task canceled 7
task done 7 err context canceled
task err occurs:  err test result [1 2 3 4 5 0 0 0]

总结:

使用 errorgroup.Group 时注意它的特点:

  • 继承了 WaitGroup 的功能

  • errgroup.Group 在出现错误或者等待结束后都会调用 Context 对象 的 cancel 方法同步取消信号。

  • 只有第一个出现的错误才会被返回,剩余的错误都会被直接抛弃。

  • context 信号传播:如果子任务 goroutine 中有循环逻辑,则可以添加 ctx.Done 逻辑,此时通过 context 的

    取消信号,提前结束子任务执行。

你可能感兴趣的:(golang,golang)