在网上找了很多关于websocket协议的资料。我发现大部分的资料或是实现记录的都不完整,或者只给出了最基本的实现。于是,我花了一周的业余实现写了一个相对完整的实现。
首先是解码器部分:
public class WSDecoder extends CumulativeProtocolDecoder { private final static String REQUEST_CONTEXT_KEY = "__REQUEST_DATA_CONTEXT"; private final static String END_TAG = "\r\n"; private enum FrameType { Text,Binary,Control; } private class RequestDataContext { private IoBuffer _tmp; private CharsetDecoder _charsetDecoder; private FrameType _frameType; RequestDataContext(String charset) { this._tmp = IoBuffer.allocate(512).setAutoExpand(true); this._charsetDecoder = Charset.forName("utf-8").newDecoder(); } public FrameType getFrameType() { return this._frameType; } public String getDataAsString() { try { _tmp.flip(); return _tmp.getString(_charsetDecoder); } catch (CharacterCodingException e) { return null; } } public byte[] getDataAsArray() { _tmp.flip(); byte[] data = new byte[_tmp.remaining()]; _tmp.get(data); return data; } void append(byte[] data) { this._tmp.put(data); } void setFrameType(FrameType _frameType) { this._frameType = _frameType; } } private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN; public void setByteOrder(ByteOrder byteOrder) { this.byteOrder = byteOrder; } private CharsetDecoder charsetDecoder = Charset.forName("utf-8") .newDecoder(); public void setCharsetDecoder(CharsetDecoder charsetDecoder) { this.charsetDecoder = charsetDecoder; } @Override protected boolean doDecode(IoSession session, IoBuffer in, ProtocolDecoderOutput decoderOutput) throws CharacterCodingException, NoSuchAlgorithmException { if (!in.hasRemaining()) return false; WSSessionState state = getSessionState(session); switch (state) { case Handshake: doHandshake(session, in); break; case Connected: if (in.remaining() < 2) return false; in.mark().order(this.byteOrder); byte fstByte = in.get(); int opCode = fstByte & 0xf; switch (opCode) { case 0x0: case 0x1: case 0x2: boolean isFinalFrame = fstByte < 0; boolean isRsvColZero = (fstByte & 0x70) == 0; if (!isRsvColZero) { closeConnection(session, in); break; } byte secByte = in.get(); boolean isMasking = secByte < 0; int dataLength = 0; byte payload = (byte) (secByte & 0x7f); if (payload == 126) dataLength = in.getUnsignedShort(); else if (payload == 127) dataLength = (int) in.getLong(); else dataLength = payload; if (in.remaining() < (isMasking ? dataLength + 4 : dataLength)) { in.reset(); return false; } byte[] mask = new byte[4]; byte[] data = new byte[dataLength]; if (isMasking) in.get(mask); in.get(data); // 用掩码处理数据。 for( int i=0, maskLength=mask.length, looplimit=data.length; i<looplimit; i++ ) data[i] = (byte)(data[i] ^ mask[i % maskLength]); // 创建一个对象保存“数据帧的数据类型”。协议规定——对于分片的数据只有第一帧会携带数据类型信息,所以要新建对象保存数据类型,以应对分片。 RequestDataContext context = (RequestDataContext) session .getAttribute(REQUEST_CONTEXT_KEY); if (context == null) { context = new RequestDataContext(charsetDecoder.charset() .name()); context.setFrameType((opCode == 0x1) ? FrameType.Text : FrameType.Binary); session.setAttribute(REQUEST_CONTEXT_KEY, context); } context.append(data); if (isFinalFrame) { context = (RequestDataContext) session.removeAttribute(REQUEST_CONTEXT_KEY); if (context.getFrameType() == FrameType.Text) decoderOutput.write(context.getDataAsString()); else decoderOutput.write(context.getDataAsArray()); return true; } else return false; case 0x3: case 0x4: case 0x5: case 0x6: case 0x7: break; case 0x8: closeConnection(session, in); break; case 0x9: case 0xA: default: closeConnection(session, in); break; } break; default: closeConnection(session, in); break; } return true; } private void doHandshake(IoSession session, IoBuffer in) throws CharacterCodingException, NoSuchAlgorithmException { String handshakeMessage = in.getString(charsetDecoder); String[] msgColumns = splitHandshakeMessage(handshakeMessage); String requestURI = msgColumns[0]; String httpVersion = requestURI.substring( requestURI.lastIndexOf("/") + 1, requestURI.length()); String upgradeCol = getMessageColumnValue(msgColumns, "Upgrade:"); String connectionCol = getMessageColumnValue(msgColumns, "Connection:"); String secWsProtocolCol = getMessageColumnValue(msgColumns, "Sec-WebSocket-Protocol:"); String secWskeyCol = getMessageColumnValue(msgColumns, "Sec-WebSocket-Key:"); String wsVersionCol = getMessageColumnValue(msgColumns, "Sec-WebSocket-Version:"); // 校验重要字段。任何字段不满足条件,都会导致握手失败! boolean hasWebsocket = contains(upgradeCol, "websocket"); boolean hasUpgrade = contains(connectionCol, "upgrade"); boolean isGetMethod = "GET".equalsIgnoreCase(subString(requestURI, 1, " ")); boolean isSecWsKeyNull = secWskeyCol == null || secWskeyCol.isEmpty(); boolean isValidVersion = "13".equals(wsVersionCol); boolean isValidHttpVer = Float.parseFloat(httpVersion) >= 1.1F; if (!hasWebsocket || !hasUpgrade || !isGetMethod) throw new WSException("Invalid websocket request!"); if (isSecWsKeyNull || !isValidVersion || !isValidHttpVer) throw new WSException("Invalid websocket request!"); String secWsAccept = getSecWebsocketAccept(secWskeyCol); String response = getResponseData(upgradeCol, connectionCol, secWsAccept, secWsProtocolCol); initRequestContext(session, msgColumns); session.write(response); } private String[] splitHandshakeMessage(String handshakeMessage) { StringTokenizer st = new StringTokenizer(handshakeMessage, END_TAG); String[] result = new String[st.countTokens()]; for (int i = 0; st.hasMoreTokens(); i++) result[i] = st.nextToken(); return result; } private boolean contains(String src, String target) { if (src == null || src.isEmpty()) return false; else return src.toLowerCase().contains(target); } private String getSecWebsocketAccept(String secWebsocketkey) throws NoSuchAlgorithmException { StringBuilder srcBuilder = new StringBuilder(); srcBuilder.append(secWebsocketkey); srcBuilder.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); MessageDigest md = MessageDigest.getInstance("SHA-1"); md.update(srcBuilder.toString().getBytes(charsetDecoder.charset())); byte[] ciphertext = md.digest(); String result = new String(Base64.encodeBase64(ciphertext), charsetDecoder.charset()); return result; } private String getResponseData(String upgrade, String connection, String secWsAccept, String secWsProtocol) { StringBuilder result = new StringBuilder(); result.append("HTTP/1.1 101 Switching Protocols\r\n"); result.append("Upgrade:").append(upgrade).append(END_TAG); result.append("Connection:").append(connection).append(END_TAG); result.append("Sec-WebSocket-Accept:").append(secWsAccept) .append(END_TAG); if (secWsProtocol != null && !"".equals(secWsProtocol)) result.append("Sec-WebSocket-Protocol:").append(secWsProtocol) .append(END_TAG); result.append(END_TAG); return result.toString(); } private void closeConnection(IoSession session, IoBuffer buffer) { buffer.free(); session.close(true); } @SuppressWarnings("unused") private void closeConnection(IoSession session, String errorMsg) { session.write(errorMsg).addListener( new IoFutureListener<WriteFuture>() { @Override public void operationComplete(WriteFuture future) { future.getSession().close(true); } }); } private void initRequestContext(IoSession session, String[] data) { session.setAttribute("__SESSION_CONTEXT", data); } }
下面是编码器部分:
public class WSEncoder extends ProtocolEncoderAdapter { private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN; public void setByteOrder(ByteOrder byteOrder) { this.byteOrder = byteOrder; } private CharsetEncoder charsetEncoder = Charset.forName("utf-8") .newEncoder(); public void setCharsetEncoder(CharsetEncoder charsetEncoder) { this.charsetEncoder = charsetEncoder; } private int defaultPageSize = 65536; public void setDefaultPageSize(int defaultPageSize) { this.defaultPageSize = defaultPageSize; } // 返回的数据默认不必进行掩码运算。 private boolean isMasking = false; public void setIsMasking(boolean masking) { this.isMasking = masking; } @Override public void encode(IoSession session, Object message, ProtocolEncoderOutput encoderOutput) throws CharacterCodingException { IoBuffer buff = IoBuffer.allocate(1024).setAutoExpand(true); WSSessionState status = getSessionState(session); switch (status) { case Handshake: try { buff.putString((String) message, charsetEncoder) .flip(); encoderOutput.write(buff); } catch (CharacterCodingException e) { session.close(true); } session.setAttribute(WSSessionState.ATTRIBUTE_KEY, WSSessionState.Connected); break; case Connected: if (!session.isConnected() || message == null) return; byte dataType = 1; // 将数据统一转换成byte数组进行处理。 byte[] data = null; if (message instanceof String) data = ((String) message).getBytes(charsetEncoder.charset()); else { data = (byte[]) message; dataType = 2; } // 生成掩码。 byte[] mask = new byte[4]; Random random = new Random(); // 用掩码处理数据。 for( int i=0,limit=data.length; i<limit; i++) data[i] = (byte)(data[i] ^ mask[i%4]); /** * 以分片的方式向客户端推送数据。 */ int pageSize = this.defaultPageSize; // 分页大小 int remainLength = data.length; // 剩余数据长度 int realLength = 0; // 数据帧中“负载数据”的实际长度 int dataIndex = 0; for( boolean isFirstFrame=true, isFinalFrame = false; !isFinalFrame; buff.clear(), isFirstFrame=false) { int headerLeng = 2; int payload = 0; if (remainLength > 0 && remainLength <= 125) { payload = remainLength; } else if (remainLength > 125 && remainLength <=65536) { payload = 126; } else { payload = 127; } switch(payload) { case 126 : headerLeng += 2; break; case 127 : headerLeng += 8; break; default : headerLeng += 0; break; } headerLeng += isMasking ? 4 : 0; // 计算当前帧中剩余的可用于保存“负载数据”的字节长度。 realLength = ( pageSize - headerLeng ) >= remainLength ? remainLength : ( pageSize - headerLeng ); // 判断当前帧是否为最后一帧。 isFinalFrame = (remainLength - (pageSize - headerLeng)) < 0; // 生成第一个字节 byte fstByte = (byte)(isFinalFrame ? 0x80 : 0x0); // 若当前帧为第一帧,则还需保存数据类型信息。 fstByte += isFirstFrame ? dataType : 0; buff.put(fstByte); // 生成第二个字节。判断是否需要掩码,若需要掩码,则置标志位的值为1. byte sndByte = (byte)(isMasking ? 0x80 : 0); // 保存payload信息。 sndByte += payload; buff.put(sndByte); switch(payload) { case 126 : buff.putUnsignedShort(realLength); break; case 127 : buff.putLong(realLength); break; default : break; } if (isMasking) { random.nextBytes(mask); buff.put(mask); } for(int loopCount=dataIndex+realLength, i=0; dataIndex<loopCount; dataIndex++,i++) buff.put((byte)(data[i] ^ mask[i%4])); buff.flip(); encoderOutput.write(buff); remainLength -= (pageSize - headerLeng); } break; default: session.close(true); break; } } }
下面是工具类实现:
final class WSToolKit { private WSToolKit() { } static enum WSSessionState { Handshake, Connected; public final static String ATTRIBUTE_KEY = "__SESSION_STATE"; } static <T> T nvl(T t1, T t2) { return t1 == null ? t2 : t1; } static WSSessionState getSessionState (IoSession session) { WSSessionState state = (WSSessionState) session .getAttribute(WSSessionState.ATTRIBUTE_KEY); if (state == null) { state = WSSessionState.Handshake; session.setAttribute(WSSessionState.ATTRIBUTE_KEY, state); } return state; } static String getMessageColumnValue(String[] headers, String headerTag) { for (String header : headers) { if( header.startsWith(headerTag) ) return header.substring(headerTag.length(), header.length()) .trim(); } return null; } static String subString(String src, int order, String flag) { for (int i = 1, j = 0, k = 0;; i++) { j = src.indexOf(flag, k); if (i < order) { if (j == -1) return ""; else k = j + 1; } else { if (j == -1) return src.substring(k, src.length()); else return src.substring(k, j); } } } }
如果大家想转载,请注明出处!谢谢。