client访问rpc设计
最简单的rpc设计,仅需一小段代码就可以完成。性能的主要消耗是网络通信与序列化和反序列化。这里我们设计一个较为通用的rpc框架,可以完成java下的任意远程调用。
我们选择netty来做网络通信,暂时选择三种可用的java序列化方式hessian,kyro和jdk自己的序列化。
通信数据
根据请求的类名、方法名、方法参数,以及接口名,就可以根据反射和代理来处理请求。
public class MessageRequest implements Serializable {
private static final long serialVersionUID = 779639215038924077L;
private String messageId; //唯一ID
private String className; //接口名
private String methodName; //方法名
private Class>[] typeParameters; //参数类型
private Object[] parametersVal; //参数
public String getMessageId() {
return messageId;
}
public void setMessageId(String messageId) {
this.messageId = messageId;
}
public String getClassName() {
return className;
}
public void setClassName(String className) {
this.className = className;
}
public String getMethodName() {
return methodName;
}
public void setMethodName(String methodName) {
this.methodName = methodName;
}
public Class>[] getTypeParameters() {
return typeParameters;
}
public void setTypeParameters(Class>[] typeParameters) {
this.typeParameters = typeParameters;
}
public Object[] getParameters() {
return parametersVal;
}
public void setParameters(Object[] parametersVal) {
this.parametersVal = parametersVal;
}
@Override
public String toString() {
return "MessageRequest{" +
"messageId='" + messageId + '\'' +
", className='" + className + '\'' +
", methodName='" + methodName + '\'' +
", typeParameters=" + Arrays.toString(typeParameters) +
", parametersVal=" + Arrays.toString(parametersVal) +
'}';
}
}
public class MessageResponse implements Serializable {
private static final long serialVersionUID = -4628239730293658445L;
private String messageId; //唯一ID
private String error; //错误消息
private Object resultDesc; //方法调用结果
public String getMessageId() {
return messageId;
}
public void setMessageId(String messageId) {
this.messageId = messageId;
}
public String getError() {
return error;
}
public void setError(String error) {
this.error = error;
}
public Object getResult() {
return resultDesc;
}
public void setResult(Object resultDesc) {
this.resultDesc = resultDesc;
}
@Override
public String toString() {
return "MessageResponse{" +
"messageId='" + messageId + '\'' +
", error='" + error + '\'' +
", resultDesc=" + resultDesc +
'}';
}
}
public class MessageCallBack {
private MessageResponse response;
private Lock lock = new ReentrantLock();
private Condition finish = lock.newCondition();
public Object start() throws InterruptedException {
try {
lock.lock();
finish.await(10 * 1000, TimeUnit.MILLISECONDS);
if (this.response != null) {
return this.response.getResult();
} else {
return null;
}
}finally {
lock.unlock();
}
}
public void over(MessageResponse reponse) {
try {
lock.lock();
finish.signal();
this.response = reponse;
} finally {
lock.unlock();
}
}
}
序列化与反序列化
类的encode和decode,设计一个接口
public interface MessageCodecUtil {
int MESSAGE_LENGTH = 4;
void encode(final ByteBuf out, final Object message) throws IOException;
Object decode(byte[] body) throws IOException;
}
encode和decode都需要用到序列化和反序列化,再设计一个接口
public interface RpcSerialize {
void serialize(OutputStream output, Object object) throws IOException;
Object deserialize(InputStream input) throws IOException;
}
netty的childHandler需要pipeline添加一系列处理,再设计一个接口
public interface SerializeFrame {
void select(ChannelPipeline pipeline, Map handlerMap);
}
因为有多种序列化方式可以选择,这里我们使用策略模式,不同的策略对应不同的序列化和encode方式
public class SerializeContext {
private SerializeFrame serializeFrame;
public SerializeContext(SerializeFrame serializeFrame) {
this.serializeFrame = serializeFrame;
}
public void setSerializeFrame(SerializeFrame serializeFrame) {
this.serializeFrame = serializeFrame;
}
public void select(ChannelPipeline pipeline, Map handlerMap) {
serializeFrame.select(pipeline,handlerMap);
}
}
具体实现就不贴代码了,只是在序列化对象构造时使用倒了对象池化技术,kyro是有自己的KryoPool的,hessian就需要使用common-pool2来池化对象。
另外,服务端还需要定义一个注解,用来确认对应的服务端实现类
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface RpcServer {
Class> value();
}
对于这个注解的处理,我们在服务起来的时候模仿spring依赖注入的方式,将所有的service先实例化
public class RpcAnnotationFactory {
public static Map getBeansWithAnnotation(Class extends Annotation> annotation,String packageName){
Map handlerMap = new ConcurrentHashMap<>();
String packageDirName = packageName.replace('.', '/');
Enumeration dirs;
try {
dirs = Thread.currentThread().getContextClassLoader().getResources(
packageDirName);
while (dirs.hasMoreElements()){
URL url = dirs.nextElement();
String protocol = url.getProtocol();
if ("file".equals(protocol)) {
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
findAndAddClassesInPackageByFile(annotation, packageName, filePath,
handlerMap);
}
}
} catch (IOException e) {
e.printStackTrace();
}
return handlerMap;
}
private static void findAndAddClassesInPackageByFile(Class extends Annotation> annotation, String packageName, String packagePath, Map handlerMap){
File dir = new File(packagePath);
if (!dir.exists() || !dir.isDirectory()) {
return;
}
File[] dirfiles = dir.listFiles(new FileFilter(){
@Override
public boolean accept(File pathname) {
return pathname.isDirectory() || pathname.getName().endsWith(".class");
}
});
for (File file : dirfiles){
if (file.isDirectory()) {
findAndAddClassesInPackageByFile(annotation,packageName + "."
+ file.getName(), file.getAbsolutePath(),
handlerMap);
}else{
String className = file.getName().substring(0, file.getName().length() - 6);
Object serviceBean = null;
try {
Class clazz = Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className);
Annotation anno = clazz.getAnnotation(annotation);
if(anno != null){
serviceBean = clazz.newInstance();
handlerMap.put(anno,serviceBean);
}
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
}
在MessageRecvExecutor初始化的时候,实现依赖注入.这里还是用了google的ListeningExecutorService来做异步
public class MessageRecvExecutor {
private static Logger LOG = Logger.getLogger(MessageRecvExecutor.class);
private String serverAddress;
private SerializeFrame serializeFrame;
private final static String DELIMITER = ":";
private Map handlerMap = new ConcurrentHashMap<>();
private static ListeningExecutorService threadPoolExecutor;
public MessageRecvExecutor(String serverAddress, SerializeFrame serializeFrame) {
this.serverAddress = serverAddress;
this.serializeFrame = serializeFrame;
}
public static void submit(Callable task, final ChannelHandlerContext context, final MessageRequest request, final MessageResponse response){
if (threadPoolExecutor == null) {
synchronized (MessageRecvExecutor.class) {
if (threadPoolExecutor == null) {
threadPoolExecutor = MoreExecutors.listeningDecorator((ThreadPoolExecutor) RpcThreadPool.getExecutor(16, -1));
}
}
}
ListenableFuture listenableFuture = threadPoolExecutor.submit(task);
Futures.addCallback(listenableFuture, new FutureCallback() {
@Override
public void onSuccess(@Nullable Boolean aBoolean) {
context.writeAndFlush(response).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOG.info("RPC Server Send message-id respone:" + request.getMessageId());
}
});
}
@Override
public void onFailure(Throwable throwable) {
throwable.printStackTrace();
}
}, threadPoolExecutor);
}
public void init(String packageName){
Map rpcServiceObject = RpcAnnotationFactory.getBeansWithAnnotation(RpcServer.class,packageName);
if(MapUtils.isNotEmpty(rpcServiceObject)) {
for(Object serviceBean : rpcServiceObject.values()) {
String interfaceName = serviceBean.getClass().getAnnotation(RpcServer.class).value().getName();
handlerMap.put(interfaceName, serviceBean);
}
}
}
public void afterPropertiesSet(){
ThreadFactory threadRpcFactory = new NamedThreadFactory("NettyRPC ThreadFactory");
int parallel = Runtime.getRuntime().availableProcessors() * 2;
EventLoopGroup boss = new NioEventLoopGroup();
EventLoopGroup worker = new NioEventLoopGroup(parallel, threadRpcFactory, SelectorProvider.provider());
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(boss, worker).channel(NioServerSocketChannel.class)
.childHandler(new MessageRecvChannelInitializer(handlerMap).buildRpcSerializeProtocol(new SerializeContext(serializeFrame)))
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
String[] ipAddr = serverAddress.split(MessageRecvExecutor.DELIMITER);
if(ipAddr.length == 2){
String host = ipAddr[0];
int port = Integer.parseInt(ipAddr[1]);
try {
ChannelFuture future = bootstrap.bind(host, port).sync();
LOG.info(String.format("Netty RPC Server start success!\nip:%s\nport:%d\nprotocol:%s\n\n", host, port, new SerializeContext(serializeFrame).toString()));
future.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
worker.shutdownGracefully();
boss.shutdownGracefully();
}
}else {
LOG.error("Netty RPC Server start fail!\n");
}
}
}
server端具体的计算是放在一个Callable里面,使用反射来进行
public class MessageRecvInitializeTask implements Callable {
private MessageRequest request = null;
private MessageResponse response = null;
private Map handlerMap = null;
public MessageResponse getResponse() {
return response;
}
public MessageRequest getRequest() {
return request;
}
public void setRequest(MessageRequest request) {
this.request = request;
}
MessageRecvInitializeTask(MessageRequest request, MessageResponse response, Map handlerMap) {
this.request = request;
this.response = response;
this.handlerMap = handlerMap;
}
@Override
public Boolean call() throws Exception {
response.setMessageId(request.getMessageId());
try {
Object result = reflect(request);
response.setResult(result);
return Boolean.TRUE;
} catch (Throwable t) {
response.setError(t.toString());
t.printStackTrace();
return Boolean.FALSE;
}
}
private Object reflect(MessageRequest request) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
String className = request.getClassName();
Object serviceBean = handlerMap.get(className);
String methodName = request.getMethodName();
Object[] parameters = request.getParameters();
return MethodUtils.invokeMethod(serviceBean, methodName, parameters);
}
}
具体的远程调用方法需要约定相同的接口和实现,然后client端封装MessageSendExecutor来进行调用。client端主要也是使用netty将MessageRequest发送,并将放回的消息序列化为MessageResponse。
public class RpcServerLoader {
private volatile static RpcServerLoader rpcServerLoader;
private final static String DELIMITER = ":";
private final static int parallel = Math.max(2,Runtime.getRuntime().availableProcessors() * 2);
private EventLoopGroup eventLoopGroup = new NioEventLoopGroup(parallel);
private static ListeningExecutorService threadPoolExecutor = MoreExecutors.listeningDecorator((ThreadPoolExecutor) RpcThreadPool.getExecutor(16, -1));
private MessageSendHandler messageSendHandler;
private Lock lock = new ReentrantLock();
private Condition connectStatus = lock.newCondition();
private Condition handlerStatus = lock.newCondition();
private RpcServerLoader() {
}
public static RpcServerLoader getInstance() {
if (rpcServerLoader == null) {
synchronized (RpcServerLoader.class) {
if (rpcServerLoader == null) {
rpcServerLoader = new RpcServerLoader();
}
}
}
return rpcServerLoader;
}
public void load(String serverAddress, SerializeFrame serializeFrame) {
String[] ipAddr = serverAddress.split(RpcServerLoader.DELIMITER);
if (ipAddr.length == 2) {
String host = ipAddr[0];
int port = Integer.parseInt(ipAddr[1]);
final InetSocketAddress remoteAddr = new InetSocketAddress(host, port);
ListenableFuture listenableFuture = threadPoolExecutor.submit(new MessageSendInitializeTask(eventLoopGroup, remoteAddr, serializeFrame));
Futures.addCallback(listenableFuture, new FutureCallback() {
@Override
public void onSuccess(@Nullable Boolean aBoolean) {
try {
lock.lock();
if (messageSendHandler == null) handlerStatus.await();
if (aBoolean == Boolean.TRUE && messageSendHandler != null) connectStatus.signalAll();
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
@Override
public void onFailure(Throwable throwable) {
throwable.printStackTrace();
}
}, threadPoolExecutor);
}
}
public void setMessageSendHandler(MessageSendHandler messageInHandler) {
try {
lock.lock();
this.messageSendHandler = messageInHandler;
handlerStatus.signal();
} finally {
lock.unlock();
}
}
public MessageSendHandler getMessageSendHandler() throws InterruptedException {
try {
lock.lock();
if (messageSendHandler == null) {
connectStatus.await();
}
return messageSendHandler;
} finally {
lock.unlock();
}
}
public void unLoad() {
messageSendHandler.close();
threadPoolExecutor.shutdown();
eventLoopGroup.shutdownGracefully();
}
}
client提交任务Callable
public class MessageSendInitializeTask implements Callable {
private EventLoopGroup eventLoopGroup = null;
private InetSocketAddress serverAddress;
private SerializeFrame serializeFrame;
public MessageSendInitializeTask(EventLoopGroup eventLoopGroup, InetSocketAddress serverAddress, SerializeFrame serializeFrame) {
this.eventLoopGroup = eventLoopGroup;
this.serverAddress = serverAddress;
this.serializeFrame = serializeFrame;
}
@Override
public Boolean call() throws Exception {
Bootstrap b = new Bootstrap();
b.group(eventLoopGroup).channel(NioSocketChannel.class).option(ChannelOption.SO_KEEPALIVE, true);
b.handler(new MessageSendChannelInitializer().buildRpcSerializeProtocol(new SerializeContext(serializeFrame)));
ChannelFuture future = b.connect(serverAddress);
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if(future.isSuccess()){
MessageSendHandler handler = future.channel().pipeline().get(MessageSendHandler.class);
RpcServerLoader.getInstance().setMessageSendHandler(handler);
}
}
});
return Boolean.TRUE;
}
}
MessageSendHandler主要用来处理返回数据MessageResponse
public class MessageSendHandler extends ChannelInboundHandlerAdapter {
private ConcurrentHashMap mapCallBack = new ConcurrentHashMap<>();
private volatile Channel channel;
private SocketAddress remoteAddr;
public Channel getChannel() {
return channel;
}
public SocketAddress getRemoteAddr() {
return remoteAddr;
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
super.channelRegistered(ctx);
this.channel = ctx.channel();
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
this.remoteAddr = this.channel.remoteAddress();
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
MessageResponse response = (MessageResponse) msg;
String messageId = response.getMessageId();
MessageCallBack callBack = mapCallBack.get(messageId);
if (callBack != null) {
mapCallBack.remove(messageId);
callBack.over(response);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
}
public void close() {
channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
public MessageCallBack sendRequest(MessageRequest request){
MessageCallBack callBack = new MessageCallBack();
mapCallBack.put(request.getMessageId(), callBack);
channel.writeAndFlush(request);
return callBack;
}
}