动手写Amazon SQS客户端

Amazon SQS是AWS上主流的消息队列服务,按理说它是有SDK的,那么为什么还要自己编写客户端呢?因为它提供的SDK太简单,就几个Web API,没有办法直接用。我们具体来说一说。

SQS SDK中的API,我们主要用到的也就是getQueueUrl, sendMessage, receiveMessage等。getQueueUrl能根据传入的queueName查找到queueUrl,后续用这个queueUrl来访问相应的queue(即:调用sendMessage发消息,或调用receiveMessage收消息)。主要复杂度在于收消息:这个API是要主动调用的,可是你怎么知道有没有新消息需要你去收呢?事实上,这个receiveMessage API是基于拉模式(pull mode)的,你需要轮询来不停地拉取新消息,这个比较像Kafka。随之而来的,就需要线程管理,需要一个对SDK做了进一步包装的客户端库。

Spring Cloud Messaging提供了SQS的客户端库。但是当我们在2023年3月构建基于SQS的应用程序时,我们用的是AWS SDK V2,而Spring Cloud Messaging尚未正式支持AWS SDK V2。因此,我们决定自己编写SQS的客户端库。而且我们的设计也与Spring Cloud Messaging的有所不同:我们同时使用多个AWS账号,为此,我们直接在配置中引用queueUrl(它其实是静态值,可直接引用);而Spring Cloud Messaging只能在配置中引用queueName,然后再运行时获取当前AWS账号中相应的queueUrl。

现在就来讲一讲设计与实现。消息队列客户端遵循生产者-消费者模型,分为Producer和Consumer。SQS的消息体必须是不大于256KB的文本,因此可以把消息体当成一个String。

Producer

Producer很简单,把消息发出去就行了,顺便对超时和异常做适当的处理。库的用户可以自行决定消息体的序列化和反序列化方式,我们不干涉这件事。

Producer的使用方式很简单:

new SqsMessageProducer(queueUrl)
    .produce(yourMessagePayload);

Producer的完整实现代码大致如下:

/** How to use: Call produce() with your serialized message string. */
public class SqsMessageProducer {
  private final String queueUrl;
  private final int timeoutSeconds;

  private final SqsAsyncClient client;

  public SqsMessageProducer(String queueUrl, int timeoutSeconds) {
    this.queueUrl = queueUrl;
    this.timeoutSeconds = timeoutSeconds;
    client = new SqsClientFactory().createSqsAsyncClient();
  }

  public void produce(String payload) {
    var sendMessageFuture =
        client.sendMessage(
            SendMessageRequest.builder().queueUrl(queueUrl).messageBody(payload).build());
    // 不能无限等待future,要有超时机制
    try {
      sendMessageFuture.get(timeoutSeconds, TimeUnit.SECONDS);
    } catch (InterruptedException | ExecutionException | TimeoutException e) {
      throw new ProducerException(e);
    }
  }

  public static class ProducerException extends RuntimeException {
    public ProducerException(Throwable cause) {
      super(cause);
    }
  }
}

Consumer

Consumer的使用方式很简单,有效利用了函数式编程风格,不需要编写派生类,只需要创建Consumer的实例,传入一个消息处理函数,然后启动就可以。示例代码如下:

new SqsMessageConsumer(queueUrl, yourCustomizedThreadNamePrefix, yourMessageHandler)
  .runAsync();

Consumer的实现要复杂一些,需要实现消息驱动的异步计算风格。处理消息一般会比收取消息更花时间,因此它创建一个主循环线程用来轮询消息队列,创建一个工作线程池用来处理消息。主循环线程每次可能收到0~n个消息,把收到的消息分发给工作线程池来处理。因为工作线程池自带任务队列用于缓冲,所以这两种线程之间是互不阻塞的:如果工作线程慢了,主循环线程可以照常收取和分发新消息;如果主循环线程慢了,工作线程可以照常处理已有的消息。

注意一个要点:SQS不会自动清理已被收取的消息,因为它不知道你是否成功处理了消息。当一个消息被收取后,它会暂时被隐藏,以免其他消费者收到它,如果此消息一直没有被清理,它会在一段时间后(默认30秒,可配置)重新出现,被某个消费者再度收取。你需要一个机制来主动告知SQS某条消息已被处理,这个机制就是deleteMessage API:成功处理一个消息后,主动调deleteMessage来从队列中删除此消息;如果处理失败,什么都不用做,SQS会在一段时间后再次让消费者收取到此消息。

核心代码这么写:

private volatile boolean shouldShutdown = false;

// 只要没有关闭,主循环就一直收取消息
while (!shouldShutdown) {
  List messages;
  try {
    messages = receiveMessages();
  } catch (Throwable e) {
    logger.error("failed to receive", e);
    continue;
  }

  try {
    dispatchMessages(queueUrl, messages);
  } catch (Throwable e) {
    logger.error("failed to dispatch", e);
  }
}

// 收消息的具体实现
private List receiveMessages() throws ExecutionException, InterruptedException {
  // visibilityTimeout = message handling timeout
  // It is usually set at infrastructure level
  var receiveMessageFuture =
      client.receiveMessage(
          ReceiveMessageRequest.builder()
              .queueUrl(queueUrl)
              .waitTimeSeconds(10)
              .maxNumberOfMessages(maxParallelism)
              .build());
  // 上面已在请求中设置waitTimeSeconds=10,所以这里可以不设置超时
  return receiveMessageFuture.get().messages();
}

// 把收到消息分发给工作线程池做处理
// 要显式地把处理好的消息从队列中删除
// 如果不删除,会在未来再次被主循环收取到
private void dispatchMessages(String queueUrl, List messages) {
  for (Message message : messages) {
    workerThreadPool.execute(
        () -> {
          String messageId = message.messageId();
          try {
            logger.info("Started handling message with id={}", messageId);
            messageHandler.accept(message);
            logger.info("Completed handling message with id={}", messageId);
            // Should delete the succeeded message
            client.deleteMessage(
                DeleteMessageRequest.builder()
                    .queueUrl(queueUrl)
                    .receiptHandle(message.receiptHandle())
                    .build());
            logger.info("Deleted handled message with id={}", messageId);
          } catch (Throwable e) {
            // Logging is enough. Failed message is not deleted, and will be retried on a future polling.
            logger.error("Failed to handle message with id=$messageId", e);
          }
        });
  }
}

在以上代码中,每次receiveMessage时设置waitTimeSeconds=10,即最多等待10秒,若没有新消息就返回0条消息;若有新消息,就提前返回所收到的1或多条消息。之所以不无限等待,是怕网关自动关闭长时间静默的网络连接。

还需要一个优雅关闭机制,让服务器能顺利关闭和清理资源:

Thread mainLoopThread = Thread.currentThread();
// JVM awaits all shutdown hooks to complete
// https://stackoverflow.com/questions/8663107/how-does-the-jvm-terminate-daemon-threads-or-how-to-write-daemon-threads-that-t
Runtime.getRuntime()
    .addShutdownHook(
        new Thread(
            () -> {
              shouldShutdown = true;
              mainLoopThread.interrupt();
              try {
                workerThreadPool.shutdown();
                boolean terminated = workerThreadPool.awaitTermination(1, TimeUnit.MINUTES);
                if (!terminated) {
                  List runnables = workerThreadPool.shutdownNow();
                  logger.info("shutdownNow with {} runnables undone", runnables.size());
                }
              } catch (RuntimeException e) {
                logger.error("shutdown failed", e);
                throw e;
              } catch (InterruptedException e) {
                logger.error("shutdown interrupted", e);
                throw new IllegalStateException(e);
              }
            }));

有时网络连接不稳定,主循环频繁报错比较noisy,改成指数退避的重试:

while (!shouldShutdown) {
  List messages;
  try {
    messages = receiveMessages();
    // after success, restore backoff to the initial value
    receiveBackoffSeconds = 1;
  } catch (Throwable e) {
    logger.error("failed to receive", e);
    logger.info("Gonna sleep {} seconds for backoff", receiveBackoffSeconds);
    try {
      //noinspection BusyWait
      Thread.sleep(receiveBackoffSeconds * 1000L);
    } catch (InterruptedException ex) {
      logger.error("backoff sleep interrupted", ex);
    }
    // after failure, increment next backoff (≤ limit)
    receiveBackoffSeconds = exponentialBackoff(receiveBackoffSeconds, 60);
    continue;
  }

  try {
    dispatchMessages(queueUrl, messages);
  } catch (Throwable e) {
    logger.error("failed to dispatch", e);
  }
}

private int exponentialBackoff(int current, int limit) {
  int next = current * 2;
  return Math.min(next, limit);
}

工作线程池是一个ThreadPoolExecutor,使用一个有界的BlockingQueue来实现回压(back-pressure),当这个queue一满,主循环线程就会被迫暂停,以防止本地的消息积压过多:如果积压过多,既会浪费内存,又会导致很多消息被收取却得不到及时处理,这时还不如让给其他消费者实例去收取。创建工作线程池的相关代码如下:

workerThreadPool =
    new ThreadPoolExecutor(
        maxParallelism,
        maxParallelism,
        0,
        TimeUnit.SECONDS,
        // bounded queue for back pressure
        new LinkedBlockingQueue<>(100),
        new CustomizableThreadFactory(threadPoolPrefix + "-pool-"),
        new TimeoutBlockingPolicy(30));

// Used by workerThreadPool
private static class TimeoutBlockingPolicy implements RejectedExecutionHandler {

  private final long timeoutSeconds;

  public TimeoutBlockingPolicy(long timeoutSeconds) {
    this.timeoutSeconds = timeoutSeconds;
  }

  @Override
  public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
    try {
      BlockingQueue queue = executor.getQueue();
      if (!queue.offer(r, this.timeoutSeconds, TimeUnit.SECONDS)) {
        throw new RejectedExecutionException("Timeout after " + timeoutSeconds + " seconds");
      }
    } catch (InterruptedException e) {
      throw new IllegalStateException(e);
    }
  }
}

Consumer的完整实现代码大致如下:

/**
 * How to use:
 * 1. create a consumer instance with a queue name and a stateless messageHandler function.
 * 2. call runAsync() method to start listening to the queue.
 */
public class SqsMessageConsumer implements Runnable {
  private static final Logger logger = LoggerFactory.getLogger(SqsMessageConsumer.class);

  private final String queueUrl;
  private final Consumer messageHandler;
  private final int maxParallelism;

  private final SqsAsyncClient client;
  private final ExecutorService workerThreadPool;

  private volatile boolean shouldShutdown = false;

  public SqsMessageConsumer(
      String queueUrl,
      String threadPoolPrefix,
      Consumer messageHandler) {
    this(queueUrl, threadPoolPrefix, messageHandler, 8);
  }

  public SqsMessageConsumer(
      String queueUrl,
      String threadPoolPrefix,
      Consumer messageHandler,
      int maxParallelism) {
    this.queueUrl = queueUrl;
    this.messageHandler = messageHandler;
    this.maxParallelism = maxParallelism;
    client = new SqsClientFactory().createSqsAsyncClient();
    workerThreadPool =
        new ThreadPoolExecutor(
            maxParallelism,
            maxParallelism,
            0,
            TimeUnit.SECONDS,
            // bounded queue for back pressure
            new LinkedBlockingQueue<>(100),
            new CustomizableThreadFactory(threadPoolPrefix + "-pool-"),
            new TimeoutBlockingPolicy(30));
  }

  /** Use this method by default, it is asynchronous and handles threading for you. */
  public void runAsync() {
    Thread mainLoopThread = new Thread(this);
    mainLoopThread.start();
  }

  /**
   * Use this method only if you run it in your own thread pool, it runs synchronously in the
   * contextual thread.
   */
  @Override
  public void run() {
      Thread mainLoopThread = Thread.currentThread();
      // JVM awaits all shutdown hooks to complete
      // https://stackoverflow.com/questions/8663107/how-does-the-jvm-terminate-daemon-threads-or-how-to-write-daemon-threads-that-t
      Runtime.getRuntime()
        .addShutdownHook(
            new Thread(
                () -> {
                  shouldShutdown = true;
                  mainLoopThread.interrupt();
                  try {
                    workerThreadPool.shutdown();
                    boolean terminated = workerThreadPool.awaitTermination(1, TimeUnit.MINUTES);
                    if (!terminated) {
                      List runnables = workerThreadPool.shutdownNow();
                      logger.info("shutdownNow with {} runnables undone", runnables.size());
                    }
                  } catch (RuntimeException e) {
                    logger.error("shutdown failed", e);
                    throw e;
                  } catch (InterruptedException e) {
                    logger.error("shutdown interrupted", e);
                    throw new IllegalStateException(e);
                  }
                }));

    logger.info("polling loop started");
    int receiveBackoffSeconds = 1;
    // "shouldShutdown" state is more reliable than Thread interrupted state
    while (!shouldShutdown) {
      List messages;
      try {
        messages = receiveMessages();
        // after success, restore backoff to the initial value
        receiveBackoffSeconds = 1;
      } catch (Throwable e) {
        logger.error("failed to receive", e);
        logger.info("Gonna sleep {} seconds for backoff", receiveBackoffSeconds);
        try {
          //noinspection BusyWait
          Thread.sleep(receiveBackoffSeconds * 1000L);
        } catch (InterruptedException ex) {
          logger.error("backoff sleep interrupted", ex);
        }
        // after failure, increment next backoff (≤ limit)
        receiveBackoffSeconds = exponentialBackoff(receiveBackoffSeconds, 60);
        continue;
      }

      try {
        dispatchMessages(queueUrl, messages);
      } catch (Throwable e) {
        logger.error("failed to dispatch", e);
      }
    }
  }

  private int exponentialBackoff(int current, int limit) {
    int next = current * 2;
    return Math.min(next, limit);
  }

  private List receiveMessages() throws ExecutionException, InterruptedException {
    // visibilityTimeout = message handling timeout
    // It has usually been set at infrastructure level
    var receiveMessageFuture =
        client.receiveMessage(
            ReceiveMessageRequest.builder()
                .queueUrl(queueUrl)
                .waitTimeSeconds(10)
                .maxNumberOfMessages(maxParallelism)
                .build());
    // Consumer can wait infinitely for the next message, rely on library default timeout.
    return receiveMessageFuture.get().messages();
  }

  private void dispatchMessages(String queueUrl, List messages) {
    for (Message message : messages) {
      workerThreadPool.execute(
          () -> {
            String messageId = message.messageId();
            try {
              logger.info("Started handling message with id={}", messageId);
              messageHandler.accept(message);
              logger.info("Completed handling message with id={}", messageId);
              // Should delete the succeeded message
              client.deleteMessage(
                  DeleteMessageRequest.builder()
                      .queueUrl(queueUrl)
                      .receiptHandle(message.receiptHandle())
                      .build());
              logger.info("Deleted handled message with id={}", messageId);
            } catch (Throwable e) {
              // Logging is enough. Failed message is not deleted, will be retried at next polling.
              logger.error("Failed to handle message with id=$messageId", e);
            }
          });
    }
  }

  // Used by workerThreadPool
  private static class TimeoutBlockingPolicy implements RejectedExecutionHandler {

    private final long timeoutSeconds;

    public TimeoutBlockingPolicy(long timeoutSeconds) {
      this.timeoutSeconds = timeoutSeconds;
    }

    @Override
    public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
      try {
        BlockingQueue queue = executor.getQueue();
        if (!queue.offer(r, this.timeoutSeconds, TimeUnit.SECONDS)) {
          throw new RejectedExecutionException("Timeout after " + timeoutSeconds + " seconds");
        }
      } catch (InterruptedException e) {
        throw new IllegalStateException(e);
      }
    }
  }
}

你可能感兴趣的:(动手写Amazon SQS客户端)