使用netty开发长连接协议


Netty 是什么

Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients.

官方的解释最精准了,其中最吸引人的就是高性能了。但是很多人会有这样的疑问:直接用 NIO 实现的话,一定会更快吧?就像我直接手写 JDBC 虽然代码量大了点,但是一定比 iBatis 快!

但是,如果了解 Netty 后你才会发现,这个还真不一定!

利用 Netty 而不用 NIO 直接写的优势有这些:

  • 高性能高扩展的架构设计,大部分情况下你只需要关注业务而不需要关注架构
  • Zero-Copy 技术尽量减少内存拷贝
  • 为 Linux 实现 Native 版 Socket
  • 写同一份代码,兼容 java 1.7 的 NIO2 和 1.7 之前版本的 NIO
  • Pooled Buffers 大大减轻 Buffer 和释放 Buffer 的压力
  • ……

实践

使用netty开发长连接服务器,支持心跳,websocket,自定义数据传输格式和对数据进行校验

依赖

<dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
    <optional>true</optional>
</dependency>
<dependency>
    <groupId>cn.hutool</groupId>
    <artifactId>hutool-all</artifactId>
    <version>5.7.22</version>
</dependency>
<!--	Netty	-->
<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
    <version>4.1.74.Final</version>
</dependency>

公共实体

消息的定义(Message),消息类型(MessageType),请求头(MessageHead),心跳实体(UavHeartbeat),消息编码(MessageEncoder),消息解码(MeaasgeDecoder),数据加密工具类(Md5Utils),消息序列化工具类(MsgUtils)

自定义协议 数据包格式:

/**
 * -----------------------------------
 * | 协议开始标志 | 包长度|消息类型(定长4个字节)|令牌 (定长50个字节)|令牌生成时间(定长30个字节)| 包内容 |
 * -----------------------------------
 * 协议头长度: 4 + 4 + 4 + 50 + 30
 *
 * 令牌生成规则
 *  协议开始标志 +包长度+消息类型+令牌生成时间+包内容+服务器与客户端约定的秘钥
 *
 */

消息的定义(Message)

消息实体类为netty中传输的数据对象,如果不传输此对象消息在解码中会被直接丢弃

import com.example.netty.util.Md5Utils;
import lombok.Data;

import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;

/**
 * 自定义协议 数据包格式
 * -----------------------------------
 * | 协议开始标志 | 包长度|消息类型(定长4个字节)|令牌 (定长50个字节)|令牌生成时间(定长30个字节)| 包内容 |
 * -----------------------------------
 * 协议头长度: 4 + 4 + 4 + 50 + 30
 *
 * 令牌生成规则
 *  协议开始标志 +包长度+消息类型+令牌生成时间+包内容+服务器与客户端约定的秘钥
 * @author wmg
 *
 */

@Data
public class Message {

	public Message(MessageHead head,byte[] content) {
		this.Header=head;
		this.content=content;
	}
	// 协议头
	private MessageHead Header;

	// 内容
	private byte[] content;

	/**
	 * 令牌生成规则
	 * 协议开始标志 +包长度+消息类型+令牌生成时间+包内容+服务器与客户端约定的秘钥
	 * */
	public String buidToken() {
		//生成token
		DateTimeFormatter fmt = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
		String time = this.getHeader().getCreateDate().format(fmt);

		String allData = String.valueOf(this.getHeader().getHeadData());
		allData += String.valueOf(this.getHeader().getLength());
		allData += this.getHeader().getMessageType();
		allData += time;
		allData += new String(this.getContent());
		//秘钥
		allData += "hnaiot";

		return Md5Utils.strMD5(allData);

	}

	public boolean authorization(String token) {
		//表示参数被修改
		if(!token.equals(this.getHeader().getToken())) {
			return false;
		}
		//验证是否失效
		Long s = (System.currentTimeMillis() - getHeader().getCreateDate().toInstant(ZoneOffset.ofHours(8)).toEpochMilli()) / (1000 * 60);
		if(s > 60) {
			return false;
		}
		return true;
	}
}

消息类型(MessageType)

消息类型可根据自己的需求进行自定义

/**
 * 消息的类型
 * @author wmg
 */

public final class MessageType {
	/*心跳请求消息*/
	public static final Integer HEARTBEAT = 0;
	/**指令消息*/
	public static final Integer INSTRUCT = 1;
	/**字符串*/
	public static final Integer STR = 2;
	/**
	 *  * 自定义协议 数据包格式
	 *  * -----------------------------------
	 *  * | 协议开始标志 | 包长度|消息类型(定长4个字节)|令牌 (定长50个字节)|令牌生成时间(定长20个字节)| 包内容 |
	 *  * -----------------------------------
	 *  * 协议头长度: 4 + 4 + 4 + 50 + 20
	 */
	public static final Integer BASE_LENGTH = 4+4+4+50+20;
	/**消息协议头*/
	public static final Integer HEAD_DATA = 0x76;


}

心跳实体(UavHeartbeat)

心跳实体可按照自己的需求进行自定义

/**
 * 心跳内容 可根据自己的需求改或者直接使用ping,pong
 * @author wmg
 */
@Data
@Accessors(chain = true)
public class UavHeartbeat implements Serializable {
	private String uavCode;
	private String uavNumber;
	private Integer uavStatus;
}

请求头/协议头(MessageHead)

消息头包含了消息的认证信息和长度,用来认证信息的合法来源和消息的截取

**
 * 消息头
 *  消息头包含了消息的认证信息和长度,用来认证信息的合法来源和消息的截取
 * @author wmg
 */
@Data
@Accessors(chain = true)
public class MessageHead {
	//协议开始标志
	private int headData = 0X76;
	//包的长度
	private int length;
	//认证的Token,可以设置时效
	private String token;
	// 时间
	private LocalDateTime createDate;
	//消息类型
	private Integer messageType;
}

消息编码(MessageEncoder)

自定义消息编码格式

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.time.format.DateTimeFormatter;

@Component
public class MessageEncoder extends MessageToByteEncoder<Message> {
	@Override
	protected void encode(ChannelHandlerContext ctx, Message msg, ByteBuf out) throws Exception {

		// 写入开头的标志
		out.writeInt(msg.getHeader().getHeadData());
		// 写入包的的长度
		out.writeInt(msg.getContent().length);

		//写入消息类型
		out.writeInt(msg.getHeader().getMessageType());

		//写入令牌
		writeByte(out, msg.getHeader().getToken().getBytes(), 50);

		//写入令牌生成时间
		byte[] indexByte=msg.getHeader().getCreateDate().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")).getBytes();
		writeByte(out, indexByte, 20);

		// 写入消息主体
		out.writeBytes(msg.getContent());

	}

	private void writeByte(ByteBuf out, byte[] bytes, int length) {
		byte[] writeArr = new byte[length];
		/**
		 * 第一个参数 原数组 * 第二个参数 原数组位置 * 第三个参数 目标数组 * 第四个参数 目标数组位置 * 第五个参数 copy多少个长度
		 * */
		System.arraycopy(bytes, 0, writeArr, 0, Math.min(bytes.length, writeArr.length));
		out.writeBytes(writeArr);
	}

	private void writeByte(ByteBuf out, String content, int length) {
		if (StringUtils.isEmpty(content)) {
			content = "";
		}
		writeByte(out, content.getBytes(), length);
	}
}

消息解码(MeaasgeDecoder)

消息编解码的操作流程一定要一致否则将出错

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;

@Component
public class MeaasgeDecoder extends ByteToMessageDecoder {


	@Override
	protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) {
		handleHttpRequest(ctx,buffer,out);
	}

	/**
	 * 处理http请求
	 * */
	private void handleHttpRequest(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out){
		/**
		 *  * 自定义协议 数据包格式
		 *  * -----------------------------------
		 *  * | 协议开始标志 | 包长度|消息类型(定长4个字节)|令牌 (定长50个字节)|令牌生成时间(定长20个字节)| 包内容 |
		 *  * -----------------------------------
		 *  * 协议头长度: 4 + 4 + 4 + 50 + 20
		 */

		// 刻度长度必须大于基本长度
		// readableBytes 返回可被读取的字节数
		if (buffer.readableBytes() >= MessageType.BASE_LENGTH) {
			/**
			 * 粘包 发送频繁 可能多次发送黏在一起 需要考虑 不过一个客户端发送太频繁也可以推断是否是攻击
			 * 防止soket流攻击。客户端传过来的数据太大不合理
			 */
			if (buffer.readableBytes() > 1024*1024) {
				buffer.skipBytes(buffer.readableBytes());
			}
		}

		int beginIndex;//记录包开始位置
		while (true) {
			// 获取包头开始的index
			beginIndex = buffer.readerIndex();

			//如果读到开始标记位置 结束读取避免拆包和粘包 -4
			if (buffer.readInt() == MessageType.HEAD_DATA) {
				break;
			}

			//初始化读的index为0
			buffer.resetReaderIndex();
			// 当略过,一个字节之后,
			//如果当前buffer数据小于基础数据 返回等待下一次读取
			if (buffer.readableBytes() < MessageType.BASE_LENGTH) {
				return;
			}
		}
		// 消息的长度 -4
		int length = buffer.readInt();
		// 判断请求数据包数据是否到齐  74去掉消息头的长度。
		if ((buffer.readableBytes() - 74) < length) {
			//没有到齐 返回读的指针 等待下一次数据到期再读
			buffer.readerIndex(beginIndex);
			return;
		}

		//读取消息类型 -4
		int messageType = buffer.readInt();

		//读取令牌 -50
		byte[] tokenByte = new byte[50];
		buffer.readBytes(tokenByte);

		//读取令牌生成时间 -20
		byte[] createDateByte = new byte[20];
		buffer.readBytes(createDateByte);

		//读取content
		byte[] data = new byte[length];
		buffer.readBytes(data);

		MessageHead head = new MessageHead();
		String timeStr = new String(createDateByte);
		LocalDateTime dateTime = LocalDateTime.parse(timeStr.trim(), DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
		head.setHeadData(MessageType.HEAD_DATA)
				.setToken(new String(tokenByte).trim())
				.setCreateDate(dateTime)
				.setLength(length)
				.setMessageType(messageType);

		Message message = new Message(head, data);
		//认证不通过
		if (!message.authorization(message.buidToken())) {
			ctx.close();
			return;
		}
		out.add(message);
		buffer.discardReadBytes();//回收已读字节
	}
}

数据加密工具类(Md5Utils)

import java.io.FileInputStream;
import java.io.IOException;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

/**
 * Md5
 * 提供对字符串的md5-->stringMD5 
 * 提供对文件的Md5-->fileMD5  
 * 
 * 对于大文件,可以使用DigestInputStream
 */
public class Md5Utils {

    protected static char hexDigits[] = {'0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'};  
    protected static MessageDigest messageDigest = null;
    
    static{  
        try{  
            // 拿到一个MD5转换器(如果想要SHA1参数换成”SHA1”)
            messageDigest = MessageDigest.getInstance("MD5");  
        }catch (NoSuchAlgorithmException e) {  
            System.err.println(Md5Utils.class.getName()+"初始化失败,MessageDigest不支持MD5Util.");  
            e.printStackTrace();  
        }  
    }
    
    private static String bufferToHex(byte bytes[], int m, int n) {  
       StringBuffer stringbuffer = new StringBuffer(2 * n);  
       int k = m + n;  
       for (int l = m; l < k; l++) {  
        appendHexPair(bytes[l], stringbuffer);  
       }  
       return stringbuffer.toString();  
    }  
          
    private static void appendHexPair(byte bt, StringBuffer stringbuffer) {  
     
       char c0 = hexDigits[(bt & 0xf0) >> 4];  
       char c1 = hexDigits[bt & 0xf];  
       stringbuffer.append(c0);  
       stringbuffer.append(c1);  
    } 
    
    private static String bufferToHex(byte bytes[]) {  
        return bufferToHex(bytes, 0, bytes.length);  
    }
    
    /**
     * 字符串的md5加密
     * @param input
     * @return
     */
    public static String strMD5(String input) {
          // 输入的字符串转换成字节数组
         byte[] inputByteArray = input.getBytes();
         // inputByteArray是输入字符串转换得到的字节数组
         messageDigest.update(inputByteArray);
         // 转换并返回结果,也是字节数组,包含16个元素
         byte[] resultByteArray = messageDigest.digest();
         // 字符数组转换成字符串返回
         return bufferToHex(resultByteArray);
    }
    /**
     * 文件的md5加密
     * @param inputFile
     * @return
     * @throws IOException
     */
    public static String fileMD5(String inputFile) throws IOException {
          // 缓冲区大小(这个可以抽出一个参数)
          int bufferSize = 256 * 1024;
          FileInputStream fileInputStream = null;
          DigestInputStream digestInputStream = null;
          try {
             // 使用DigestInputStream
             fileInputStream = new FileInputStream(inputFile);
             digestInputStream = new DigestInputStream(fileInputStream,messageDigest);
             // read的过程中进行MD5处理,直到读完文件
             byte[] buffer =new byte[bufferSize];
             while (digestInputStream.read(buffer) > 0);
             // 获取最终的MessageDigest
             messageDigest= digestInputStream.getMessageDigest();
             // 拿到结果,也是字节数组,包含16个元素
             byte[] resultByteArray = messageDigest.digest();
             // 同样,把字节数组转换成字符串
             return bufferToHex(resultByteArray);
          } finally {
             try {
                digestInputStream.close();
             } catch (Exception e) {
             }
             try {
                fileInputStream.close();
             } catch (Exception e) {
             }
          }
       }
}

消息序列化工具类(MsgUtils)

消息的格式转换再不同客户端可能不一样(如安卓),可按照实际情况进行修改

import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.example.netty.entity.Message;
import com.example.netty.entity.MessageHead;
import com.example.netty.entity.MessageType;

import java.io.*;
import java.time.LocalDateTime;

/**
 * 消息工具类
 * @author wmg
 */
public class MsgUtils {

	/**
	 * 封装对象消息
	 * */
	public static Message getObjMessage(Object obj) throws IOException {
		MessageHead head=new MessageHead();
		byte[] content = objectToBytes(obj);

		Message msg=new Message(head,content);

		head.setCreateDate(LocalDateTime.now())
				.setMessageType(MessageType.HEARTBEAT)
				.setToken(msg.buidToken());

		return msg;
	}
	/**
	 * 封装Str消息
	 * */
	public static Message getStrMessage(String str) throws IOException {
		MessageHead head=new MessageHead();
		byte[] content = str.getBytes();

		Message msg=new Message(head,content);

		head.setCreateDate(LocalDateTime.now())
				.setMessageType(MessageType.STR)
				.setLength(content.length)
				.setToken(msg.buidToken());

		return msg;
	}

	/**
	 * 将 byte[] 转成 String
	 * */
	public static String bytesToStr(byte[] bytes){
		return new String(bytes);
	}

	/**
	 * 将 byte[] 转成 Bean
	 * */
	public static <T> T bytesToBean(byte[] bytes, Class<T> beanClass) throws Exception {
		Object obj;
		try {
			//bytearray to object
			ByteArrayInputStream bi = new ByteArrayInputStream(bytes);
			ObjectInputStream oi = new ObjectInputStream(bi);

			obj = oi.readObject();

			bi.close();
			oi.close();
		}
		catch(Exception e) {
			throw e;
		}

		return JSONUtil.toBean(JSONUtil.parseObj(obj),beanClass);
	}

	/**
	 * 将 Bean 转成 byte[]
	 * */
	public static byte[] objectToBytes(Object obj) throws IOException {
		byte[] bytes;
		try {
			ByteArrayOutputStream bo = new ByteArrayOutputStream();
			ObjectOutputStream oo = new ObjectOutputStream(bo);
			oo.writeObject(JSONUtil.parse(obj));

			bytes = bo.toByteArray();

			bo.close();
			oo.close();
		}
		catch(Exception e) {
			throw e;
		}
		return(bytes);
	}
}

服务端

服务端包含:Socket拦截器(SocketHandler),Websocket拦截器(TextWebSocketHandler),入站I/O事件处理器(BaseHanlderAdapter),Socket 初始化器(SocketInitializer),Socket服务配置(SocketServer),Socket启动器(NettyStartListener)

Socket拦截器(SocketHandler)

import com.example.netty.entity.Message;
import com.example.netty.entity.MessageHead;
import com.example.netty.entity.MessageType;
import com.example.netty.entity.UavHeartbeat;
import com.example.netty.util.MsgUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;

/**
 * Socket拦截器,用于处理客户端的行为
 *
 * @author wmg
 **/
@Slf4j
public class SocketHandler extends ChannelInboundHandlerAdapter {
	public static final ChannelGroup HTTP_CLIENTS = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

	/**
	 * 读取到客户端发来的消息
	 *
	 * @param ctx ChannelHandlerContext
	 * @param msg msg
	 * @throws Exception e
	 */
	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		log.info(ctx.channel()+"--发送的消息: " + msg);
		Message message = (Message) msg;
		MessageHead header = message.getHeader();
		// 心跳
		if (header.getMessageType().equals(MessageType.HEARTBEAT)){
			if (message.getContent().length>0){
				UavHeartbeat uavHeartbeat = MsgUtils.bytesToBean(message.getContent(),UavHeartbeat.class);
				if (uavHeartbeat != null){
					// TODO
					log.info("uavHeartbeat--->"+uavHeartbeat);
					// 心跳回应
					sendUavHeartbeatRes(ctx);
				}
			}
		}
	}



	/**
	 * 监听连接事件
	 * */
	@Override
	public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
		log.info("新的客户端链接:" + ctx.channel().id().asShortText());
		HTTP_CLIENTS.add(ctx.channel());
	}

	/**
	 * 监听离线事件
	 * */
	@Override
	public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
		log.error("客户端断开连接:{}",ctx.channel());
		super.handlerRemoved(ctx);
	}


	/**处理异常, 一般是需要关闭通道*/
	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
//		cause.printStackTrace();
		ctx.channel().close();
		HTTP_CLIENTS.remove(ctx.channel());
	}

	/**
	 * 心跳回应
	 * */
	private void sendUavHeartbeatRes(ChannelHandlerContext ctx) throws IOException {
		MessageHead head=new MessageHead();
		byte[] content = "".getBytes(StandardCharsets.UTF_8);

		Message msg=new Message(head,content);
		head.setCreateDate(LocalDateTime.now())
				.setMessageType(MessageType.HEARTBEAT)
				.setToken(msg.buidToken());

		ctx.writeAndFlush(msg);
	}
}

Websocket拦截器(TextWebSocketHandler)

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;

/**
 * 只处理WebSocketFrame
 * WebSocketFrame有6种类型:
 *      BinaryWebSocketFrame  消息传输的方式
 *      CloseWebSocketFrame  代表关闭连接的frame
 *      ContinuationWebSocketFrame  消息中多于一个frame的表示
 *      PingWebSocketFrame  PingWebSocketFrame和PongWebSocketFrame是两个特殊的frame,他们主要用来做服务器和客户端的探测
 *      PongWebSocketFrame
 *      TextWebSocketFrame  消息传输的方式
 * */
@Slf4j
@Component
public class TextWebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame>  {

    public static final ChannelGroup WS_CLIENTS = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame frame) throws Exception {
        log.info("ws:"+ctx.channel()+": " + frame.text());
        ctx.channel().writeAndFlush(new TextWebSocketFrame("服务器接收时间: " + LocalDateTime.now()));
    }
    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        log.info("ws:客户端建立连接: " + ctx.channel().id().asLongText());
        WS_CLIENTS.add(ctx.channel());
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        log.error("ws:客户端断开连接: {}" + ctx.channel().id().asLongText());
        super.handlerRemoved(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.error("ws:异常发生:"+cause);
        WS_CLIENTS.remove(ctx.channel());
        ctx.close();
    }
}

入站I/O事件处理器(BaseHanlderAdapter)

import com.example.netty.entity.MeaasgeDecoder;
import com.example.netty.entity.MessageEncoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.ByteProcessor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit;


/**
 * 通过解码,动态选择某个协议
 * @author wmg
 */
@Component
@Slf4j
public class BaseHanlderAdapter extends ChannelInboundHandlerAdapter {

    private final int maxLength = 4096;

    private final boolean failFast = false;


    private boolean discarding;
    private int discardedBytes;

    /**
     * Last scan position.
     */
    private int offset;

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof ByteBuf) {
            ByteBuf byteBuf = Unpooled.wrappedBuffer((ByteBuf) msg);
            boolean isHttp = false;
            Object obj = null;
            try {
                obj = decode(ctx, byteBuf);
            } catch (Exception e) {
                e.printStackTrace();
            }
            if (obj instanceof ByteBuf) {
                ByteBuf b = (ByteBuf) obj;
                String str = b.toString(Charset.defaultCharset());
                String[] arr = str.split("\\s+");
                String lastStr = arr[arr.length - 1];
                if (lastStr.contains("HTTP")) {
                    isHttp = true;
                }
                log.info("lineParse decode:{}", str);
            }
            log.info("{}Http协议", isHttp ? "是" : "否");
            ChannelPipeline pipeline = ctx.pipeline();
            if (isHttp) {
                //设置解码器
                pipeline.addLast(new HttpServerCodec());
                //聚合器,使用websocket会用到
               pipeline.addLast(new HttpObjectAggregator(64 * 1024));
                //用于大数据的分区传输
               pipeline.addLast(new ChunkedWriteHandler());

               pipeline.addLast(new WebSocketServerProtocolHandler("/ws"));
               pipeline.addLast(new TextWebSocketHandler());
            } else {
                //7秒没收到客户端信息 则表示客户端因为网络等原因异常关闭
               pipeline.addLast("heartbeat",new IdleStateHandler(7, 0, 0, TimeUnit.SECONDS));
                // 添加对byte数组的编解码
               pipeline.addLast("msgDecoder",new MeaasgeDecoder());
               pipeline.addLast("msgEncoder",new MessageEncoder());
                // Socket拦截器,用于处理客户端的行为
               pipeline.addLast(new SocketHandler());
            }
           pipeline.remove(this);
        } else {
            log.info("msg not instanceof ByteBuf");
        }
        super.channelRead(ctx, msg);
    }


    protected Object decode(ChannelHandlerContext ctx, ByteBuf buffer) {
        final int eol = findEndOfLine(buffer);
        if (!discarding) {
            if (eol >= 0) {
                final int length = eol - buffer.readerIndex();
                final int delimLength = buffer.getByte(eol) == '\r' ? 2 : 1;

                if (length > maxLength) {
                    buffer.readerIndex(eol + delimLength);
                    fail(ctx, length);
                    return null;
                }
                return buffer.readRetainedSlice(length + delimLength);
            } else {
                final int length = buffer.readableBytes();
                if (length > maxLength) {
                    discardedBytes = length;
                    buffer.readerIndex(buffer.writerIndex());
                    discarding = true;
                    offset = 0;
                    if (failFast) {
                        fail(ctx, "over " + discardedBytes);
                    }
                }
                return null;
            }
        } else {
            if (eol >= 0) {
                final int length = discardedBytes + eol - buffer.readerIndex();
                final int delimLength = buffer.getByte(eol) == '\r' ? 2 : 1;
                buffer.readerIndex(eol + delimLength);
                discardedBytes = 0;
                discarding = false;
                if (!failFast) {
                    fail(ctx, length);
                }
            } else {
                discardedBytes += buffer.readableBytes();
                buffer.readerIndex(buffer.writerIndex());
                // We skip everything in the buffer, we need to set the offset to 0 again.
                offset = 0;
            }
            return null;
        }
    }

    private int findEndOfLine(final ByteBuf buffer) {
        int totalLength = buffer.readableBytes();
        int i = buffer.forEachByte(buffer.readerIndex() + offset, totalLength - offset, ByteProcessor.FIND_LF);
        if (i >= 0) {
            offset = 0;
            if (i > 0 && buffer.getByte(i - 1) == '\r') {
                i--;
            }
        } else {
            offset = totalLength;
        }
        return i;
    }

    private void fail(final ChannelHandlerContext ctx, String length) {
        ctx.fireExceptionCaught(
                new TooLongFrameException(
                        "frame length (" + length + ") exceeds the allowed maximum (" + maxLength + ')'));
    }

    private void fail(final ChannelHandlerContext ctx, int length) {
        fail(ctx, String.valueOf(length));
    }
}

Socket 初始化器(SocketInitializer)

import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import org.springframework.stereotype.Component;

/**
 * Socket 初始化器,每一个Channel进来都会调用这里的 InitChannel 方法
 * @author wmg
 **/
@Component
public class SocketInitializer extends ChannelInitializer<SocketChannel> {
	@Override
	protected void initChannel(SocketChannel socketChannel) throws Exception {

		ChannelPipeline pipeline = socketChannel.pipeline();
		pipeline.addLast(new BaseHanlderAdapter());
	}
}

Socket服务配置(SocketServer)

import com.example.netty.entity.Message;
import com.example.netty.util.MsgUtils;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.io.IOException;

/**
 * @author wmg
 **/
@Slf4j
@Component
public class SocketServer {
	@Resource
	private SocketInitializer socketInitializer;

	@Getter
	private ServerBootstrap serverBootstrap;

	/**
	 * netty服务监听端口
	 */
	@Value("${netty.port:7003}")
	private int port;
	/**
	 * 主线程组数量
	 */
	@Value("${netty.bossThread:1}")
	private int bossThread;

	/**
	 * 启动netty服务器
	 */
	public void start() throws InterruptedException {
		this.init();
		this.serverBootstrap.bind(this.port).sync();
		log.info("Netty started on port: {} (TCP) with boss thread {}", this.port, this.bossThread);
	}

	/**
	 * 初始化netty配置
	 */
	private void init() {
		// 创建两个线程组,bossGroup为接收请求的线程组,一般1-2个就行
		NioEventLoopGroup bossGroup = new NioEventLoopGroup(this.bossThread);
		// 实际工作的线程组
		NioEventLoopGroup workerGroup = new NioEventLoopGroup();
		this.serverBootstrap = new ServerBootstrap();

		this.serverBootstrap
				// 两个线程组加入进来
				.group(bossGroup, workerGroup)
				// 配置为nio类型
				.channel(NioServerSocketChannel.class)
				// 加入自己的初始化器
				.childHandler(this.socketInitializer);
	}

	/**
	 * 发消息给指定客户端
	 * */
	public static void sendObjMsgToClient(Channel client, Object obj) throws IOException {
		client.writeAndFlush(MsgUtils.getObjMessage(obj));
	}

	/**
	 * 广播对象
	 * */
	public static void sendObjMsgToClients(Object obj) throws IOException {
		for (Channel client : SocketHandler.HTTP_CLIENTS) {
			client.writeAndFlush(MsgUtils.getObjMessage(obj));
		}
	}
	/**
	 * 广播字符串
	 * */
	public static void sendStrMsgToClients(String str) throws IOException {
		for (Channel client : SocketHandler.HTTP_CLIENTS) {
			Message message = MsgUtils.getStrMessage(str);
			client.writeAndFlush(message);
		}
	}

	/**
	 * WS: 发消息给指定客户端
	 * */
	public static void sendObjMsgToWSClient(Channel client, Object obj) throws IOException {
		client.writeAndFlush(obj);
	}

	/**
	 * WS: 广播对象
	 * */
	public static void sendObjMsgToWSClients(Object obj) throws IOException {
		for (Channel client : TextWebSocketHandler.WS_CLIENTS) {
			client.writeAndFlush(obj);
		}
	}
}

Socket启动器(NettyStartListener)

import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;

/**
 * 监听Spring容器启动完成,完成后启动Netty服务器
 * @author wmg
 **/
@Component
public class NettyStartListener implements ApplicationRunner {
	@Resource
	private SocketServer socketServer;

	@Override
	public void run(ApplicationArguments args) throws Exception {
		this.socketServer.start();
	}
}

客户端

客户端配置(Client),初始化(ClientChannelInitializer),客户端拦截器(ClientHandle),客户端启动器(ClientMain)

客户端配置(Client)

import com.example.netty.entity.Message;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;

public class Client implements Runnable{
	private String ip;// ip
	private int port;// 端口
	private boolean isConnection = false;
	private ChannelHandlerContext serverChannel;

	public Client(String ip, int port) {
		this.ip = ip;
		this.port = port;
	}

	// 与服务器建立连接
	public void connection() {
		new Thread(this).start();
		
	}
	@Override
	public void run() {
		
		EventLoopGroup group = new NioEventLoopGroup();// 服务器监听服务器发送信息
		Bootstrap bootstrap = new Bootstrap();
		bootstrap.group(group).channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
				.handler(new ClientChannelInitializer(this));// 基于NIO编程模型通信
		try {
			ChannelFuture channelFuture = bootstrap.connect(ip, port).sync();

			channelFuture.channel().closeFuture().sync(); 
		} catch (InterruptedException e) {

			System.out.println("连接服务器失败");
		}finally {
			//尝试重连
			System.out.println("正在重连");
			run();
		}	
	}

	public void close() {
		serverChannel.close();
	}
	public boolean isConnection() {
		return isConnection;
	}

	public void setConnection(boolean isConnection) {
		this.isConnection = isConnection;
	}

	public void sendMsg(Message msg) {
		while(isConnection) {
			serverChannel.writeAndFlush(msg);
		}
		
	}

	public ChannelHandlerContext getServerChannel() {
		return serverChannel;
	}

	public void setServerChannel(ChannelHandlerContext serverChannel) {
		this.serverChannel = serverChannel;
	}

}

初始化(ClientChannelInitializer)

import com.example.netty.entity.MeaasgeDecoder;
import com.example.netty.entity.MessageEncoder;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;

import java.util.concurrent.TimeUnit;

public class ClientChannelInitializer extends ChannelInitializer<SocketChannel> {

	private Client client;
	public  ClientChannelInitializer(Client client) {
		// TODO Auto-generated constructor stub
		this.client=client;
	}
	@Override
	protected void initChannel(SocketChannel socketChannel) throws Exception {
		
		socketChannel.pipeline()
		//表示5秒向服务器发送一次心跳包   10秒没接收到服务器端信息表示服务器端挂了
		.addLast("ping",new IdleStateHandler(10, 5, 0, TimeUnit.SECONDS))
		.addLast("decoder",new MessageEncoder())
		.addLast("encoder",new MeaasgeDecoder())
		.addLast(new ClientHandle(client));//注册处理器
	}
}

客户端拦截器(ClientHandle)

import com.example.netty.entity.Message;
import com.example.netty.entity.MessageHead;
import com.example.netty.entity.MessageType;
import com.example.netty.entity.UavHeartbeat;
import com.example.netty.util.MsgUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;

import java.io.IOException;
import java.time.LocalDateTime;

public class ClientHandle extends ChannelInboundHandlerAdapter {
	
	Client client;
    public  ClientHandle(Client client) {
       this.client=client;
    }
    /**
	 * 读写超时事事件
     * @throws Exception 
	 */
	@Override
	public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
		if(evt instanceof IdleStateEvent) {
			IdleStateEvent idleStateEvent=((IdleStateEvent) evt);
			/**
			 * 如果没有收到服务端的写 则表示服务器超时 判断是否断开连接
			 */
	        if(idleStateEvent.state()==IdleState.READER_IDLE) {
	        	if(!ctx.channel().isOpen()) {
	        		System.err.println("正在重连");
	        		client.connection();
	        		System.err.println("重连成功");

	        	}
	        }else if(idleStateEvent.state()==IdleState.WRITER_IDLE) {
	        	//如果没有触发写事件则向服务器发送一次心跳包
	        	sendUavHeartbeat(ctx);
	        }
		}else {
			super.userEventTriggered(ctx, evt);
		}
	}
	//建立连接时回调
	@Override
	public void channelActive(ChannelHandlerContext ctx) throws Exception {
		
		System.out.println("与服务器建立连接成功");
		client.setServerChannel(ctx);
		client.setConnection(true);
		sendUavHeartbeat(ctx);
		//ctx.fireChannelActive();//如果注册多个handle 下一个handel的事件需要触发需要调用这个方法
		
	}
	//读取服务器发送信息时回调
	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		Message message=(Message) msg;
		Integer messageType = message.getHeader().getMessageType();
		if (messageType.equals(MessageType.HEARTBEAT)){
			//心跳包
			System.out.println("心跳包返回:"+message);
		}else if (messageType.equals(MessageType.INSTRUCT)){
			//指令 TODO
		}else if (messageType.equals(MessageType.STR)){
			//字符串
			System.out.println("接收服务器发送的Str数据:"+ MsgUtils.bytesToStr(message.getContent()));
		}else {
			System.out.println("接收服务器发送的数据-无法解析:"+ message.getContent());
		}
	
	}

	//发生异常时回调
	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
	    System.out.println("发生异常 与服务器断开连接");
		ctx.close();//关闭连接
	}

	private void sendUavHeartbeat(ChannelHandlerContext ctx) throws IOException {
		MessageHead head=new MessageHead();
		UavHeartbeat uavHeartbeat = new UavHeartbeat();

		uavHeartbeat.setUavCode("dji_m300_001")
				.setUavNumber("m300_001")
				.setUavStatus(1);
		byte[] content = MsgUtils.objectToBytes(uavHeartbeat);

		Message msg=new Message(head,content);
		head.setCreateDate(LocalDateTime.now())
				.setMessageType(MessageType.HEARTBEAT)
				.setLength(content.length)
				.setToken(msg.buidToken());

//		System.out.println("正在向服务端发送心跳包:"+msg);
		ctx.writeAndFlush(msg);
	}
}

客户端启动器(ClientMain)

import com.example.netty.entity.Message;
import com.example.netty.entity.MessageHead;
import com.example.netty.entity.MessageType;

import java.time.LocalDateTime;


public class ClientMain {
	public static void main(String[] args) {
		new Thread(new Runnable() {

			@Override
			public void run() {
				
				Client client1 = new Client("202.104.29.139", 8915);
//				Client client1 = new Client("127.0.0.1", 11005);

				client1.connection();


				String content = "哈哈哈哈!";
				byte[] bts = content.getBytes();
				MessageHead head = new MessageHead();
				// 令牌生成时间
				head.setCreateDate(LocalDateTime.now());
                head.setMessageType(MessageType.STR);
				head.setLength(bts.length);

				Message message = new Message(head, bts);
				message.getHeader().setToken(message.buidToken());

				client1.sendMsg(message);

			}
		}).start();
		
	}
}

websocket测试

websocket在线测试网址:http://websocket.jsonin.com/

连接路径: ws://127.0.0.1:7003/ws


文章作者: wmg
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 wmg !
  目录