gitee:https://gitee.com/niumazlb/mynettydemo
github:https://github.com/ishuaige/mynettydemo

服务端

这里用 netty+WebSocket 举例
首先我们需要一个服务端 WebSocketServer
创建服务端大致步骤:

  1. 创建 ServerBootstrap - bootstrap
  2. bootstrap 关联两个事件循环组(EventLoopGroup) boss 和 work
  3. bootstrap 配置项 option
  4. bootstrap 绑定 channel
  5. bootstrap 绑定 handler,这里可以做一个封装
  6. bootstrap 绑定 绑定端口
  7. bootstrap 绑定监听关闭连接事件,处理 boss 和 work 的关闭

以上 2-5 顺序随意

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import com.niuma.mynetty.server.handler.MyWebSocketChannelHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;

/**
* @author niumazlb
* @create 2023-01-09 20:30
*/
@ChannelHandler.Sharable
@Component
public class WebSocketServer {

EventLoopGroup boss = new NioEventLoopGroup();
EventLoopGroup work = new NioEventLoopGroup();

@Resource
MyWebSocketChannelHandler myWebSocketChannelHandler;

public void run() {
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(boss, work);
bootstrap.option(ChannelOption.SO_BACKLOG, 1024);
bootstrap.channel(NioServerSocketChannel.class);
//将channelHandler做一个封装
bootstrap.childHandler(myWebSocketChannelHandler);
Channel channel = bootstrap.bind(8888).sync().channel();
channel.closeFuture().addListener(future -> {
boss.shutdownGracefully();
work.shutdownGracefully();
});
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}

注意 Handler 的继承喔

  • 这里的 Handler 是处理的 websocket 的,其他需要具体情况具体分析
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package com.niuma.mynetty.server.handler;

import com.niuma.mynetty.config.NettyConfig;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;


/**
* 用于初始化注册 Channel 中的各个 Handler
* @author niumazlb
* @create 2023-01-09 20:35
*/
@Component
public class MyWebSocketChannelHandler extends ChannelInitializer<SocketChannel> {
@Resource
MyWebSocketHandler myWebSocketHandler ;
@Resource
RegisterHandler registerHandler ;
@Resource
SingleMessageHandler singleMessageHandler ;
@Override
protected void initChannel(SocketChannel ch) throws Exception {
//webSocket连接基本就是以下的几个handler
//http服务端的解码器
ch.pipeline().addLast("http-codec", new HttpServerCodec())
//日志
.addLast("log",new LoggingHandler(LogLevel.DEBUG))
/*
通过 HttpObjectAggregator 可以把 HttpMessage 和 HttpContent 聚合成一个 FullHttpRequest,
并定义可以接受的数据大小,在文件上传时,可以支持params+multipart
*/
.addLast("aggregator",new HttpObjectAggregator(65536)) //httpContent消息聚合
//块写入写出Handler
.addLast("http-chunked",new ChunkedWriteHandler()) // HttpContent 压缩
/*
netty内置的WebSocketServerProtocolHandler作为Websocket协议的主要处理器
是处理HTTP相关的数据的
*/
.addLast("protocolHandler",new WebSocketServerProtocolHandler("/websocket"))

// 到此数据都处理完了,后面就是我们定义的处理器了==============================
// 用于消息的分发,接收 WebSocketFrame 客户端传来的帧
.addLast(myWebSocketHandler)
// 处理各种的消息
.addLast(registerHandler)
.addLast(singleMessageHandler);



}
//客户端与服务端创建连接
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
NettyConfig.group.add(ctx.channel());
System.out.println("客户端与服务端连接开启....");
}

//客户端与服务端断开连接
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
NettyConfig.group.remove(ctx.channel());
System.out.println("客户端与服务端连接关闭....");
}

//接收结束之后 read相对于服务端
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}

// 出现异常时调用
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
ctx.close();
}

}

简单的 RPC 框架

参考黑马程序员 netty 课程

协议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package cn.itcast.protocol;

import cn.itcast.config.Config;
import cn.itcast.message.Message;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageCodec;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

@Slf4j
@ChannelHandler.Sharable
/**
* 必须和 LengthFieldBasedFrameDecoder 一起使用,确保接到的 ByteBuf 消息是完整的
*/
public class MessageCodecSharable extends MessageToMessageCodec<ByteBuf, Message> {
@Override
public void encode(ChannelHandlerContext ctx, Message msg, List<Object> outList) throws Exception {
ByteBuf out = ctx.alloc().buffer();
// 1. 4 字节的魔数
out.writeBytes(new byte[]{1, 2, 3, 4});
// 2. 1 字节的版本,
out.writeByte(1);
// 3. 1 字节的序列化方式 jdk 0 , json 1
out.writeByte(Config.getSerializerAlgorithm().ordinal());
// 4. 1 字节的指令类型
out.writeByte(msg.getMessageType());
// 5. 4 个字节
out.writeInt(msg.getSequenceId());
// 无意义,对齐填充
out.writeByte(0xff);
// 6. 获取内容的字节数组
byte[] bytes = Config.getSerializerAlgorithm().serialize(msg);
// 7. 长度
out.writeInt(bytes.length);
// 8. 写入内容
out.writeBytes(bytes);
outList.add(out);
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
int magicNum = in.readInt();
byte version = in.readByte();
byte serializerAlgorithm = in.readByte(); // 0 或 1
byte messageType = in.readByte(); // 0,1,2...
int sequenceId = in.readInt();
in.readByte();
int length = in.readInt();
byte[] bytes = new byte[length];
in.readBytes(bytes, 0, length);

// 找到反序列化算法
Serializer.Algorithm algorithm = Serializer.Algorithm.values()[serializerAlgorithm];
// 确定具体消息类型
Class<? extends Message> messageClass = Message.getMessageClass(messageType);
Message message = algorithm.deserialize(messageClass, bytes);
// log.debug("{}, {}, {}, {}, {}, {}", magicNum, version, serializerType, messageType, sequenceId, length);
// log.debug("{}", message);
out.add(message);
}

}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package cn.itcast.protocol;

import io.netty.handler.codec.LengthFieldBasedFrameDecoder;

/**
* 通过规定消息的格式,处理半包黏包编码器
*/
public class ProtocolFrameDecoder extends LengthFieldBasedFrameDecoder {

public ProtocolFrameDecoder() {
this(1024, 12, 4, 0, 0);
}

public ProtocolFrameDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}
}

服务端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package cn.itcast.server;

import cn.itcast.protocol.MessageCodecSharable;
import cn.itcast.protocol.ProtocolFrameDecoder;
import cn.itcast.server.handler.RpcRequestMessageHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class RpcServer {
public static void main(String[] args) {
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup worker = new NioEventLoopGroup();
//日志
LoggingHandler LOGGING_HANDLER = new LoggingHandler(LogLevel.DEBUG);
//消息编解码
MessageCodecSharable MESSAGE_CODEC = new MessageCodecSharable();
//针对Rpc请求消息的处理器
RpcRequestMessageHandler RPC_HANDLER = new RpcRequestMessageHandler();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.channel(NioServerSocketChannel.class);
serverBootstrap.group(boss, worker);
serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
//一定加这个处理半包黏包的处理器
ch.pipeline().addLast(new ProtocolFrameDecoder());
ch.pipeline().addLast(LOGGING_HANDLER);
ch.pipeline().addLast(MESSAGE_CODEC);
ch.pipeline().addLast(RPC_HANDLER);
}
});
Channel channel = serverBootstrap.bind(8080).sync().channel();
channel.closeFuture().sync();
} catch (InterruptedException e) {
log.error("server error", e);
} finally {
boss.shutdownGracefully();
worker.shutdownGracefully();
}
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package cn.itcast.server.handler;


import cn.itcast.message.RpcRequestMessage;
import cn.itcast.message.RpcResponseMessage;
import cn.itcast.server.service.ServicesFactory;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
* 处理Rpc请求的handler
* @author niumazlb
* @create 2023-01-08 15:37
*/

public class RpcRequestMessageHandler extends SimpleChannelInboundHandler<RpcRequestMessage> {

@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage msg) throws Exception {
String interfaceName = msg.getInterfaceName();
String methodName = msg.getMethodName();
int sequenceId = msg.getSequenceId();
Object[] parameterValue = msg.getParameterValue();
Class[] parameterTypes = msg.getParameterTypes();

RpcResponseMessage response = new RpcResponseMessage();
response.setSequenceId(sequenceId);
//通过反射来方法调用
try {
Object service = ServicesFactory.getService(Class.forName(interfaceName));
Method method = service.getClass().getMethod(methodName,parameterTypes);
Object invoke = method.invoke(service, parameterValue);
response.setReturnValue(invoke);
} catch (Exception e) {
e.printStackTrace();
response.setExceptionValue(new Exception("error:"+e.getCause().getMessage()));
}
ctx.writeAndFlush(response);
}
}

客户端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package cn.itcast.client;

import cn.itcast.client.handler.RpcResponseMessageHandler;
import cn.itcast.message.RpcRequestMessage;
import cn.itcast.protocol.MessageCodecSharable;
import cn.itcast.protocol.ProtocolFrameDecoder;
import cn.itcast.protocol.SequenceIdGenerator;
import cn.itcast.server.service.HelloService;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.DefaultPromise;

import java.lang.reflect.Proxy;

/**
* @author niumazlb
* @create 2023-01-08 15:51
*/
public class RpcClientManager {
public static void main(String[] args) {
HelloService proxyService = getProxyService(HelloService.class);
System.out.println(proxyService.sayHello("zhangsan"));
// System.out.println(proxyService.sayHello("lisi"));
// System.out.println(proxyService.sayHello("wangwu"));
// System.out.println(proxyService.sayHello("zhaoliu"));

}

public static <T> T getProxyService(Class<T> serviceClass) {
ClassLoader loader = serviceClass.getClassLoader();

Class<?>[] interfaces = new Class[]{serviceClass};
Object o = Proxy.newProxyInstance(loader, interfaces, (proxy, method, args) -> {
int sequenceId = SequenceIdGenerator.nextId();
RpcRequestMessage message = new RpcRequestMessage(
sequenceId,
serviceClass.getName(),
method.getName(),
method.getReturnType(),
method.getParameterTypes(),
args
);

getChannel().writeAndFlush(message);

DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop());
RpcResponseMessageHandler.PROMISES.put(sequenceId, promise);


promise.await();

if (promise.isSuccess()) {
return promise.getNow();
} else {
throw new RuntimeException(promise.cause());
}
});
return (T) o;
}


private static Channel channel = null;
private static final Object LOCK = new Object();

public static Channel getChannel() {
if (channel != null) {
return channel;
}
synchronized (LOCK) {
if (channel != null) {
return channel;
}
initChannel();
return channel;
}
}

private static void initChannel() {
NioEventLoopGroup group = new NioEventLoopGroup();
LoggingHandler LOGGING_HANDLER = new LoggingHandler(LogLevel.DEBUG);
MessageCodecSharable MESSAGE_CODEC = new MessageCodecSharable();
RpcResponseMessageHandler RPC_HANDLER = new RpcResponseMessageHandler();
Bootstrap bootstrap = new Bootstrap();
bootstrap.channel(NioSocketChannel.class);
bootstrap.group(group);
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProtocolFrameDecoder());
ch.pipeline().addLast(LOGGING_HANDLER);
ch.pipeline().addLast(MESSAGE_CODEC);
ch.pipeline().addLast(RPC_HANDLER);
}
});
try {
channel = bootstrap.connect("localhost", 8080).sync().channel();
channel.closeFuture().addListener(future -> {
group.shutdownGracefully();
});
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package cn.itcast.client.handler;

import cn.itcast.message.RpcResponseMessage;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@ChannelHandler.Sharable
public class RpcResponseMessageHandler extends SimpleChannelInboundHandler<RpcResponseMessage> {

// 序号 用来接收结果的 promise 对象
public static final Map<Integer, Promise<Object>> PROMISES = new ConcurrentHashMap<>();

@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
log.debug("{}", msg);
int sequenceId = msg.getSequenceId();
Object returnValue = msg.getReturnValue();
Exception exceptionValue = msg.getExceptionValue();
Promise<Object> promise = PROMISES.remove(sequenceId);
if(promise != null){
if(exceptionValue != null){
promise.setFailure(exceptionValue);
}else {
promise.setSuccess(returnValue);
}
}
}
}

消息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package cn.itcast.message;

import lombok.Data;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

@Data
public abstract class Message implements Serializable {

public static Class<?> getMessageClass(int messageType) {
return messageClasses.get(messageType);
}

private int sequenceId;

private int messageType;

public abstract int getMessageType();

/**
* 请求类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_REQUEST = 101;
/**
* 响应类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_RESPONSE = 102;
private static final Map<Integer, Class<?>> messageClasses = new HashMap<>();

static {
messageClasses.put(RPC_MESSAGE_TYPE_REQUEST, RpcRequestMessage.class);
messageClasses.put(RPC_MESSAGE_TYPE_RESPONSE, RpcResponseMessage.class);
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package cn.itcast.message;

import lombok.Getter;
import lombok.ToString;

/**
* @author yihang
*/
@Getter
@ToString(callSuper = true)
public class RpcRequestMessage extends Message {

/**
* 调用的接口全限定名,服务端根据它找到实现
*/
private String interfaceName;
/**
* 调用接口中的方法名
*/
private String methodName;
/**
* 方法返回类型
*/
private Class<?> returnType;
/**
* 方法参数类型数组
*/
private Class[] parameterTypes;
/**
* 方法参数值数组
*/
private Object[] parameterValue;

public RpcRequestMessage(int sequenceId, String interfaceName, String methodName, Class<?> returnType, Class[] parameterTypes, Object[] parameterValue) {
super.setSequenceId(sequenceId);
this.interfaceName = interfaceName;
this.methodName = methodName;
this.returnType = returnType;
this.parameterTypes = parameterTypes;
this.parameterValue = parameterValue;
}

@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_REQUEST;
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package cn.itcast.message;

import lombok.Data;
import lombok.ToString;

/**
* @author yihang
*/
@Data
@ToString(callSuper = true)
public class RpcResponseMessage extends Message {
/**
* 返回值
*/
private Object returnValue;
/**
* 异常值
*/
private Exception exceptionValue;

@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_RESPONSE;
}
}