基于Mina2的Websocket实现

阅读更多

    在网上找了很多关于websocket协议的资料。我发现大部分的资料或是实现记录的都不完整,或者只给出了最基本的实现。于是,我花了一周的业余实现写了一个相对完整的实现。

    首先是解码器部分:

    

public class WSDecoder extends CumulativeProtocolDecoder {
	private final static String REQUEST_CONTEXT_KEY = "__REQUEST_DATA_CONTEXT";
	private final static String END_TAG = "\r\n";
	
	private  enum FrameType {
		Text,Binary,Control;
	}
	
	private class RequestDataContext {
		
		private IoBuffer _tmp;
		private CharsetDecoder _charsetDecoder;
		private FrameType _frameType;
		
		RequestDataContext(String charset) {
			this._tmp = IoBuffer.allocate(512).setAutoExpand(true);
			
			this._charsetDecoder = Charset.forName("utf-8").newDecoder();
		}
		
		public FrameType getFrameType() {
			return this._frameType;
		}
		
		public String getDataAsString() {
			try {
				_tmp.flip();
				return _tmp.getString(_charsetDecoder);
			} catch (CharacterCodingException e) {
				return null;
			}
		}
		
		public byte[] getDataAsArray() {
			_tmp.flip();
			
			byte[] data = new byte[_tmp.remaining()];
			_tmp.get(data);
			
			return data;
		}
		
		void append(byte[] data) {
			this._tmp.put(data);
		}
		
		void setFrameType(FrameType _frameType) {
			this._frameType = _frameType;
		}
	}

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetDecoder charsetDecoder = Charset.forName("utf-8")
			.newDecoder();

	public void setCharsetDecoder(CharsetDecoder charsetDecoder) {
		this.charsetDecoder = charsetDecoder;
	}

	@Override
	protected boolean doDecode(IoSession session, IoBuffer in,
			ProtocolDecoderOutput decoderOutput)
			throws CharacterCodingException, NoSuchAlgorithmException {

		if (!in.hasRemaining())
			return false;

		WSSessionState state = getSessionState(session);

		switch (state) {
			case Handshake:
				doHandshake(session, in);
				break;
	
			case Connected:
				if (in.remaining() < 2)
					return false;
	
				in.mark().order(this.byteOrder);
	
				byte fstByte = in.get();
	
				int opCode = fstByte & 0xf;
				switch (opCode) {
					case 0x0:
					case 0x1:
					case 0x2:
						boolean isFinalFrame = fstByte < 0;
						boolean isRsvColZero = (fstByte & 0x70) == 0;
						if (!isRsvColZero) {
							closeConnection(session, in);
							break;
						}
		
						byte secByte = in.get();
						boolean isMasking = secByte < 0;
		
						int dataLength = 0;
						byte payload = (byte) (secByte & 0x7f);
						if (payload == 126)
							dataLength = in.getUnsignedShort();
						else if (payload == 127)
							dataLength = (int) in.getLong();
						else
							dataLength = payload;
		
						if (in.remaining() < (isMasking ? dataLength + 4 : dataLength)) {
							in.reset();
							return false;
						}
		
						byte[] mask = new byte[4];
						byte[] data = new byte[dataLength];
		
						if (isMasking)
							in.get(mask);
		
						in.get(data);
						
						// 用掩码处理数据。
						for( int i=0, maskLength=mask.length, looplimit=data.length; i= 1.1F;

		if (!hasWebsocket || !hasUpgrade || !isGetMethod)
			throw new WSException("Invalid websocket request!");

		if (isSecWsKeyNull || !isValidVersion || !isValidHttpVer)
			throw new WSException("Invalid websocket request!");

		String secWsAccept = getSecWebsocketAccept(secWskeyCol);
		String response = getResponseData(upgradeCol, connectionCol,
				secWsAccept, secWsProtocolCol);

		initRequestContext(session, msgColumns);

		session.write(response);
	}

	private String[] splitHandshakeMessage(String handshakeMessage) {
		StringTokenizer st = new StringTokenizer(handshakeMessage, END_TAG);
		String[] result = new String[st.countTokens()];

		for (int i = 0; st.hasMoreTokens(); i++)
			result[i] = st.nextToken();

		return result;
	}

	private boolean contains(String src, String target) {
		if (src == null || src.isEmpty())
			return false;
		else
			return src.toLowerCase().contains(target);
	}

	private String getSecWebsocketAccept(String secWebsocketkey)
			throws NoSuchAlgorithmException {

		StringBuilder srcBuilder = new StringBuilder();
		srcBuilder.append(secWebsocketkey);
		srcBuilder.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");

		MessageDigest md = MessageDigest.getInstance("SHA-1");

		md.update(srcBuilder.toString().getBytes(charsetDecoder.charset()));
		byte[] ciphertext = md.digest();

		String result = new String(Base64.encodeBase64(ciphertext),
				charsetDecoder.charset());
		return result;
	}

	private String getResponseData(String upgrade, String connection,
			String secWsAccept, String secWsProtocol) {

		StringBuilder result = new StringBuilder();

		result.append("HTTP/1.1 101 Switching Protocols\r\n");
		result.append("Upgrade:").append(upgrade).append(END_TAG);
		result.append("Connection:").append(connection).append(END_TAG);
		result.append("Sec-WebSocket-Accept:").append(secWsAccept)
				.append(END_TAG);

		if (secWsProtocol != null && !"".equals(secWsProtocol))
			result.append("Sec-WebSocket-Protocol:").append(secWsProtocol)
					.append(END_TAG);

		result.append(END_TAG);

		return result.toString();
	}

	private void closeConnection(IoSession session, IoBuffer buffer) {
		buffer.free();
		session.close(true);
	}

	@SuppressWarnings("unused")
	private void closeConnection(IoSession session, String errorMsg) {
		session.write(errorMsg).addListener(
				new IoFutureListener() {
					@Override
					public void operationComplete(WriteFuture future) {
						future.getSession().close(true);
					}
				});
	}

	private void initRequestContext(IoSession session, String[] data) {
		session.setAttribute("__SESSION_CONTEXT", data);
	}
}

     

    下面是编码器部分:

    

public class WSEncoder extends ProtocolEncoderAdapter {

	private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;

	public void setByteOrder(ByteOrder byteOrder) {
		this.byteOrder = byteOrder;
	}

	private CharsetEncoder charsetEncoder = Charset.forName("utf-8")
			.newEncoder();

	public void setCharsetEncoder(CharsetEncoder charsetEncoder) {
		this.charsetEncoder = charsetEncoder;
	}

	private int defaultPageSize = 65536;

	public void setDefaultPageSize(int defaultPageSize) {
		this.defaultPageSize = defaultPageSize;
	}
	
	// 返回的数据默认不必进行掩码运算。
	private boolean isMasking = false;
	public void setIsMasking(boolean masking) {
		this.isMasking = masking;
	}

	@Override
	public void encode(IoSession session, Object message, ProtocolEncoderOutput encoderOutput) throws CharacterCodingException {
		IoBuffer buff = IoBuffer.allocate(1024).setAutoExpand(true);

		WSSessionState status = getSessionState(session);
		
		switch (status) {
			case Handshake:
				try {
					buff.putString((String) message, charsetEncoder)
						.flip();

					encoderOutput.write(buff);
				} catch (CharacterCodingException e) {
					session.close(true);
				}

				session.setAttribute(WSSessionState.ATTRIBUTE_KEY, WSSessionState.Connected);
				break;
	
			case Connected:
				if (!session.isConnected() || message == null)
					return;
				
				byte dataType = 1;
				
				// 将数据统一转换成byte数组进行处理。
				byte[] data = null;
				if (message instanceof String)
					data = ((String) message).getBytes(charsetEncoder.charset());
				else {
					data = (byte[]) message;
					dataType = 2;
				}
				
				// 生成掩码。
				byte[] mask = new byte[4];
				Random random = new Random();
				
				// 用掩码处理数据。
				for( int i=0,limit=data.length; i 0 && remainLength <= 125) {
						payload = remainLength;
					} else if (remainLength > 125 && remainLength <=65536) {
						payload = 126;
					} else {
						payload = 127;
					}
					
					switch(payload) {
						case 126 : 
							headerLeng += 2;
							break;
							
						case 127 :
							headerLeng += 8;
							break;
							
						default  :
							headerLeng += 0;
							break;
					}
					
					headerLeng += isMasking ? 4 : 0;
					
					// 计算当前帧中剩余的可用于保存“负载数据”的字节长度。
					realLength = ( pageSize - headerLeng )  >= remainLength ? remainLength : ( pageSize - headerLeng );
					
					
					// 判断当前帧是否为最后一帧。
					isFinalFrame = (remainLength - (pageSize - headerLeng)) < 0;
					
					// 生成第一个字节
					byte fstByte = (byte)(isFinalFrame ? 0x80 : 0x0);
					// 若当前帧为第一帧,则还需保存数据类型信息。
					fstByte += isFirstFrame ? dataType : 0;
					buff.put(fstByte);
					
					// 生成第二个字节。判断是否需要掩码,若需要掩码,则置标志位的值为1.
					byte sndByte = (byte)(isMasking ? 0x80 : 0);
					// 保存payload信息。
					sndByte += payload;
					buff.put(sndByte);
					
					switch(payload) {
						case 126 : 
							buff.putUnsignedShort(realLength);
							break;
							
						case 127 :
							buff.putLong(realLength);
							break;
							
						default  :
							break;
					}
					
					if (isMasking) {
						random.nextBytes(mask);
						buff.put(mask);
					}
					
					for(int loopCount=dataIndex+realLength, i=0; dataIndex 
  

     

    下面是工具类实现:

    

final class WSToolKit {
	private WSToolKit() {
	}

	static enum WSSessionState {
		Handshake, Connected;
		public final static String ATTRIBUTE_KEY = "__SESSION_STATE";
	}

	static  T nvl(T t1, T t2) {
		return t1 == null ? t2 : t1;
	}
	
	static WSSessionState getSessionState (IoSession session) {
		WSSessionState state = (WSSessionState) session
				.getAttribute(WSSessionState.ATTRIBUTE_KEY);
		
		if (state == null) {
			state = WSSessionState.Handshake;
			session.setAttribute(WSSessionState.ATTRIBUTE_KEY, state);
		}
		
		return state;
	}
	
	static String getMessageColumnValue(String[] headers, String headerTag) {
		for (String header : headers) {
			if( header.startsWith(headerTag) )
				return header.substring(headerTag.length(), header.length())
						.trim();
		}

		return null;
	}
	
	static String subString(String src, int order, String flag) {
		for (int i = 1, j = 0, k = 0;; i++) {
			j = src.indexOf(flag, k);
			if (i < order) {
				if (j == -1)
					return "";
				else
					k = j + 1;
			} else {
				if (j == -1)
					return src.substring(k, src.length());
				else
					return src.substring(k, j);
			}
		}
	}
}

     

    如果大家想转载,请注明出处!谢谢。

你可能感兴趣的:(mina2,websocket,协议,实现,java)