手把手教你 Netty 实现自定义协议!
扫描下方海报了解专栏详情
本文来源:
my.oschina.net/zhangxufeng/blog/3043768
《Java工程师面试突击(第3季)》重磅升级,由原来的70讲增至160讲,内容扩充一倍多,升级部分内容请参见文末
1. 协议规定
2. 协议实现
public class Message {
  private int magicNumber;
  private byte mainVersion;
  private byte subVersion;
  private byte modifyVersion;
  private String sessionId;
  private MessageTypeEnum messageType;
  private Map<String, String> attachments = new HashMap<>();
  private String body;
  public Map<String, String> getAttachments() {
    return Collections.unmodifiableMap(attachments);
  }
  public void setAttachments(Map<String, String> attachments) {
    this.attachments.clear();
    if (null != attachments) {
      this.attachments.putAll(attachments);
    }
  }
  public void addAttachment(String key, String value) {
    attachments.put(key, value);
  }
  // getter and setter...
}
public enum MessageTypeEnum {
  REQUEST((byte)1), RESPONSE((byte)2), PING((byte)3), PONG((byte)4), EMPTY((byte)5);
  private byte type;
  MessageTypeEnum(byte type) {
    this.type = type;
  }
  public int getType() {
    return type;
  }
  public static MessageTypeEnum get(byte type) {
    for (MessageTypeEnum value : values()) {
      if (value.type == type) {
        return value;
      }
    }
    throw new RuntimeException("unsupported type: " + type);
  }
}
public class MessageEncoder extends MessageToByteEncoder<Message> {
  @Override
  protected void encode(ChannelHandlerContext ctx, Message message, ByteBuf out) {
    // 这里会判断消息类型是不是EMPTY类型,如果是EMPTY类型,则表示当前消息不需要写入到管道中
    if (message.getMessageType() != MessageTypeEnum.EMPTY) {
      out.writeInt(Constants.MAGIC_NUMBER);	// 写入当前的魔数
      out.writeByte(Constants.MAIN_VERSION);	// 写入当前的主版本号
      out.writeByte(Constants.SUB_VERSION);	// 写入当前的次版本号
      out.writeByte(Constants.MODIFY_VERSION);	// 写入当前的修订版本号
      if (!StringUtils.hasText(message.getSessionId())) {
        // 生成一个sessionId,并将其写入到字节序列中
        String sessionId = SessionIdGenerator.generate();
        message.setSessionId(sessionId);
        out.writeCharSequence(sessionId, Charset.defaultCharset());
      }
      out.writeByte(message.getMessageType().getType());	// 写入当前消息的类型
      out.writeShort(message.getAttachments().size());	// 写入当前消息的附加参数数量
      message.getAttachments().forEach((key, value) -> {
        Charset charset = Charset.defaultCharset();
        out.writeInt(key.length());	// 写入键的长度
        out.writeCharSequence(key, charset);	// 写入键数据
        out.writeInt(value.length());	// 希尔值的长度
        out.writeCharSequence(value, charset);	// 写入值数据
      });
      if (null == message.getBody()) {
        out.writeInt(0);	// 如果消息体为空,则写入0,表示消息体长度为0
      } else {
        out.writeInt(message.getBody().length());
        out.writeCharSequence(message.getBody(), Charset.defaultCharset());
      }
    }
  }
}
public class MessageDecoder extends ByteToMessageDecoder {
  @Override
  protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) throws Exception {
    Message message = new Message();
    message.setMagicNumber(byteBuf.readInt());  // 读取魔数
    message.setMainVersion(byteBuf.readByte()); // 读取主版本号
    message.setSubVersion(byteBuf.readByte()); // 读取次版本号
    message.setModifyVersion(byteBuf.readByte());	// 读取修订版本号
    CharSequence sessionId = byteBuf.readCharSequence(
        Constants.SESSION_ID_LENGTH, Charset.defaultCharset());	// 读取sessionId
    message.setSessionId((String)sessionId);
    message.setMessageType(MessageTypeEnum.get(byteBuf.readByte()));	// 读取当前的消息类型
    short attachmentSize = byteBuf.readShort();	// 读取附件长度
    for (short i = 0; i < attachmentSize; i++) {
      int keyLength = byteBuf.readInt();	// 读取键长度和数据
      CharSequence key = byteBuf.readCharSequence(keyLength, Charset.defaultCharset());
      int valueLength = byteBuf.readInt();	// 读取值长度和数据
      CharSequence value = byteBuf.readCharSequence(valueLength, Charset.defaultCharset());
      message.addAttachment(key.toString(), value.toString());
    }
    int bodyLength = byteBuf.readInt();	// 读取消息体长度和数据
    CharSequence body = byteBuf.readCharSequence(bodyLength, Charset.defaultCharset());
    message.setBody(body.toString());
    out.add(message);
  }
}
// 服务端消息处理器
public class ServerMessageHandler extends SimpleChannelInboundHandler<Message> {
  // 获取一个消息处理器工厂类实例
  private MessageResolverFactory resolverFactory = MessageResolverFactory.getInstance();
  @Override
  protected void channelRead0(ChannelHandlerContext ctx, Message message) throws Exception {
    Resolver resolver = resolverFactory.getMessageResolver(message);	// 获取消息处理器
    Message result = resolver.resolve(message);	// 对消息进行处理并获取响应数据
    ctx.writeAndFlush(result);	// 将响应数据写入到处理器中
  }
  @Override
  public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
    resolverFactory.registerResolver(new RequestMessageResolver());	// 注册request消息处理器
    resolverFactory.registerResolver(new ResponseMessageResolver());// 注册response消息处理器
    resolverFactory.registerResolver(new PingMessageResolver());	// 注册ping消息处理器
    resolverFactory.registerResolver(new PongMessageResolver());	// 注册pong消息处理器
  }
}
// 客户端消息处理器
public class ClientMessageHandler extends ServerMessageHandler {
  // 创建一个线程,模拟用户发送消息
  private ExecutorService executor = Executors.newSingleThreadExecutor();
  @Override
  public void channelActive(ChannelHandlerContext ctx) throws Exception {
    // 对于客户端,在建立连接之后,在一个独立线程中模拟用户发送数据给服务端
    executor.execute(new MessageSender(ctx));
  }
  /**
   * 这里userEventTriggered()主要是在一些用户事件触发时被调用,这里我们定义的事件是进行心跳检测的
   * ping和pong消息,当前触发器会在指定的触发器指定的时间返回内如果客户端没有被读取消息或者没有写入
   * 消息到管道,则会触发当前方法
   */
  @Override
  public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
    if (evt instanceof IdleStateEvent) {
      IdleStateEvent event = (IdleStateEvent) evt;
      if (event.state() == IdleState.READER_IDLE) {
        // 一定时间内,当前服务没有发生读取事件,也即没有消息发送到当前服务来时,
        // 其会发送一个Ping消息到服务器,以等待其响应Pong消息
        Message message = new Message();
        message.setMessageType(MessageTypeEnum.PING);
        ctx.writeAndFlush(message);
      } else if (event.state() == IdleState.WRITER_IDLE) {
        // 如果当前服务在指定时间内没有写入消息到管道,则关闭当前管道
        ctx.close();
      }
    }
  }
  private static final class MessageSender implements Runnable {
    private static final AtomicLong counter = new AtomicLong(1);
    private volatile ChannelHandlerContext ctx;
    public MessageSender(ChannelHandlerContext ctx) {
      this.ctx = ctx;
    }
    @Override
    public void run() {
      try {
        while (true) {
          // 模拟随机发送消息的过程
          TimeUnit.SECONDS.sleep(new Random().nextInt(3));
          Message message = new Message();
          message.setMessageType(MessageTypeEnum.REQUEST);
          message.setBody("this is my " + counter.getAndIncrement() + " message.");
          message.addAttachment("name", "xufeng");
          ctx.writeAndFlush(message);
        }
      } catch (InterruptedException e) {
        e.printStackTrace();
      }
    }
  }
}
public final class MessageResolverFactory {
  // 创建一个工厂类实例
  private static final MessageResolverFactory resolverFactory = new MessageResolverFactory();
  private static final List<Resolver> resolvers = new CopyOnWriteArrayList<>();
  private MessageResolverFactory() {}
  // 使用单例模式实例化当前工厂类实例
  public static MessageResolverFactory getInstance() {
    return resolverFactory;
  }
  public void registerResolver(Resolver resolver) {
    resolvers.add(resolver);
  }
  // 根据解码后的消息,在工厂类处理器中查找可以处理当前消息的处理器
  public Resolver getMessageResolver(Message message) {
    for (Resolver resolver : resolvers) {
      if (resolver.support(message)) {
        return resolver;
      }
    }
    throw new RuntimeException("cannot find resolver, message type: " + message.getMessageType());
  }
}
// request类型的消息
public class RequestMessageResolver implements Resolver {
  private static final AtomicInteger counter = new AtomicInteger(1);
  @Override
  public boolean support(Message message) {
    return message.getMessageType() == MessageTypeEnum.REQUEST;
  }
  @Override
  public Message resolve(Message message) {
    // 接收到request消息之后,对消息进行处理,这里主要是将其打印出来
    int index = counter.getAndIncrement();
    System.out.println("[trx: " + message.getSessionId() + "]"
        + index + ". receive request: " + message.getBody());
    System.out.println("[trx: " + message.getSessionId() + "]"
        + index + ". attachments: " + message.getAttachments());
    // 处理完成后,生成一个响应消息返回
    Message response = new Message();
    response.setMessageType(MessageTypeEnum.RESPONSE);
    response.setBody("nice to meet you too!");
    response.addAttachment("name", "xufeng");
    response.addAttachment("hometown", "wuhan");
    return response;
  }
}
// 响应消息处理器
public class ResponseMessageResolver implements Resolver {
  private static final AtomicInteger counter = new AtomicInteger(1);
  @Override
  public boolean support(Message message) {
    return message.getMessageType() == MessageTypeEnum.RESPONSE;
  }
  @Override
  public Message resolve(Message message) {
    // 接收到对方服务的响应消息之后,对响应消息进行处理,这里主要是将其打印出来
    int index = counter.getAndIncrement();
    System.out.println("[trx: " + message.getSessionId() + "]"
        + index + ". receive response: " + message.getBody());
    System.out.println("[trx: " + message.getSessionId() + "]"
        + index + ". attachments: " + message.getAttachments());
    // 响应消息不需要向对方服务再发送响应,因而这里写入一个空消息
    Message empty = new Message();
    empty.setMessageType(MessageTypeEnum.EMPTY);
    return empty;
  }
}
// ping消息处理器
public class PingMessageResolver implements Resolver {
  @Override
  public boolean support(Message message) {
    return message.getMessageType() == MessageTypeEnum.PING;
  }
  @Override
  public Message resolve(Message message) {
    // 接收到ping消息后,返回一个pong消息返回
    System.out.println("receive ping message: " + System.currentTimeMillis());
    Message pong = new Message();
    pong.setMessageType(MessageTypeEnum.PONG);
    return pong;
  }
}
// pong消息处理器
public class PongMessageResolver implements Resolver {
  @Override
  public boolean support(Message message) {
    return message.getMessageType() == MessageTypeEnum.PONG;
  }
  @Override
  public Message resolve(Message message) {
    // 接收到pong消息后,不需要进行处理,直接返回一个空的message
    System.out.println("receive pong message: " + System.currentTimeMillis());
    Message empty = new Message();
    empty.setMessageType(MessageTypeEnum.EMPTY);
    return empty;
  }
}
// 服务端
public class Server {
  public static void main(String[] args) {
    EventLoopGroup bossGroup = new NioEventLoopGroup();
    EventLoopGroup workerGroup = new NioEventLoopGroup();
    try {
      ServerBootstrap bootstrap = new ServerBootstrap();
      bootstrap.group(bossGroup, workerGroup)
          .channel(NioServerSocketChannel.class)
          .option(ChannelOption.SO_BACKLOG, 1024)
          .handler(new LoggingHandler(LogLevel.INFO))
          .childHandler(new ChannelInitializer<SocketChannel>() {
            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
              ChannelPipeline pipeline = ch.pipeline();	
              // 添加用于处理粘包和拆包问题的处理器
              pipeline.addLast(new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4));
              pipeline.addLast(new LengthFieldPrepender(4));
              // 添加自定义协议消息的编码和解码处理器
              pipeline.addLast(new MessageEncoder());
              pipeline.addLast(new MessageDecoder());
              // 添加具体的消息处理器
              pipeline.addLast(new ServerMessageHandler());
            }
          });
      ChannelFuture future = bootstrap.bind(8585).sync();
      future.channel().closeFuture().sync();
    } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      bossGroup.shutdownGracefully();
      workerGroup.shutdownGracefully();
    }
  }
}
public class Client {
  public static void main(String[] args) {
    NioEventLoopGroup group = new NioEventLoopGroup();
    Bootstrap bootstrap = new Bootstrap();
    try {
      bootstrap.group(group)
          .channel(NioSocketChannel.class)
          .option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
          .handler(new ChannelInitializer<SocketChannel>() {
            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
              ChannelPipeline pipeline = ch.pipeline();
              // 添加用于解决粘包和拆包问题的处理器
              pipeline.addLast(new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4));
              pipeline.addLast(new LengthFieldPrepender(4));
              // 添加用于进行心跳检测的处理器
              pipeline.addLast(new IdleStateHandler(1, 2, 0));
              // 添加用于根据自定义协议将消息与字节流进行相互转换的处理器
              pipeline.addLast(new MessageEncoder());
              pipeline.addLast(new MessageDecoder());
              // 添加客户端消息处理器
              pipeline.addLast(new ClientMessageHandler());
            }
          });
      ChannelFuture future = bootstrap.connect("127.0.0.1", 8585).sync();
      future.channel().closeFuture().sync();
    } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      group.shutdownGracefully();
    }
  }
}
// 客户端
receive pong message: 1555123429356
[trx: d05024d2]1. receive response: nice to meet you too!
[trx: d05024d2]1. attachments: {hometown=wuhan, name=xufeng}
[trx: 66ee1438]2. receive response: nice to meet you too!
// 服务器
receive ping message: 1555123432279
[trx: f582444f]4. receive request: this is my 4 message.
[trx: f582444f]4. attachments: {name=xufeng}
3. 小结
END
《Java工程师面试突击第三季》加餐部分大纲:(注:1-66讲的大纲请扫描文末二维码,在课程详情页获取)
详细的课程内容,大家可以扫描下方二维码了解:
