从上图可看出RPC主要分为三个部分:
(1)服务提供者(RPC Server),运行在服务器端,提供服务接口定义与服务实现类。
(2)服务中心(Registry),运行在服务器端,负责将本地服务发布成远程服务,管理远程服务,提供给服务消费者使用。
(3)服务消费者(RPC Client),运行在客户端,通过远程代理对象调用远程服务。
根据上面的rpc结构图,如果我们想实现一个rpc调用框架,则必须实现上述三部分。
1、定义客户端调用的API(服务接口)。
2、自定义一套协议,即客户端和服务端交互的数据。
3、编写服务提供方,即API的具体实现。
基于netty编写服务端程序作为注册中心,功能包括:
1、自定义协议的编解码。
2、处理客户端请求的handler,该handler的主要功能包括:
2.1、扫描对外服务实现类所在的包,将服务的提供方对象实例保存到容器中。
2.2、解析客户端传来的协议,根据协议类容调用容器中服务提供者的方法,并将结果返回。
基于netty编写服务端程序,功能包括:
1、将自定义协议发送给服务中心,通过自定义handler接受服务端的响应。
package com.syx.rpc.api;
public interface IRpcHelloService {
String hello(String name);
}
package com.syx.rpc.api;
public interface IRpcService {
int add(int a,int b);
int sub(int a,int b);
int mult(int a,int b);
}
package com.syx.rpc.protocol;
import lombok.Data;
import java.io.Serializable;
/**
* 客户端服务端交互的协议
*/
@Data
public class InvokerProtocol implements Serializable {
//类名
private String className;
//方法名
private String methodName;
//参数类型
private Class<?>[] paramTypes;
//参数列表
private Object[] params;
}
package com.syx.rpc.provider;
import com.syx.rpc.api.IRpcHelloService;
public class RpcHelloServiceImpl implements IRpcHelloService {
@Override
public String hello(String name) {
return "hello "+name;
}
}
package com.syx.rpc.provider;
import com.syx.rpc.api.IRpcService;
public class RpcServiceImpl implements IRpcService {
@Override
public int add(int a, int b) {
return a+b;
}
@Override
public int sub(int a, int b) {
return a-b;
}
@Override
public int mult(int a, int b) {
return a*b;
}
}
package com.syx.rpc.register;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolver;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
/**
* Rpc注册中心相当于一个服务器
*/
public class RpcRegister {
private int port;
public RpcRegister(int port) {
this.port = port;
}
//启动netty服务器
public void start() {
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup work = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
ServerBootstrap bootstrap = serverBootstrap.group(boss, work)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
//自定义协议解码器
pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0,
4, 0, 4));
//自定义协议编码器
pipeline.addLast(new LengthFieldPrepender(4));
//对象参数类型解码器
pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
//对象参数类型编码器
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast(new RegisterHandler("com.syx.rpc.provider"));
}
}).option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
ChannelFuture future = bootstrap.bind(port).sync();
System.out.println("RPC Register Listening in port " + port);
future.channel().closeFuture().sync();
} catch (Exception e) {
e.printStackTrace();
} finally {
boss.shutdownGracefully();
work.shutdownGracefully();
}
}
public static void main(String[] args) {
new RpcRegister(9999).start();
}
}
package com.syx.rpc.register;
import com.syx.rpc.protocol.InvokerProtocol;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.io.File;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
/**
* 1、扫描对外服务实现类所在的包,将服务的提供方对象实例保存到容器中
* 2、解析客户端传来的协议,根据协议类容调用容器中服务提供者的方法,并将结果返回
*/
public class RegisterHandler extends ChannelInboundHandlerAdapter {
private static ConcurrentHashMap<String,Object> registerMap = new ConcurrentHashMap<>();
public RegisterHandler(String providerPackage) {
List<String> classNames = scanClass(providerPackage);
doRegister(classNames);
}
/**
* 扫描指定包下的类
* @param packageName
*/
private static List<String> scanClass(String packageName){
List<String> className = new ArrayList<>();
URL url = Thread.currentThread().getContextClassLoader().getResource(packageName.replaceAll("\\.", "/"));
File file = new File(url.getFile());
if(!file.exists()){
return className;
}
Stack<File> stack = new Stack<>();
stack.push(file);
while (!stack.isEmpty()){
File pop = stack.pop();
for (File f : pop.listFiles()) {
if(f.isDirectory()){
stack.push(f);
}else {
String prefix = packageName.split("\\.")[0];
String suffix = ".class";
String path = f.getPath();
String s = path.substring(path.indexOf(prefix), path.indexOf(suffix)).replaceAll("\\\\", "\\.");
className.add(s);
}
}
}
return className;
}
/**
* 注册实现类
* @param className
*/
private static void doRegister(List<String> className){
if(className.size()==0)return;
for (String s : className) {
try{
Class<?> clazz = Class.forName(s);
Class<?>[] c = clazz.getInterfaces();
if(c.length>0){
registerMap.put(c[0].getName(),clazz.newInstance());
}
}catch (Exception e){
e.printStackTrace();
}
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
//读取请求的信息
InvokerProtocol request = (InvokerProtocol) msg;
Object result = new Object();
if(registerMap.containsKey(request.getClassName())){
//本地协议中约定的接口示例对象
Object instance = registerMap.get(request.getClassName());
Method method = instance.getClass().getDeclaredMethod(request.getMethodName(), request.getParamTypes());
//方法执行的返回值
result = method.invoke(instance, request.getParams());
}
//将返回值写回客户端channel
ctx.write(result);
ctx.flush();
ctx.close();
}
public static void main(String[] args) {
List<String> classNames = scanClass("com.syx.rpc.provider");
classNames.forEach(System.out::println);
System.out.println("***********");
doRegister(classNames);
registerMap.forEach((k,v)->{
System.out.println(k+" "+v);
});
}
}
package com.syx.rpc.consumer;
import com.syx.rpc.protocol.InvokerProtocol;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
public class RpcProxy {
public static <T> T create(Class<?> clazz){
MethodProxy methodProxy = new MethodProxy(clazz);
Class<?> [] interfaces = clazz.isInterface()?new Class[]{clazz}:clazz.getInterfaces();
T o = (T) Proxy.newProxyInstance(clazz.getClassLoader(), interfaces, methodProxy);
return o;
}
private static class MethodProxy implements InvocationHandler{
private Class<?> clazz;
private MethodProxy(Class<?> clazz) {
this.clazz = clazz;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if(Object.class.equals(method.getDeclaringClass())){
return method.invoke(this,args);
}else {
return rpcInvoker(proxy,method,args);
}
}
private Object rpcInvoker(Object proxy, Method method, Object[] args) {
InvokerProtocol request = new InvokerProtocol();
request.setMethodName(method.getName());
request.setClassName(clazz.getName());
request.setParamTypes(method.getParameterTypes());
request.setParams(args);
NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup();
RpcProxyHandler rpcProxyHandler = new RpcProxyHandler();
try {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY,true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
//自定义协议解码器
pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0,
4, 0, 4));
//自定义协议编码器
pipeline.addLast(new LengthFieldPrepender(4));
//对象参数类型解码器
pipeline.addLast("decoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
//对象参数类型编码器
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast(rpcProxyHandler);
}
});
ChannelFuture sync = bootstrap.connect("127.0.0.1", 9999).sync();
sync.channel().writeAndFlush(request).sync();
sync.channel().closeFuture().sync();
}catch (Exception e){
e.printStackTrace();
}finally {
eventLoopGroup.shutdownGracefully();
}
return rpcProxyHandler.getResponse();
}
}
}
package com.syx.rpc.consumer;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.concurrent.EventExecutorGroup;
public class RpcProxyHandler extends ChannelInboundHandlerAdapter {
private Object response;
public Object getResponse() {
return response;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
setResponse(msg);
}
public void setResponse(Object response) {
this.response = response;
}
}
package com.syx.rpc.consumer;
import com.syx.rpc.api.IRpcHelloService;
import com.syx.rpc.api.IRpcService;
import com.syx.rpc.provider.RpcServiceImpl;
public class RpcConsumer {
public static void main(String[] args) {
IRpcHelloService helloService = RpcProxy.create(IRpcHelloService.class);
String whh = helloService.hello("whh");
System.out.println(whh);
System.out.println("rpc调用");
long l1 = System.currentTimeMillis();
IRpcService rpcService = RpcProxy.create(IRpcService.class);
System.out.println(rpcService.add(1, 2));
long l2 = System.currentTimeMillis();
System.out.println( l2- l1);
System.out.println("本地调用");
long l3 = System.currentTimeMillis();
RpcServiceImpl rpcService1 = new RpcServiceImpl();
System.out.println(rpcService1.add(1, 2));
System.out.println(System.currentTimeMillis() - l3);
}
}