做过Java开发的同学肯定知道,JDK7加入的Fork/Join是一个非常优秀的设计,到了JDK8,又结合并行流中进行了优化和增强,是一个非常好的工具。
Fork/Join本质上是一种任务分解,即:将一个很大的任务分解成若干个小任务,然后再对小任务进一步分解,直到最小颗粒度,然后并发执行。
这么做的优点很明显,就是可以大幅提升计算性能,缺点嘛,也有一点,那就是资源开销要大一些。
在网上找了一张图,任务分解就是这个意思:
对于Golang中的Fork/Join的实现,我参考了JDK的源码,利用了Goroutine特性,这样就能充分利用MPG模型,不必自己再处理任务窃取等问题了,用起来还是蛮爽的。
废话不多说,请看代码:
package like_fork_join
import (
"fmt"
"github.com/oklog/ulid/v2"
)
const defaultPageSize = 10
type MyForkJoinTask struct {
size int
}
// NewMyTask 初始化一个任务
func NewMyTask(pageSize int) *MyForkJoinTask {
var size = defaultPageSize
if pageSize > size {
size = pageSize
}
return &MyForkJoinTask{
size: size,
}
}
// Do 执行任务时,传入一个切片
func (t *MyForkJoinTask) Do(numbers []int) int {
JoinCh := make(chan bool, 1)
resultCh := make(chan int, 1)
t.do(numbers, JoinCh, resultCh, ulid.Make().String())
result := <-resultCh
return result
}
func (t *MyForkJoinTask) do(numbers []int, joinCh chan bool, resultCh chan int, id string) {
defer func() {
joinCh <- true
close(joinCh)
close(resultCh)
}()
fmt.Printf("id %s numbers %+v\n", id, numbers)
// 任务小于最小颗粒度时,直接执行逻辑(此处是求和),不再拆分,否则进行分治
if len(numbers) <= t.size {
var sum = 0
for _, number := range numbers {
sum += number
}
resultCh <- sum
fmt.Printf("id %s numbers %+v, result %+v\n", id, numbers, sum)
return
} else {
start := 0
end := len(numbers)
middle := (start + end) / 2
// 左
leftJoinCh := make(chan bool, 1)
leftResultCh := make(chan int, 1)
leftId := ulid.Make().String()
go t.do(numbers[start:middle], leftJoinCh, leftResultCh, id+"->left->"+leftId)
// 右
rightJoinCh := make(chan bool, 1)
rightResultCh := make(chan int, 1)
rightId := ulid.Make().String()
go t.do(numbers[middle:], rightJoinCh, rightResultCh, id+"->right->"+rightId)
// 等待左边和右边分治子任务结束
var leftDone, rightDone = false, false
for {
select {
case _, ok := <-leftJoinCh:
if ok {
fmt.Printf("left %s join done\n", leftId)
leftDone = true
}
case _, ok := <-rightJoinCh:
if ok {
fmt.Printf("right %s join done\n", rightId)
rightDone = true
}
}
if leftDone && rightDone {
break
}
}
// 取结果
var (
left = 0
right = 0
leftResultDone = false
rightResultDone = false
)
for {
select {
case l, ok := <-leftResultCh:
if ok {
fmt.Printf("id %s numbers %+v, left %s return: %+v\n", id, numbers, leftId, left)
left = l
leftResultDone = true
}
case r, ok := <-rightResultCh:
if ok {
fmt.Printf("id %s numbers %+v, right %s return: %+v\n", id, numbers, rightId, right)
right = r
rightResultDone = true
}
}
if leftResultDone && rightResultDone {
break
}
}
resultCh <- left + right
return
}
}
代码也不复杂,有注释,大家耐心读一下就明白了。
我写了一个比较有压力的测试用例代码,请看:
package like_fork_join
import (
"fmt"
"testing"
)
func TestMyTask_Do(t1 *testing.T) {
type args struct {
numbers []int
}
const max = 10000
var nums = make([]int, 0, max)
var want = 0
for i := 1; i <= max; i++ {
nums = append(nums, i)
want += i
}
tests := []struct {
name string
args args
want int
}{
{name: fmt.Sprintf("sum(1,%d)", max), args: args{numbers: nums}, want: want},
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
for i := 0; i <= 100; i += 5 {
t := NewMyTask(i)
if got := t.Do(tt.args.numbers); got != tt.want {
t1.Errorf("Do() = %v, want %v", got, tt.want)
}
}
})
}
}
测试成功:
--- PASS: TestMyTask_Do/sum(1,10000) (1257.79s)
PASS
删除所有fmt包的控制台输出,再跑单元测试结果:
=== RUN TestMyTask_Do
--- PASS: TestMyTask_Do (60.53s)
=== RUN TestMyTask_Do/sum(1,10000)
--- PASS: TestMyTask_Do/sum(1,10000) (60.53s)
PASS
20万次加法计算,长度为1万的数组的20次计算,60秒搞定,性能巨强,Golang就是棒!
计划后续再研究研究,看能否把执行任务的逻辑做成泛型和函数闭包,给抽象出来,这样就能单独形成一个通用型的代码包,供外部各种应用程序使用了,不过考虑到goroutine的上下文等问题,估计会让代码比较复杂,眼下这个版本足够简单,也能满足绝大多数场景了。