Kotlin协程实现 CountDownLatch

import kotlinx.coroutines.CancellableContinuation
import kotlinx.coroutines.suspendCancellableCoroutine
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.resume
import kotlin.math.max


/**
 * 使用mutex或者Semaphore来实现,释放的时候每个Coroutine都要执行一遍获取锁&释放锁的操作,比较慢
 * 下面这个类实现了批量地恢复Coroutine
 *
 * Mutex issue: https://github.com/Kotlin/kotlinx.coroutines/issues/2371
 */
internal class SuspendCountDownLatch {

  private val countDownNumber: AtomicInteger

  //哨兵节点
  private val firstContinuationNode = ContinuationNode(null)

  @Volatile
  private var lastContinuationNode = firstContinuationNode

  constructor(count: Int) {
    countDownNumber = AtomicInteger(count)
  }

  fun getLockedCount() = max(0, countDownNumber.get())

  fun isLocked() = countDownNumber.get() > 0

  fun countDown() {
    if (!isLocked()) return

    if (countDownNumber.decrementAndGet() == 0) {
      val firstContinuableNode = releaseAllContinuableNodes()
      firstContinuableNode?.let { resumeContinuableNodes(it) }
    }
  }

  suspend fun await() {
    if (isLocked()) awaitSlowPath()
  }

  //CancellableContinuationImpl 能够根据resumeMode+协程上下文进行恢复, 并且在协程已取消时不会造成Crash
  private suspend fun awaitSlowPath() = suspendCancellableCoroutine { cont ->
    val newContinuationNode = ContinuationNode(cont)
    //抢占链表尾
    while (!tryAddToQueueTail(newContinuationNode)) {
      if (!isLocked()) {
        cont.resume(Unit)
        return@suspendCancellableCoroutine
      }
    }
    //确保其它线程可以接着插入节点
    lastContinuationNode = newContinuationNode
  }

  private fun releaseAllContinuableNodes(): ContinuationNode? {
    val emptyNode = ContinuationNode(null)
    //占位,确保不会有新的节点进入链表
    while (!tryAddToQueueTail(emptyNode)) {
      //...
    }

    if (firstContinuationNode.nextRef.get() == emptyNode) {
      //没有可恢复的节点
      return null
    }

    //第一个是哨兵节点
    val firstContinuableNode = firstContinuationNode.nextRef.get() ?: return null
    //释放除了哨兵以外的所有对象,防止内存泄露
    firstContinuationNode.nextRef.set(null)
    return firstContinuableNode
  }

  private fun resumeContinuableNodes(first: ContinuationNode) {
    var nextNode: ContinuationNode? = first
    while (nextNode != null) {
      nextNode.continuation?.resume(Unit)
      nextNode = nextNode.nextRef.get()
    }
  }

  private fun tryAddToQueueTail(node: ContinuationNode): Boolean =
      lastContinuationNode.nextRef.compareAndSet(null, node)

  override fun toString(): String {
    return super.toString() + "[Count = " + getLockedCount() + "]"
  }

  private class ContinuationNode(val continuation: CancellableContinuation?) {

    val nextRef = AtomicReference(null)
  }
}

你可能感兴趣的:(Kotlin协程实现 CountDownLatch)