基于Mina2的Websocket实现

    在网上找了很多关于websocket协议的资料。我发现大部分的资料或是实现记录的都不完整,或者只给出了最基本的实现。于是,我花了一周的业余实现写了一个相对完整的实现。

    首先是解码器部分:

    

public class WSDecoder extends CumulativeProtocolDecoder {
	private final static String REQUEST_CONTEXT_KEY = "__REQUEST_DATA_CONTEXT";
	private final static String END_TAG = "\r\n";
	
	private  enum FrameType {
		Text,Binary,Control;
	}
	
	private class RequestDataContext {
		
		private IoBuffer _tmp;
		private CharsetDecoder _charsetDecoder;
		private FrameType _frameType;
		
		RequestDataContext(String charset) {
			this._tmp = IoBuffer.allocate(512).setAutoExpand(true);
			
			this._charsetDecoder = Charset.forName("utf-8").newDecoder();
		}
		
		public FrameType getFrameType() {
			return this._frameType;
		}
		
		public String getDataAsString() {
			try {
				_tmp.flip();
				return _tmp.getString(_charsetDecoder);
			} catch (CharacterCodingException e) {
				return null;
			}
		}
		
		public byte[] getDataAsArray() {
			_tmp.flip();
			
			byte[] data = new byte[_tmp.remaining()];
			_tmp.get(data);
			
			return data;
		}
		
		void append(byte[] data) {
			this._tmp.put(data);
		}
		
		void setFrameType(FrameType _frameType) {
			this._frameType = _frameType;
		}
	}

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetDecoder charsetDecoder = Charset.forName("utf-8")
			.newDecoder();

	public void setCharsetDecoder(CharsetDecoder charsetDecoder) {
		this.charsetDecoder = charsetDecoder;
	}

	@Override
	protected boolean doDecode(IoSession session, IoBuffer in,
			ProtocolDecoderOutput decoderOutput)
			throws CharacterCodingException, NoSuchAlgorithmException {

		if (!in.hasRemaining())
			return false;

		WSSessionState state = getSessionState(session);

		switch (state) {
			case Handshake:
				doHandshake(session, in);
				break;
	
			case Connected:
				if (in.remaining() < 2)
					return false;
	
				in.mark().order(this.byteOrder);
	
				byte fstByte = in.get();
	
				int opCode = fstByte & 0xf;
				switch (opCode) {
					case 0x0:
					case 0x1:
					case 0x2:
						boolean isFinalFrame = fstByte < 0;
						boolean isRsvColZero = (fstByte & 0x70) == 0;
						if (!isRsvColZero) {
							closeConnection(session, in);
							break;
						}
		
						byte secByte = in.get();
						boolean isMasking = secByte < 0;
		
						int dataLength = 0;
						byte payload = (byte) (secByte & 0x7f);
						if (payload == 126)
							dataLength = in.getUnsignedShort();
						else if (payload == 127)
							dataLength = (int) in.getLong();
						else
							dataLength = payload;
		
						if (in.remaining() < (isMasking ? dataLength + 4 : dataLength)) {
							in.reset();
							return false;
						}
		
						byte[] mask = new byte[4];
						byte[] data = new byte[dataLength];
		
						if (isMasking)
							in.get(mask);
		
						in.get(data);
						
						// 用掩码处理数据。
						for( int i=0, maskLength=mask.length, looplimit=data.length; i<looplimit; i++ ) 
							data[i] = (byte)(data[i] ^ mask[i % maskLength]);
						
						// 创建一个对象保存“数据帧的数据类型”。协议规定——对于分片的数据只有第一帧会携带数据类型信息,所以要新建对象保存数据类型,以应对分片。
						RequestDataContext context = (RequestDataContext) session
								.getAttribute(REQUEST_CONTEXT_KEY);
						if (context == null) {
							context = new RequestDataContext(charsetDecoder.charset()
									.name());
							context.setFrameType((opCode == 0x1) ? FrameType.Text
									: FrameType.Binary);
		
							session.setAttribute(REQUEST_CONTEXT_KEY, context);
						}
						
						context.append(data);
		
						if (isFinalFrame) {
							context = (RequestDataContext) session.removeAttribute(REQUEST_CONTEXT_KEY);
							
							if (context.getFrameType() == FrameType.Text)
								decoderOutput.write(context.getDataAsString());
							else
								decoderOutput.write(context.getDataAsArray());
							
							return true;
						}
						else
							return false;
		
					case 0x3:
					case 0x4:
					case 0x5:
					case 0x6:
					case 0x7:
						break;
		
					case 0x8:
						closeConnection(session, in);
						break;
						
					case 0x9:
					case 0xA:
					default:
						closeConnection(session, in);
						break;
				}
	
				break;
	
			default:
				closeConnection(session, in);
				break;
		}
		return true;
	}

	private void doHandshake(IoSession session, IoBuffer in)
			throws CharacterCodingException, NoSuchAlgorithmException {
		String handshakeMessage = in.getString(charsetDecoder);
		String[] msgColumns = splitHandshakeMessage(handshakeMessage);

		String requestURI = msgColumns[0];
		String httpVersion = requestURI.substring(
				requestURI.lastIndexOf("/") + 1, requestURI.length());

		String upgradeCol = getMessageColumnValue(msgColumns, "Upgrade:");
		String connectionCol = getMessageColumnValue(msgColumns, "Connection:");
		String secWsProtocolCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Protocol:");
		String secWskeyCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Key:");
		String wsVersionCol = getMessageColumnValue(msgColumns,
				"Sec-WebSocket-Version:");

		// 校验重要字段。任何字段不满足条件,都会导致握手失败!
		boolean hasWebsocket = contains(upgradeCol, "websocket");
		boolean hasUpgrade = contains(connectionCol, "upgrade");
		boolean isGetMethod = "GET".equalsIgnoreCase(subString(requestURI, 1,
				" "));

		boolean isSecWsKeyNull = secWskeyCol == null || secWskeyCol.isEmpty();
		boolean isValidVersion = "13".equals(wsVersionCol);
		boolean isValidHttpVer = Float.parseFloat(httpVersion) >= 1.1F;

		if (!hasWebsocket || !hasUpgrade || !isGetMethod)
			throw new WSException("Invalid websocket request!");

		if (isSecWsKeyNull || !isValidVersion || !isValidHttpVer)
			throw new WSException("Invalid websocket request!");

		String secWsAccept = getSecWebsocketAccept(secWskeyCol);
		String response = getResponseData(upgradeCol, connectionCol,
				secWsAccept, secWsProtocolCol);

		initRequestContext(session, msgColumns);

		session.write(response);
	}

	private String[] splitHandshakeMessage(String handshakeMessage) {
		StringTokenizer st = new StringTokenizer(handshakeMessage, END_TAG);
		String[] result = new String[st.countTokens()];

		for (int i = 0; st.hasMoreTokens(); i++)
			result[i] = st.nextToken();

		return result;
	}

	private boolean contains(String src, String target) {
		if (src == null || src.isEmpty())
			return false;
		else
			return src.toLowerCase().contains(target);
	}

	private String getSecWebsocketAccept(String secWebsocketkey)
			throws NoSuchAlgorithmException {

		StringBuilder srcBuilder = new StringBuilder();
		srcBuilder.append(secWebsocketkey);
		srcBuilder.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");

		MessageDigest md = MessageDigest.getInstance("SHA-1");

		md.update(srcBuilder.toString().getBytes(charsetDecoder.charset()));
		byte[] ciphertext = md.digest();

		String result = new String(Base64.encodeBase64(ciphertext),
				charsetDecoder.charset());
		return result;
	}

	private String getResponseData(String upgrade, String connection,
			String secWsAccept, String secWsProtocol) {

		StringBuilder result = new StringBuilder();

		result.append("HTTP/1.1 101 Switching Protocols\r\n");
		result.append("Upgrade:").append(upgrade).append(END_TAG);
		result.append("Connection:").append(connection).append(END_TAG);
		result.append("Sec-WebSocket-Accept:").append(secWsAccept)
				.append(END_TAG);

		if (secWsProtocol != null && !"".equals(secWsProtocol))
			result.append("Sec-WebSocket-Protocol:").append(secWsProtocol)
					.append(END_TAG);

		result.append(END_TAG);

		return result.toString();
	}

	private void closeConnection(IoSession session, IoBuffer buffer) {
		buffer.free();
		session.close(true);
	}

	@SuppressWarnings("unused")
	private void closeConnection(IoSession session, String errorMsg) {
		session.write(errorMsg).addListener(
				new IoFutureListener<WriteFuture>() {
					@Override
					public void operationComplete(WriteFuture future) {
						future.getSession().close(true);
					}
				});
	}

	private void initRequestContext(IoSession session, String[] data) {
		session.setAttribute("__SESSION_CONTEXT", data);
	}
}

     

    下面是编码器部分:

    

public class WSEncoder extends ProtocolEncoderAdapter {

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetEncoder charsetEncoder = Charset.forName("utf-8")
			.newEncoder();

	public void setCharsetEncoder(CharsetEncoder charsetEncoder) {
		this.charsetEncoder = charsetEncoder;
	}

	private int defaultPageSize = 65536;

	public void setDefaultPageSize(int defaultPageSize) {
		this.defaultPageSize = defaultPageSize;
	}
	
	// 返回的数据默认不必进行掩码运算。
	private boolean isMasking = false;
	public void setIsMasking(boolean masking) {
		this.isMasking = masking;
	}

	@Override
	public void encode(IoSession session, Object message, ProtocolEncoderOutput encoderOutput) throws CharacterCodingException {
		IoBuffer buff = IoBuffer.allocate(1024).setAutoExpand(true);

		WSSessionState status = getSessionState(session);
		
		switch (status) {
			case Handshake:
				try {
					buff.putString((String) message, charsetEncoder)
						.flip();

					encoderOutput.write(buff);
				} catch (CharacterCodingException e) {
					session.close(true);
				}

				session.setAttribute(WSSessionState.ATTRIBUTE_KEY, WSSessionState.Connected);
				break;
	
			case Connected:
				if (!session.isConnected() || message == null)
					return;
				
				byte dataType = 1;
				
				// 将数据统一转换成byte数组进行处理。
				byte[] data = null;
				if (message instanceof String)
					data = ((String) message).getBytes(charsetEncoder.charset());
				else {
					data = (byte[]) message;
					dataType = 2;
				}
				
				// 生成掩码。
				byte[] mask = new byte[4];
				Random random = new Random();
				
				// 用掩码处理数据。
				for( int i=0,limit=data.length; i<limit; i++)
					data[i] = (byte)(data[i] ^ mask[i%4]);
				
				/**
				 * 以分片的方式向客户端推送数据。
				 */
				int pageSize = this.defaultPageSize;	// 分页大小
				int remainLength = data.length;			// 剩余数据长度
				int realLength = 0;						// 数据帧中“负载数据”的实际长度
				int dataIndex = 0;
				for( boolean isFirstFrame=true, isFinalFrame = false; !isFinalFrame; buff.clear(), isFirstFrame=false) {
					
					int headerLeng = 2;
					
					int payload = 0;
					if (remainLength > 0 && remainLength <= 125) {
						payload = remainLength;
					} else if (remainLength > 125 && remainLength <=65536) {
						payload = 126;
					} else {
						payload = 127;
					}
					
					switch(payload) {
						case 126 : 
							headerLeng += 2;
							break;
							
						case 127 :
							headerLeng += 8;
							break;
							
						default  :
							headerLeng += 0;
							break;
					}
					
					headerLeng += isMasking ? 4 : 0;
					
					// 计算当前帧中剩余的可用于保存“负载数据”的字节长度。
					realLength = ( pageSize - headerLeng )  >= remainLength ? remainLength : ( pageSize - headerLeng );
					
					
					// 判断当前帧是否为最后一帧。
					isFinalFrame = (remainLength - (pageSize - headerLeng)) < 0;
					
					// 生成第一个字节
					byte fstByte = (byte)(isFinalFrame ? 0x80 : 0x0);
					// 若当前帧为第一帧,则还需保存数据类型信息。
					fstByte += isFirstFrame ? dataType : 0;
					buff.put(fstByte);
					
					// 生成第二个字节。判断是否需要掩码,若需要掩码,则置标志位的值为1.
					byte sndByte = (byte)(isMasking ? 0x80 : 0);
					// 保存payload信息。
					sndByte += payload;
					buff.put(sndByte);
					
					switch(payload) {
						case 126 : 
							buff.putUnsignedShort(realLength);
							break;
							
						case 127 :
							buff.putLong(realLength);
							break;
							
						default  :
							break;
					}
					
					if (isMasking) {
						random.nextBytes(mask);
						buff.put(mask);
					}
					
					for(int loopCount=dataIndex+realLength, i=0; dataIndex<loopCount; dataIndex++,i++)
						buff.put((byte)(data[i] ^ mask[i%4]));
					
					buff.flip();
					encoderOutput.write(buff);
					
					remainLength -= (pageSize - headerLeng);
				}
				
				break;
	
			default:
				session.close(true);
				break;
		}
	}
}

     

    下面是工具类实现:

    

final class WSToolKit {
	private WSToolKit() {
	}

	static enum WSSessionState {
		Handshake, Connected;
		public final static String ATTRIBUTE_KEY = "__SESSION_STATE";
	}

	static <T> T nvl(T t1, T t2) {
		return t1 == null ? t2 : t1;
	}
	
	static WSSessionState getSessionState (IoSession session) {
		WSSessionState state = (WSSessionState) session
				.getAttribute(WSSessionState.ATTRIBUTE_KEY);
		
		if (state == null) {
			state = WSSessionState.Handshake;
			session.setAttribute(WSSessionState.ATTRIBUTE_KEY, state);
		}
		
		return state;
	}
	
	static String getMessageColumnValue(String[] headers, String headerTag) {
		for (String header : headers) {
			if( header.startsWith(headerTag) )
				return header.substring(headerTag.length(), header.length())
						.trim();
		}

		return null;
	}
	
	static String subString(String src, int order, String flag) {
		for (int i = 1, j = 0, k = 0;; i++) {
			j = src.indexOf(flag, k);
			if (i < order) {
				if (j == -1)
					return "";
				else
					k = j + 1;
			} else {
				if (j == -1)
					return src.substring(k, src.length());
				else
					return src.substring(k, j);
			}
		}
	}
}

     

    如果大家想转载,请注明出处!谢谢。

你可能感兴趣的:(java,实现,websocket,协议,mina2)