Task.WhileAll扩展方法

TPL实现Task.WhileAll扩展方法

文章翻译整理自 Nikola Malovic 两篇博文:

当 Task.WhenAll 遇见 Task.WhenAny

在 TPL (Task Parallel Library) 中,有两种通过非阻塞方式等待 Task 数组任务结束的方式:Task.WhenAll 和 Task.WhenAny 。

它们的工作方式是:

  • WhenAll 当每项任务都完成时为完成。
  • WhenAny 当任意项任务完成时为完成。

现在我们需要一项功能,完成 Task 数组中的所有任务,并且当有任务完成时汇报状态。

我们称这个扩展方法为:Task.WhileAll 。

扩展方法实现

复制代码
 1     public static class TaskExtensions
 2     {
 3         public static async Task<IList<T>> WhileAll<T>(this IList<Task<T>> tasks, IProgress<T> progress)
 4         {
 5             var result = new List<T>(tasks.Count);
 6             var done = new List<Task<T>>(tasks);
 7 
 8             while (done.Count > 0)
 9             {
10                 await Task.WhenAny(tasks);
11 
12                 var spinning = new List<Task<T>>(done.Count - 1);
13                 for (int i = 0; i < done.Count; i++)
14                 {
15                     if (done[i].IsCompleted)
16                     {
17                         result.Add(done[i].Result);
18                         progress.Report(done[i].Result);
19                     }
20                     else
21                     {
22                         spinning.Add(done[i]);
23                     }
24                 }
25 
26                 done = spinning;
27             }
28 
29             return result;
30         }
31     }
复制代码

代码实现很简单:

  • 其是 IList<Task<T>> 的一个 async 扩展方法
  • 方法返回完整的 IList<T> 结果
  • 方法会接受一个 IProgress<T> 类型的参数,用于向订阅者发布 Task 完成信息
  • 在方法体内,我们使用一个循环来检测,直到所有 Task 完成
  • 通过使用 Task.WhenAny 来异步等待 Task 完成

单元测试

复制代码
 1     [TestClass]
 2     public class UnitTest1
 3     {
 4         [TestMethod]
 5         public async Task TestTaskExtensionsWhileAll()
 6         {
 7             var task1 = Task.Run(() => 101);
 8             var task2 = Task.Run(() => 102);
 9             var tasks = new List<Task<int>>() { task1, task2 };
10 
11             List<int> result = new List<int>();
12             var listener = new Progress<int>(
13                 taskResult =>
14                 {
15                     result.Add(taskResult);
16                 });
17 
18             var actual = await tasks.WhileAll(listener);
19             Thread.Sleep(50); // wait a bit for progress reports to complete
20 
21             Assert.AreEqual(2, result.Count);
22             Assert.IsTrue(result.Contains(101));
23             Assert.IsTrue(result.Contains(102));
24 
25             Assert.AreEqual(2, actual.Count);
26             Assert.IsTrue(actual.Contains(101));
27             Assert.IsTrue(actual.Contains(102));
28         }
29     }
复制代码

同样,测试代码也不复杂:

  • 创建两个哑元 Task,并存到数组中
  • 定义进度侦听器 Progress<T>,来监测每个任务运行的结果
  • 通过 await 方式来调用方法
  • 使用 Thread.Sleep 来等待 50ms ,以便 Progress 可以来得及处理结果
  • 检查所有 Task 执行完毕后均已上报 Progress
  • 检查所有 Task 均已执行完毕

我知道每当使用 Thread.Sleep 时绝不是件好事,所以我决定摆脱它。

实现IProgressAsync<T>

问题实际上是因为 IProgress<T> 接口定义的是 void 委托,因此无法使用 await 进行等待。

因此我决定定义一个新的接口,使用同样的 Report 行为,但会返回 Task ,用以实现真正的异步。

1     public interface IProgressAsync<in T>
2     {
3         Task ReportAsync(T value);
4     }

有了异步版本的支持,将使订阅者更容易处理 await 调用。当然也可以使用 async void 来达成,但我认为 async void 总会延伸出更差的设计。所以,我还是选择通过定义 Task 返回值签名的接口来达成这一功能。

如下为接口实现:

复制代码
 1     public class ProgressAsync<T> : IProgressAsync<T>
 2     {
 3         private readonly Func<T, Task> handler;
 4 
 5         public ProgressAsync(Func<T, Task> handler)
 6         {
 7             this.handler = handler;
 8         }
 9 
10         public async Task ReportAsync(T value)
11         {
12             await this.handler.InvokeAsync(value);
13         }
14     }
复制代码

显然也没什么特别的:

  • 使用 Func<T, Task> 来代替 Action<T>,以便可以使用 await
  • ReportAsync 通过使用 await 方式来提供 Task

有了这些之后,我们来更新扩展方法:

复制代码
 1     public static class TaskExtensions
 2     {
 3         public static async Task<IList<T>> WhileAll<T>(this IList<Task<T>> tasks, IProgressAsync<T> progress)
 4         {
 5             var result = new List<T>(tasks.Count);
 6             var remainingTasks = new List<Task<T>>(tasks);
 7 
 8             while (remainingTasks.Count > 0)
 9             {
10                 await Task.WhenAny(tasks);
11                 var stillRemainingTasks = new List<Task<T>>(remainingTasks.Count - 1);
12                 for (int i = 0; i < remainingTasks.Count; i++)
13                 {
14                     if (remainingTasks[i].IsCompleted)
15                     {
16                         result.Add(remainingTasks[i].Result);
17                         await progress.ReportAsync(remainingTasks[i].Result);
18                     }
19                     else
20                     {
21                         stillRemainingTasks.Add(remainingTasks[i]);
22                     }
23                 }
24 
25                 remainingTasks = stillRemainingTasks;
26             }
27 
28             return result;
29         }
30 
31         public static Task InvokeAsync<T>(this Func<T, Task> task, T value)
32         {
33             return Task<Task>.Factory.FromAsync(task.BeginInvoke, task.EndInvoke, value, null);
34         }
35     }
复制代码

所有都就绪后,我们就可以将 Thread.Sleep 从单元测试中移除了。

复制代码
 1     [TestClass]
 2     public class UnitTest1
 3     {
 4         private List<int> result = new List<int>();
 5         private async Task OnProgressAsync(int arg)
 6         {
 7             result.Add(arg);
 8         }     
 9 
10         [TestMethod]
11         public async Task TestTaskExtensionsWhileAll()
12         {
13             var task1 = Task.Run(() => 101);
14             var task2 = Task.Run(() => 102);
15             var tasks = new List<Task<int>>() { task1, task2 };
16 
17             var listener = new ProgressAsync<int>(this.OnProgressAsync);
18             var actual = await tasks.WhileAll(listener);
19 
20             Assert.AreEqual(2, this.result.Count);
21             Assert.IsTrue(this.result.Contains(101));
22             Assert.IsTrue(this.result.Contains(102));
23 
24             Assert.AreEqual(2, actual.Count);
25             Assert.IsTrue(actual.Contains(101));
26             Assert.IsTrue(actual.Contains(102));
27         }
28     }
复制代码

 

你可能感兴趣的:(while)