
手写dubbo 10-基于netty实现RPC



    首先科普一下RPC三个字母,即Remote Procedure Call。简单来说就是从一台机器(客户端)上通过参数传递的方式调用另一台机器(服务器)上的一个函数或方法(可以统称为服务)并得到返回的结果。





  1. ComputerA将自己的需要调用的方法和参数准备封装好。

  2. 按照约定的方式,将封装好的参数传给ComputerB

  3. ComputerB收到约定的数据后,解析获得ComputerA需要调用的方法和参数。

  4. ComputerB按照ComputerA给的数据,执行对应的方法。

  5. ComputerB将执行结果按照约定返回ComputerA。







<dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> <version>4.1.36.Final</version></dependency>


@FarSPI("netty")public interface IProviderServer { void start(String selfAddress);}
@FarSPI("netty")public interface IConsumerServer { Object execute(String address, RequestDTO requestDTO);}




@Retention(RetentionPolicy.RUNTIME)@Target({ElementType.TYPE})public @interface Provider { Class interfaceClazz();
String name() default "";}


public class Container { private static final Logger logger = LoggerFactory.getLogger(Container.class); private static IRegistrar registrar = RegistrarFactory.getRegistrar(); private static Map<String, Object> providers = new HashMap<String, Object>();
static { Reflections reflections = new Reflections(new ConfigurationBuilder() .setUrls(ClasspathHelper.forPackage("com.ofcoder")) .setScanners(new TypeAnnotationsScanner())); Set<Class<?>> classes = reflections.getTypesAnnotatedWith(Provider.class, true); for (Class<?> clazz : classes) { try { Provider annotation = clazz.getAnnotation(Provider.class); Object provider = clazz.newInstance(); String canonicalName = annotation.interfaceClazz().getCanonicalName();
// 保存到本地容器 providers.put(canonicalName, provider); } catch (Exception e) { logger.error(e.getMessage(), e); } } }
public static void registerSelf(String selfAddress){ for (String service : providers.keySet()) { registrar.register(selfAddress, service); } }
public static Map<String, Object> getProviders() { return providers; }}


public class NettyProviderHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(NettyProviderHandler.class);
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { super.channelRead(ctx, msg); RequestDTO requestDTO = (RequestDTO) msg; Object result = new Object();
logger.info("receive request.. {}", requestDTO); if (Container.getProviders().containsKey(requestDTO.getClassName())) { Object provider = Container.getProviders().get(requestDTO.getClassName());
Class<?> providerClazz = provider.getClass(); Method method = providerClazz.getMethod(requestDTO.getMethodName(), requestDTO.getTypes()); // 反射执行指定的方法 result = method.invoke(provider, requestDTO.getParams()); }
// 将结果输出到消费端 ctx.write(result); ctx.flush(); ctx.close(); }}


public class NettyProviderServer implements IProviderServer { private static final Logger logger = LoggerFactory.getLogger(NettyProviderServer.class);
public void start(String selfAddress) { Container.registerSelf(selfAddress);
String[] addrs = selfAddress.split(":"); String ip = addrs[0]; Integer port = Integer.parseInt(addrs[1]);
publisher(ip, port); }
private void publisher(String ip, Integer port) { // 启动服务 try { EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup();
ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .childHandler(new ChannelInitializer<Channel>() { @Override protected void initChannel(Channel channel) throws Exception { ChannelPipeline pipeline = channel.pipeline(); pipeline.addLast(new ObjectEncoder()); pipeline.addLast(new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(NettyProviderServer.class.getClassLoader()))); pipeline.addLast(new NettyProviderHandler()); } }).option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE, true);
ChannelFuture future = bootstrap.bind(ip, port).sync(); logger.info("netty server is started..."); future.channel().closeFuture().sync(); } catch (Exception e) { logger.error(e.getMessage(), e); } }}



public class NettyConsumerServer implements IConsumerServer { private static final Logger logger = LoggerFactory.getLogger(NettyConsumerServer.class);
public Object execute(String serivceAddress, RequestDTO requestDTO) { String[] addrs = serivceAddress.split(":"); String host = addrs[0]; Integer port = Integer.parseInt(addrs[1]);
final NettyConsumerHandler consumerHandler = new NettyConsumerHandler(); EventLoopGroup group = new NioEventLoopGroup(); try { Bootstrap bootstrap = new Bootstrap(); bootstrap.group(group) .channel(NioSocketChannel.class) .option(ChannelOption.TCP_NODELAY, true) .handler(new ChannelInitializer<Channel>() { @Override protected void initChannel(Channel channel) throws Exception { ChannelPipeline pipeline = channel.pipeline(); pipeline.addLast( new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(ConsumerProxy.class.getClassLoader()))); pipeline.addLast( new ObjectEncoder()); pipeline.addLast(consumerHandler); } }); ChannelFuture future = bootstrap.connect(host, port).sync();
Channel channel = future.channel(); channel.writeAndFlush(requestDTO); logger.info("send request..., {}", requestDTO); channel.closeFuture().sync(); } catch (Exception e) { logger.error(e.getMessage(), e); } finally { group.shutdownGracefully(); } return consumerHandler.getResponse();


public class NettyConsumerHandler extends ChannelInboundHandlerAdapter { private Object response;
public Object getResponse() { return response; }
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { response = msg; }}



public class RpcFactory { public static IConsumerServer getConsumerService() { String protocol = Property.Rpc.protocol; IConsumerServer extension = ExtensionLoader.getExtensionLoader(IConsumerServer.class).getExtension(protocol); return extension; }
public static IProviderServer getProviderServer() { String protocol = Property.Rpc.protocol; IProviderServer extension = ExtensionLoader.getExtensionLoader(IProviderServer.class).getExtension(protocol); return extension; }}




@Testpublic void provider() throws IOException { IProviderServer providerServer = RpcFactory.getProviderServer(); providerServer.start(""); System.in.read();}
@Testpublic void consumer() { IConsumerServer consumerService = RpcFactory.getConsumerService(); Object execute = consumerService.execute("", new RequestDTO());}


...main INFO netty.NettyProviderServer: netty server is started...nioEventLoopGroup-3-1 INFO netty.NettyProviderHandler: receive request.. RequestDTO{className='null', methodName='null', types=null, params=null}



    整个过程可以总结为:首先服务消费者通过代理对象 Proxy 发起远程调用,接着通过网络客户端 Client 将编码后的请求发送给服务提供方的网络层上,也就是 Server。Server 在收到请求后,首先要做的事情是对数据包进行解码。然后将解码后的请求发送至分发器 Dispatcher,再由分发器将请求派发到指定的线程池上,最后由线程池调用具体的服务。这就是一个远程调用请求的发送与接收过程。



Proxy.greet -> InvokerInvocationHandler.invoke -> MockClusterInvoker.invoke -> ... -> AbstractInvoker.invoke -> DubboInvoker.doInvoke


protected Result doInvoke(final Invocation invocation) throws Throwable { RpcInvocation inv = (RpcInvocation) invocation; final String methodName = RpcUtils.getMethodName(invocation); inv.setAttachment(PATH_KEY, getUrl().getPath()); inv.setAttachment(VERSION_KEY, version);
ExchangeClient currentClient; if (clients.length == 1) { currentClient = clients[0]; } else { currentClient = clients[index.getAndIncrement() % clients.length]; } try { // 是否有返回值,true表示没有返回值。 boolean isOneway = RpcUtils.isOneway(getUrl(), invocation); int timeout = getUrl().getMethodParameter(methodName, TIMEOUT_KEY, DEFAULT_TIMEOUT); if (isOneway) { boolean isSent = getUrl().getMethodParameter(methodName, Constants.SENT_KEY, false); // 发送请求 currentClient.send(inv, isSent); // 清空future RpcContext.getContext().setFuture(null); // 不用关注返回值,返回默认RpcResult return AsyncRpcResult.newDefaultAsyncResult(invocation); } else { //该类实现Future接口,用于实现异步 AsyncRpcResult asyncRpcResult = new AsyncRpcResult(inv); // 发起调用,也返回的Future对象 CompletableFuture<Object> responseFuture = currentClient.request(inv, timeout); // 当发起调用的Future完成后,会通知到asyncRpcResult asyncRpcResult.subscribeTo(responseFuture); RpcContext.getContext().setFuture(new FutureAdapter(asyncRpcResult)); return asyncRpcResult; } } catch (TimeoutException e) { ... } catch (RemotingException e) { ... }}


ReferenceCountExchangeClient.request -> HeaderExchangeClient.request -> HeaderExchangeChannel.request -> AbstractClient.send -> NettyChannel.send -> NioClientSocketChannel.write


public void send(Object message, boolean sent) throws RemotingException { super.send(message, sent);
boolean success = true; int timeout = 0; try { // 发送消息(包含请求和响应消息) ChannelFuture future = channel.write(message); // sent 的值源于 <dubbo:method sent="true/false" /> 中 sent 的配置值,有两种配置值: // 1. true: 等待消息发出,消息发送失败将抛出异常 // 2. false: 不等待消息发出,将消息放入 IO 队列,即刻返回 // 默认情况下 sent = false; if (sent) { timeout = getUrl().getPositiveParameter(Constants.TIMEOUT_KEY, Constants.DEFAULT_TIMEOUT); // 等待消息发出,若在规定时间没能发出,success 会被置为 false success = future.await(timeout); } Throwable cause = future.getCause(); if (cause != null) { throw cause; } } catch (Throwable e) { throw new RemotingException(this, "Failed to send message ..."); }
// 若 success 为 false,这里抛出异常 if (!success) { throw new RemotingException(this, "Failed to send message ..."); }}


| 策略 | 用途 | | - | - | | all | 所有消息都派发到线程池,包括请求,响应,连接事件,断开事件等 | direct | 所有消息都不派发到线程池,全部在 IO 线程上直接执行 | message| 只有请求和响应消息派发到线程池,其它消息均在 IO 线程上执行 | execution| 只有请求消息派发到线程池,不含响应。其它消息均在 IO 线程上执行 | connection | 在 IO 线程上,将连接断开事件放入队列,有序逐个执行,其它消息派发到线程池


NettyHandler#messageReceived -> AbstractPeer#received —> MultiMessageHandler#received —> HeartbeatHandler#received —> ALLChannelHandler#received —> ExecutorService#execute


public class ChannelEventRunnable implements Runnable { @Override public void run() { // 检测通道状态,对于请求或响应消息,此时 state = RECEIVED if (state == ChannelState.RECEIVED) { try { handler.received(channel, message); } catch (Exception e) { logger.warn("ChannelEventRunnable handle " + state + " operation error, channel is " + channel + ", message is " + message, e); } } else { switch (state) { case CONNECTED: ... break; case DISCONNECTED: ... break; case SENT: ... break; case CAUGHT: ... break; default: logger.warn("unknown state: " + state + ", message is " + message); } }

    这里多说一句,先用if判断出现频率比较高的消息类型,然后用switch处理其他类型,不用把频率较高的类型和普通类型同级判断,以此提高效率。我们开发过程中也可借鉴这一点。    ChannelEventRunnable作用类似于路由,将消息分别交给各自的ChannelHandler去处理,这里的对象为DecodeHandler,该Handler就是对Request或Response进行解码后,继续传递到HeaderExchangeHandler。

public class HeaderExchangeHandler implements ChannelHandlerDelegate { @Override public void received(Channel channel, Object message) throws RemotingException { channel.setAttribute(KEY_READ_TIMESTAMP, System.currentTimeMillis()); final ExchangeChannel exchangeChannel = HeaderExchangeChannel.getOrAddChannel(channel); try { if (message instanceof Request) { // handle request. Request request = (Request) message; if (request.isEvent()) { handlerEvent(channel, request); } else { // 是否为单/双向调用,判断是否需要接收返回结果 if (request.isTwoWay()) { handleRequest(exchangeChannel, request); } else { handler.received(exchangeChannel, request.getData()); } } } else if (message instanceof Response) { handleResponse(channel, (Response) message); } else if (message instanceof String) { // telnet 相关 ... } else { handler.received(exchangeChannel, message); } } finally { HeaderExchangeChannel.removeChannelIfDisconnected(channel); } }
Response handleRequest(ExchangeChannel channel, Request req) throws RemotingException { Response res = new Response(req.getId(), req.getVersion()); // 检测请求是否合法,不合法则返回状态码为 BAD_REQUEST 的响应 if (req.isBroken()) { Object data = req.getData();
String msg; if (data == null) msg = null; else if (data instanceof Throwable) msg = StringUtils.toString((Throwable) data); else msg = data.toString(); res.setErrorMessage("Fail to decode request due to: " + msg); // 设置 BAD_REQUEST 状态 res.setStatus(Response.BAD_REQUEST);
return res; } // 获取 data 字段值,也就是 RpcInvocation 对象 Object msg = req.getData(); try { // 继续向下调用 Object result = handler.reply(channel, msg); // 设置 OK 状态码 res.setStatus(Response.OK); // 设置调用结果 res.setResult(result); } catch (Throwable e) { // 若调用过程出现异常,则设置 SERVICE_ERROR,表示服务端异常 res.setStatus(Response.SERVICE_ERROR); res.setErrorMessage(StringUtils.toString(e)); } return res; }}


public class DubboProtocol extends AbstractProtocol { private ExchangeHandler requestHandler = new ExchangeHandlerAdapter() {
@Override public Object reply(ExchangeChannel channel, Object message) throws RemotingException { if (message instanceof Invocation) { Invocation inv = (Invocation) message; // 获取 Invoker 实例 Invoker<?> invoker = getInvoker(channel, inv); if (Boolean.TRUE.toString().equals(inv.getAttachments().get(IS_CALLBACK_SERVICE_INVOKE))) { // 回调相关,忽略 } RpcContext.getContext().setRemoteAddress(channel.getRemoteAddress()); // 通过 Invoker 调用具体的服务 return invoker.invoke(inv); } throw new RemotingException(channel, "Unsupported request: ..."); } ... }


public abstract class AbstractProxyInvoker<T> implements Invoker<T> { @Override public Result invoke(Invocation invocation) throws RpcException { try { Object value = doInvoke(proxy, invocation.getMethodName(), invocation.getParameterTypes(), invocation.getArguments());
// 将结果封装到AsyncRpcResult,然后返回 ... return asyncRpcResult; } catch (InvocationTargetException e) { ... return AsyncRpcResult.newDefaultAsyncResult(null, e.getTargetException(), invocation); } catch (Throwable e) { throw new RpcException("Failed to invoke remote proxy method " + invocation.getMethodName() + " to " + getUrl() + ", cause: " + e.getMessage(), e); } }}

    剩余的最后一个doInvoke,是一个抽象方法,由子类实现,而Invoker实现类是由JavassistProxyFactory 动态生成,具体可查看JavassistProxyFactory.getInvoker()方法。最后生成的代理类逻辑如下:

/** Wrapper0 是在运行时生成的,可使用 Arthas 进行反编译 */public class Wrapper0 extends Wrapper implements ClassGenerator.DC { public static String[] pns; public static Map pts; public static String[] mns; public static String[] dmns; public static Class[] mts0;
// 省略其他方法
public Object invokeMethod(Object object, String string, Class[] arrclass, Object[] arrobject) throws InvocationTargetException { DemoService demoService; try { // 类型转换 demoService = (DemoService)object; } catch (Throwable throwable) { throw new IllegalArgumentException(throwable); } try { // 根据方法名调用指定的方法 if ("sayHello".equals(string) && arrclass.length == 1) { return demoService.sayHello((String)arrobject[0]); } } catch (Throwable throwable) { throw new InvocationTargetException(throwable); } throw new NoSuchMethodException(new StringBuffer().append("Not found method \"").append(string).append("\" in class com.alibaba.dubbo.demo.DemoService.").toString()); }}


并发笔记 发起了一个读者讨论 说出你的故事......