借助shared_ptr实现copy-on-write以提高多线程并发性能

        锁竞争是服务器性能四大杀手之一(其他三个是:数据拷贝、环境切换、动态资源申请),本文将基于之前发布的kimgbo网络I/O库,以一个多线程群发聊天服务器的实现为例,介绍如何借助shared_ptr提高多线程并发的性能。

        多线程群发聊天服务器实现的功能是,客户端连接服务器后,可以向服务器发送消息(消息=消息头+消息体),服务器负责将消息转发给其他正处于连接状态的客户端(包括发送消息的客户端)。示意图如下:

                   借助shared_ptr实现copy-on-write以提高多线程并发性能_第1张图片

       传统的基于Reactor模式的服务器,使用工作线程池来处理连接请求,并通过在操作之前加锁的方式来保护连接队列的数据安全。多线程对于请求队列的取出和插入操作实际上是串行的,整个服务器的并发性能较差。如果能让插入和取出处理转发任务的两个操作实现并行,则能够大大提升服务器的性能。

        shared_ptr是采用引用计数方式的智能指针,如果当前只有一个观察者,则其引用计数为1,可以通过shared_ptr::unique()判断,通过shared_ptr实现copy-on-write的原理如下:

1、read端在读之前将引用计数加1,读完将引用计数减1,这样可以保证在读期间其引用计数大于1,可以阻止并发写。

大致的流程如下:

    ConnectionListPtr connections = getConnectionList(); /*重新让一个shared_ptr指向连接队列,引用计数加1*/
    for (ConnectionList::iterator it = connections->begin(); it != connections->end(); ++it) /*读数据*/
    {
      m_codec.send((*it).get(), message); /*业务处理*/
    }

2、write端在写之前先检查引用计数是否为1,如果为1则直接修改。

3、write端写时如果发现引用计数大于1,则说明此时数据正在被read,则不能再原来的数据上并发写,应该创建一个副本,并在副本上修改,然后用副本替换以前的数据。

大致的流程如下:

    MutexLockGuard lock(m_mutex); /*此处需要加一下锁,但是仅仅是在写入是加锁,减少了读时锁的使用*/
    if(!m_connections.unique())
    {
    m_connections.reset(new ConnectionList(*m_connections)); /*如果引用计数大于1,则创建一个副本并在副本上修改,shared_ptr通过reset操作后会使引用计数减1,原先的数据在read结束后引用计数会减为0,进而被系统释放*/
    }
    assert(m_connections.unique()); /*断言新创建的副本引用计数一定为1*/
    

    /*下面执行相关业务处理*/
    if (conn->connected())
    { 
        m_connections->insert(conn);
    }
    else
    {
      m_connections->erase(conn);
    }

以上就是copy-on-write的大致实现流程,下面给出服务器程序的核心代码,全部代码参见kimgbo网络库的example/chat目录下https://github.com/kimg-bo/kimgbo ,kimgbo网络库的使用方式和muduo基本类似。核心代码如下:

#include 
#include 
#include 
#include 
#include "Logging.h"
#include "Mutex.h"
#include "EventLoop.h"
#include "TcpServer.h"
#include "codec.h"

using namespace kimgbo;
using namespace kimgbo::net;
	
class ChatServer
{
public:
  ChatServer(EventLoop* loop,
             const InetAddress& listenAddr)
  : m_loop(loop),
    m_server(m_loop, listenAddr, "ChatServer"),
    m_codec(std::bind(&ChatServer::onStringMessage, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)),
    m_connections(new ConnectionList)
  {
    m_server.setConnectionCallback(
        std::bind(&ChatServer::onConnection, this, std::placeholders::_1));
    m_server.setMessageCallback(
        std::bind(&LengthHeaderCodec::onMessage, &m_codec, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3));
  }

  void setThreadNum(int numThreads)
  {
    m_server.setThreadNum(numThreads);
  }

  void start()
  {
    m_server.start();
  }

private:
	typedef std::set ConnectionList; //存放链接的集合
	typedef std::shared_ptr ConnectionListPtr; //指向连接集合的shared_ptr
	
  void onConnection(const TcpConnectionPtr& conn)
  {
    LOG_INFO << conn->localAddress().toIpPort() << " -> "
        << conn->peerAddress().toIpPort() << " is "
        << (conn->connected() ? "UP" : "DOWN");

    MutexLockGuard lock(m_mutex); //write加锁
    LOG_INFO << "lock(m_mutex) ok"; 
    if(!m_connections.unique()) //检查引用计数
    {
    	LOG_INFO << "m_connections.unique()."; 
    	m_connections.reset(new ConnectionList(*m_connections)); //如果大于1,创建副本
    }
    assert(m_connections.unique());
    
    if (conn->connected())
    {
    	LOG_INFO << "insert before."; 
      m_connections->insert(conn);
      LOG_INFO << "insert ok"; 
    }
    else
    {
      m_connections->erase(conn);
    }
  }
  
  ConnectionListPtr getConnectionList() //获取链接集合
  {
  	MutexLockGuard lock(m_mutex);
  	return m_connections;
  }

  void onStringMessage(const TcpConnectionPtr&, const kimgbo::string& message, Timestamp)
  {
    ConnectionListPtr connections = getConnectionList(); //read操作直接读
    for (ConnectionList::iterator it = connections->begin(); it != connections->end(); ++it)
    {
      m_codec.send((*it).get(), message);
    }
  }
  
private:
  EventLoop* m_loop;
  TcpServer m_server;
  LengthHeaderCodec m_codec;
  MutexLock m_mutex;
  ConnectionListPtr m_connections;
};

int main(int argc, char* argv[])
{
  LOG_INFO << "pid = " << getpid();
  if (argc > 1)
  {
    EventLoop loop;
    uint16_t port = static_cast(atoi(argv[1]));
    InetAddress serverAddr(port);
    ChatServer server(&loop, serverAddr);
    if (argc > 2)
    {
      server.setThreadNum(atoi(argv[2]));
    }
    server.start();
    loop.loop();
  }
  else
  {
    printf("Usage: %s port [thread_num]\n", argv[0]);
  }
}

kimgbo开源网络I/O库见:https://github.com/kimg-bo/kimgbo


你可能感兴趣的:(Linux服务器端)