Thrift源码修改,改造RPC,支持多Processor[java版]

说明:基于0.8.0版本

 

Thrift rpc只提供单个接口的模式,一个TServerSocket只能实例化一个Processor实际上,当接口的模式再实际应用的过程中也许会比较少,不知道为什facebook要这么做,每开放一个服务接口都得占用一个端口,对于服务的管理也会带来不便。采用共用端口的模式可以解决端口问题,但是单个进程好像还是可以加载一个Processor,不知道是不是没找对方法。

 

为了实现加载多Processor需求,开始对thrift源码做相关的修改。

 

【服务端相关改造】

需要修改的类:

TSimpleServer(阻塞,简单模式,一般不怎么用)

TThreadPoolServer.WorkerProcess(阻塞、线程池模式)

AbstractNonblockingServer(非阻塞模式)

TBaseProcessor

TProcessorFactory

 

【客户端相关改造】

需要修改的类:

TServiceClient(阻塞)

TCompactProtocol(非阻塞)

 

【新增的类】

ClassScanner(类扫描器,加载processor

Service(服务注解)

 

代码分析:

先看看服务端的启动的例子

TServerSocket serverTransport = new TServerSocket(8811);

ServiceDemo.Processor processor = new ServiceDemo.Processor(newServiceDemoImpl());

Factory protFactory = new TBinaryProtocol.Factory(true,true);

Args rpcArgs = new Args(serverTransport);

rpcArgs.processor(processor);

rpcArgs.protocolFactory(protFactory);

TServer server = new TThreadPoolServer(rpcArgs);

server.serve();

 

所有生成代码的Processor均继承TBaseProcessor,因此首先分析TBaseProcessor

具体可看TBaseProcessor的属性private final I iface,这里的iface实例声明是final,同时生成的客户端代吗也可以看出。

privatestatic extends Iface>Maporg.apache.thrift.ProcessFunctionextends org.apache.thrift.TBase>> getProcessMap(Maporg.apache.thrift.ProcessFunctionextends org.apache.thrift.TBase>>processMap) {

      processMap.put("test",new test());

      return processMap;

}

这里的processMapkey为方法名,在发送的时候,该key将会组装到TMessagename中,并没有把类的相关信息带过去。这也是为啥thrift不支持方法重载的原因,应该是考虑到其他有些语言不支持重载的原因,比如C

 

再看TProcessorFactory

public class TProcessorFactory {

 private final TProcessor processor_;

 public TProcessorFactory(TProcessor processor) {

   processor_ = processor;

  }

 public TProcessor getProcessor(TTransport trans) {

   return processor_;

  }

}

这里的TProcessor也是final类型,服务器启动过程中,会调用该构造函数,实例化processor_属性,执行rpcArgs.processor(processor)的时候,会初始化ProcessorFactory,具体可以看AbstractServerArgs的方法

public Tprocessor(TProcessor processor) {

     this.processorFactory = new TProcessorFactory(processor);

     return (T) this;

}

 

看到这里,大概的思路就清楚了,主要有两点:

1、 把final类型的processor改为从初始化到容器中;

2、 客户端传输TMessage加入类的信息;

 

 

先看客户调用的修改,分阻塞和非阻塞两种模式

 

对于阻塞模式,只修改一个类TServiceClient

由于生成工具无法修改,因此无法从修改生成代码结构处理,只能从调用的库上做文章,实际上,所有Client的接口具有两步处理,发送-接收,均调用TServiceClient中的方法sendBase /receiveBase,只需要在sendBase中加入类信息到TMessage中即可,如下:

protected voidsendBase(String methodName, TBase args) throws TException {

               // 处理调用接口

               String className = "";

               StackTraceElement stack[] = (new Throwable()).getStackTrace();

               if (stack.length > 1) {

                           StackTraceElementste = stack[1];

                           className =ste.getFileName();

                           className =className.split("\\.")[0];

               }

               oprot_.writeMessageBegin(new TMessage(className+ "_" + methodName, TMessageType.CALL, ++seqid_));

               args.write(oprot_);

               oprot_.writeMessageEnd();

               oprot_.getTransport().flush();

  }

红色部分为获取类名,writeMessageBegin的时候加入即可,默认格式:类名_方法名;这里去的类名并非内部类Processor,而是生成的服务接口文件名,比如ServiceDemo,其内部结构如下:

 

同样,在服务端初始化Processor容器的时候也取改类文件名。

 

对于非阻塞模式,需要修改TCompactProtocol类,

同样的原理,修改

publicvoidwriteMessageBegin(TMessagemessage)throws TException {

     //处理调用接口

     String className = "";

     StackTraceElement stack[] = (new Throwable()).getStackTrace();

     if (stack.length > 1) {

         StackTraceElement ste = stack[1];

         className = ste.getFileName();

         className = className.split("\\.")[0];

     }

    writeByteDirect(PROTOCOL_ID);

    writeByteDirect((VERSION &VERSION_MASK) |((message.type <<TYPE_SHIFT_AMOUNT) &TYPE_MASK));

    writeVarint32(message.seqid);

    writeString(className + "_" + message.name);  //加入类名,格式:类名_方法名

}

至此,第一步完成。

 

下面修改服务端处理

 

1、  修改TProcessorFactory

保持原来的内容,只新增容器,新增Processor存取方法,加入processor容器加载的处理,需要用到类扫描器(ClassScanner)和对应的服务注解类(org.apache.thrift.Service),后面会介绍初始化的处理方法,具体如下:

publicclass TProcessorFactory {

     privatestatic Loggerlogger = Logger.getLogger(TProcessorFactory.class);

     privatefinal TProcessorprocessor_;

     //新增processor容器

     privatestatic MapprocessorMap =new HashMap();

     //新增function容器

     privatestatic MapfunctionMap =new HashMap();

     public TProcessorFactory(TProcessor processor) {

         processor_ = processor;

     }

     public TProcessor getProcessor(TTransport trans) {

         returnprocessor_;

     }

     /**

      * @Title: addProcessor

      * @Description:新增方法-加入processor到容器

      * @param@param key

      * @param@param processor

      * @return void

      */

     publicstaticvoid addProcessor(String key, Object processor) {

         if (processorMap.containsKey(key)) {

              return;

         }

         processorMap.put(key, processor);

         logger.info("加载Processor" +processor.getClass().getName());

     }

 

     /**

      * @Title: getProcessor

      * @Description:新增方法获取processor

      * @param@param key

      * @param@return

      * @return TProcessor

      */

     publicstatic Object getProcessor(String key) {

         returnprocessorMap.get(key);

     }

 

     /**

      * @Title: addFunction

      * @Description:新增方法-添加function到容器

      * @param@param key

      * @param@param func

      * @return void

      */

     publicstaticvoid addFunction(String key, ProcessFunction func) {

         if (functionMap.containsKey(key)) {//判断是否已经存在

              return;

         }

         functionMap.put(key, func);

     }

 

     /**

      * @Title: getFunction

      * @Description:新增方法-获取function实例

      * @param@param key

      * @param@return

      * @return ProcessFunction

      */

     publicstatic ProcessFunction getFunction(String key) {

         return (ProcessFunction)functionMap.get(key);

     }

    

     /**

      * @Title:getProcessorCount

      * @Description:获取processor数量

      * @param@return

      * @returnint

      * @throws

      */

     publicstaticint getProcessorCount() {

         returnprocessorMap.size();

     }

    

     /**

      * @Title:init

      * @Description:初始化processor容器

      * @param@param basePackage扫描基础包

      * @return void

      * @throws

      */

     publicstaticvoid init(String basePackage)throws Exception {

         init(new String[]{basePackage});

     }

    

     /**

      * @Title:init

      * @Description:初始化容器

      * @param@param basePackages扫描基础包集合

      * @return void

      * @throws

      */

     publicstaticvoid init(String[] basePackages)throws Exception {

         Set> classList = ClassScanner.getClasses(basePackages,null,true,false,true);

         List> ifaceList =new ArrayList();

         Map pMap =new HashMap();

         List> serviceList =new ArrayList();

         if (classList !=null && classList.size() > 0) {

              for (Class item : classList) {

                   if (item.getName().contains("Iface")) {

                       ifaceList.add(item);

                   } elseif (item.getName().endsWith("Processor")){

                       pMap.put(item.getName().split("\\$")[0], item);

                   } elseif (item.getAnnotation(org.apache.thrift.Service.class) != null){

                       serviceList.add(item);

                   }

              }

         }

        

         if (ifaceList !=null && ifaceList.size() > 0) {

              for (Class face : ifaceList) {

                   for (Class service : serviceList) {

                       if (face.isAssignableFrom(service)) {

                            Class processor =pMap.get(face.getName().split("\\$")[0]);

                            if (processor !=null) {

                                 // 获取构造函数并初始化

                                 Constructor ctor[] =processor.getDeclaredConstructors(); 

                                 for (Constructor c : ctor) {

                                     Class cx[] =c.getParameterTypes();

                                     if (cx.length == 1 &&cx[0].isAssignableFrom(face)) {

                                         try {//初始化

                                              processor.getConstructor(cx).newInstance(service.newInstance());

                                         }catch (Exception e) {

                                              logger.error("初始化失败:" + e);

                                         }

                                     }

                                 }

                            }

                       }

                   }

              }

         }

     }

}

 

2、  修改TBaseProcessor

 

publicabstractclass TBaseProcessorimplements TProcessor {

     privatefinal Iiface;//保留该实例,将原有的processFunctionMap转移到TProcessorFactory

     /**

      * 修改构造函数处理,

      * @param iface

      * @paramprocessFunctionMap

      */

     protected TBaseProcessor(I iface, MapextendsTBase>> processFunctionMap) {

         //保留该变量,为了不修改原有的process方法

         this.iface = iface;

         //解析该processor类名,和客户端解析规则一致

         String className = this.getClass().getName().split("\\$")[0];

         className = className.substring(className.lastIndexOf('.') + 1);

         if (processFunctionMap !=null) {//如果方法列表不为空

              Set keys = processFunctionMap.keySet();

              for (String key : keys) {

                   //逐个添加该Processor的方法到方法容器,key格式:类名_方法名

                   TProcessorFactory.addFunction(className +"_" + key,processFunctionMap.get(key));

              }

         }

         //添加该processorProcessor容器,key格式:类名

         TProcessorFactory.addProcessor(className, iface);

     }

    

     /**

      * 实际上不在使用该方法,因生成代码引用该方法,故保留

      */

     @Override

     publicboolean process(TProtocol in, TProtocol out)throws TException {

         TMessage msg = in.readMessageBegin();

         //ProcessFunctionfn =(ProcessFunction) processMap.get(msg.name);

         //改为从TProcessorFactory获取方法实例

         ProcessFunction fn = (ProcessFunction)TProcessorFactory.getFunction(msg.name);

         if (fn ==null) {

              TProtocolUtil.skip(in, TType.STRUCT);

              in.readMessageEnd();

              TApplicationException x =new TApplicationException(

                       TApplicationException.UNKNOWN_METHOD,

                       "Invalidmethod name: '" + msg.name +"'");

              out.writeMessageBegin(new TMessage(msg.name,

                       TMessageType.EXCEPTION, msg.seqid));

              x.write(out);

              out.writeMessageEnd();

              out.getTransport().flush();

              returntrue;

         }

         fn.process(msg.seqid, in, out,iface);

         returntrue;

     }

 

     /**

      * @Title: processing

      * @Description:process功能一致,改为从容器中获取processor,改为静态方法

      * @param@param in

      * @param@param out

      * @param@return

      * @param@throws TException 

      * @return boolean

      */

     publicstaticboolean processing(TProtocol in, TProtocol out)throws TException {

         TMessage msg = in.readMessageBegin();

         ProcessFunction fn = (ProcessFunction)TProcessorFactory.getFunction(msg.name);

         if (fn ==null) {

              TProtocolUtil.skip(in, TType.STRUCT);

              in.readMessageEnd();

              TApplicationException x =new TApplicationException(

                       TApplicationException.UNKNOWN_METHOD,

                       "Invalidmethod name: '" + msg.name +"'");

              out.writeMessageBegin(new TMessage(msg.name,

                       TMessageType.EXCEPTION, msg.seqid));

              x.write(out);

              out.writeMessageEnd();

              out.getTransport().flush();

              returntrue;

         }

         //改为重容器中获取processor

         fn.process(msg.seqid, in, out, TProcessorFactory.getProcessor(msg.name.split("_")[0]));

         returntrue;

     }

}

 

3、  接下来修改TThreadPoolServer. WorkerProcess,一般只用TThreadPoolServer

只需要修改run方法即可,具体如下

修改前:

public void run() {

                              //processor实例的转移到TBaseProcessor处理

                              //TProcessorprocessor = null;

                              TTransportinputTransport = null;

                              TTransportoutputTransport = null;

                              TProtocolinputProtocol = null;

                              TProtocoloutputProtocol = null;

                              try{

                                            System.out.println("client" + client_.getClass().getName());

                                            //processor = processorFactory_.getProcessor();

                                            inputTransport= inputTransportFactory_.getTransport(client_);

                                            outputTransport= outputTransportFactory_.getTransport(client_);

                                            inputProtocol= inputProtocolFactory_.getProtocol(inputTransport);

                                            outputProtocol= outputProtocolFactory_.getProtocol(outputTransport);

                                            //主要修改,改为调用静态方法processing

                                            while(!stopped_ && TBaseProcessor.processing(inputProtocol, outputProtocol)){

                                            }

                              }catch (TTransportException ttx) {

                                            //Assume the client died and continue silently

                              }catch (TException tx) {

                                            LOGGER.error("Thrifterror occurred during processing of message.", tx);

                              }catch (Exception x) {

                                            LOGGER.error("Erroroccurred during processing of message.", x);

                              }

                              if(inputTransport != null) {

                                            inputTransport.close();

                              }

                              if(outputTransport != null) {

                                            outputTransport.close();

                              }

                }

 

TSimpleServer:

修改serve方法

publicvoid serve() {

         stopped_ =false;

         try {

              serverTransport_.listen();

         } catch (TTransportException ttx) {

              LOGGER.error("Erroroccurred during listening.", ttx);

              return;

         }

 

         setServing(true);

 

         while (!stopped_) {

              TTransport client =null;

              TProcessor processor =null;

              TTransport inputTransport =null;

              TTransport outputTransport =null;

              TProtocol inputProtocol =null;

              TProtocol outputProtocol =null;

              try {

                   client = serverTransport_.accept();

                   if (client !=null) {

                       processor =processorFactory_.getProcessor(client);

                       inputTransport =inputTransportFactory_.getTransport(client);

                       outputTransport =outputTransportFactory_.getTransport(client);

                       inputProtocol =inputProtocolFactory_.getProtocol(inputTransport);

                       outputProtocol =outputProtocolFactory_.getProtocol(outputTransport);

                       //修改为调用TBaseProcessor.processing方法

                       // while(processor.process(inputProtocol, outputProtocol)){}

                       while (TBaseProcessor.processing(inputProtocol,outputProtocol)) {

                       }

                   }

              } catch (TTransportException ttx) {

                   // Client died,just move on

              } catch (TException tx) {

                   if (!stopped_) {

                       LOGGER

                                 .error(

                                         "Thrift error occurred during processing of message.",

                                         tx);

                   }

              } catch (Exception x) {

                   if (!stopped_) {

                       LOGGER.error(

                                 "Error occurred during processing of message.", x);

                   }

              }

 

              if (inputTransport !=null) {

                   inputTransport.close();

              }

 

              if (outputTransport !=null) {

                   outputTransport.close();

              }

 

         }

         setServing(false);

     }

 

对应的非阻塞模式,修改AbstractNonblockingServer.FrameBuffer

public void invoke() {

      TTransport inTrans = getInputTransport();

      TProtocol inProt = inputProtocolFactory_.getProtocol(inTrans);

      TProtocol outProt = outputProtocolFactory_.getProtocol(getOutputTransport());

 

      try {

      //修改为调用TBaseProcessor.processing方法

      //processorFactory_.getProcessor(inTrans).process(inProt,outProt);

      TBaseProcessor.processing(inProt,outProt);

      responseReady();

      return;

      } catch (TException te) {

        LOGGER.warn("Exception while invoking!", te);

      } catch (Throwable t) {

        LOGGER.error("Unexpected throwable while invoking!", t);

      }

      // Thiswill only be reached when there is athrowable.

      state_ = FrameBufferState.AWAITING_CLOSE;

      requestSelectInterestChange();

}

 

 

至此,服务端多接口模式已经修改完成。

 

 

下面介绍服务器初始化的处理,需要执行服务所在的基础包,包括生成代码包路径和服务实现类路径,可以指定多个包路径,和Thrift的启动模式无大的差别,只是将原来的Processor初始化到Agrs里面的步骤,改为调用TProcessorFactory的初始化方法。

具体看下面代码,列举了3种服务端的初始化代码。

publicclass Server {

     publicstaticvoid main(String[] args)throws Exception {

         //初始化Processor

         TProcessorFactory.init("com.test.rpc");

         // SimpleServer启动

         new SimpleServer(8810).start();

         Thread.sleep(1000);

         // ThreadPoolServer启动

         new ThreadPoolServer(8811).start();

         Thread.sleep(1000);

         // NonBlockingServer启动

         new NonBlockingServer(8812).start();

     }

    

     staticclass SimpleServerextends Thread {

         privateintport = 0;

         public SimpleServer(int port) {

              this.port = port;

         }

        

         publicvoid run() {

              try {

                   Server.startSimpleServer(port);

              } catch (Exception e) {}

         }

     }

    

     staticclass ThreadPoolServerextends Thread {

         privateintport = 0;

         public ThreadPoolServer(int port) {

              this.port = port;

         }

        

         publicvoid run() {

              try {

                   Server.startThreadPoolServer(port);

              } catch (Exception e) {}

         }

     }

    

     staticclass NonBlockingServerextends Thread {

         privateintport = 0;

         public NonBlockingServer(int port) {

              this.port = port;

         }

        

         publicvoid run() {

              try {

                   Server.startNonBlockingServer(port);

              } catch (Exception e) {

                  

              }

         }

     }

    

     publicstaticvoid startSimpleServer(int port) throws Exception {

         TServerSocket serverTransport =new TServerSocket(port);

        

         TBinaryProtocol.Factory protFactory =newTBinaryProtocol.Factory(true,

                   true);

         TThreadPoolServer.Args rpcArgs =new TThreadPoolServer.Args(

                   serverTransport);

         rpcArgs.protocolFactory(protFactory);

         TServer server = new TThreadPoolServer(rpcArgs);

         System.out.println("Startingsimple server on port " + port +" ...");

         server.serve();

     }

    

     publicstaticvoid startThreadPoolServer(int port) throws Exception {

         TServerSocket serverTransport =new TServerSocket(port);

 

         TBinaryProtocol.Factory protFactory =newTBinaryProtocol.Factory(true,

                   true);

         TThreadPoolServer.Args rpcArgs =new TThreadPoolServer.Args(

                   serverTransport);

         rpcArgs.protocolFactory(protFactory);

         TServer server = new TThreadPoolServer(rpcArgs);

         System.out.println("Startingthread-pool server on port " + port + " ...");

         server.serve();

     }

    

     publicstaticvoid startNonBlockingServer(int port) throws Exception {

         TNonblockingServerSocket socket =newTNonblockingServerSocket(port);

        

         THsHaServer.Args arg = new THsHaServer.Args(socket);

         //使用非阻塞方式,按块的大小进行传输,类似于 Java中的 NIO

         arg.protocolFactory(new TCompactProtocol.Factory());

         arg.transportFactory(new TFramedTransport.Factory());

         TServer server = new THsHaServer(arg);

         System.out.println("Startingnon-blocking server on port " + port + " ...");

         server.serve();

     }

}

 

启动服务,可以看到以下信息:

2012-07-02 00:16:36,925[org.apache.thrift.TProcessorFactory]-[INFO]加载Processorcom.test.rpc.impl.TestServiceImpl

2012-07-02 00:16:36,950[org.apache.thrift.TProcessorFactory]-[INFO]加载Processorcom.test.rpc.impl.TestService2Impl

Starting simple server on port8810 ...

Starting thread-pool server onport 8811 ...

Starting non-blocking server on port 8812 ...

 

客户端调用没有任何变化,如下:

publicclass Client {

     publicstaticvoid main(String[] args)throws Exception {

         testSimpleServer();

        

         testThreadPoolServer();

        

         testNonBlockingServer();

     }

    

     publicstaticvoid testSimpleServer()throws Exception {

         TSocket tsocket = new TSocket("localhost", 8810);

         tsocket.open();

         TProtocol protocol = new TBinaryProtocol(tsocket);

 

         System.out.println("test simpleserver...");

         TestService.Client client1 =new TestService.Client(protocol);

         System.out.println("responseTestService.getUserName:" +client1.getUserName(3l));

 

         TestService1.Client client2 =new TestService1.Client(protocol);

         System.out.println("responseTestService1.test:" + client2.test());

         tsocket.close();

     }

    

     publicstaticvoid testThreadPoolServer()throws Exception {

         TSocket tsocket = new TSocket("localhost", 8811);

         tsocket.open();

         TProtocol protocol = new TBinaryProtocol(tsocket);

 

         System.out.println("testthread-pool server...");

         TestService.Client client1 =new TestService.Client(protocol);

         System.out.println("responseTestService.getUserName:" +client1.getUserName(3l));

 

         TestService1.Client client2 =new TestService1.Client(protocol);

         System.out.println("responseTestService1.test:" + client2.test());

         tsocket.close();

     }

    

     publicstaticvoid testNonBlockingServer()throws Exception {

         TAsyncClientManager clientManager =newTAsyncClientManager();

         TNonblockingTransport transport =newTNonblockingSocket("localhost", 8812, 10000);

         TProtocolFactory protocol =new TCompactProtocol.Factory();

         TestService.AsyncClient asyncClient =newTestService.AsyncClient(protocol, clientManager, transport);

         System.out.println("Clientcalls .....");

         MyCallback callBack = new MyCallback();

         asyncClient.getUserName(3l, callBack);

 

         TestService1.AsyncClient asyncClient1 =newTestService1.AsyncClient(protocol, clientManager, transport);

         asyncClient1.test(new TestService1CB());

        

         while (true) {

              Thread.sleep(1000);

         }

     }

    

     staticpublicclass MyCallbackimplementsAsyncMethodCallback {

 

         //返回结果

         @Override

         publicvoid onComplete(TestService.AsyncClient.getUserName_callresponse) {

              try {

                   System.out.println(response.getResult().toString());

              } catch (TException e) {

                   e.printStackTrace();

              }

         }

 

         //返回异常

         @Override

         publicvoid onError(Exception exception) {

              System.out.println("onError");

         }

     }

    

     staticclass TestService1CBimplementsAsyncMethodCallback {

         //返回结果

         @Override

         publicvoid onComplete(TestService1.AsyncClient.test_call response){

              try {

                   System.out.println(response.getResult().toString());

              } catch (TException e) {

                   e.printStackTrace();

              }

         }

 

         //返回异常

         @Override

         publicvoid onError(Exception exception) {

              System.out.println("onError");

         }

     }

}

 

执行响应如下:

test simple server...

responseTestService.getUserName:hello, I'm TestService.getUserName();

responseTestService1.test:hello, I'm TestService1.test();

test thread-pool server...

responseTestService.getUserName:hello, I'm TestService.getUserName();

response TestService1.test:hello, I'mTestService1.test();

test non-blocking server...

responseTestService.getUserName:hello, I'm TestService.getUserName();

responseTestService1.test:hello, I'm TestService1.test();

 

1、客户端阻塞/非阻塞放一起调用会报异常,不了解啥原因。

 

你可能感兴趣的:(RPC&序列化)