在网上找了很多关于websocket协议的资料。我发现大部分的资料或是实现记录的都不完整,或者只给出了最基本的实现。于是,我花了一周的业余实现写了一个相对完整的实现。
首先是解码器部分:
public class WSDecoder extends CumulativeProtocolDecoder {
private final static String REQUEST_CONTEXT_KEY = "__REQUEST_DATA_CONTEXT";
private final static String END_TAG = "\r\n";
private enum FrameType {
Text,Binary,Control;
}
private class RequestDataContext {
private IoBuffer _tmp;
private CharsetDecoder _charsetDecoder;
private FrameType _frameType;
RequestDataContext(String charset) {
this._tmp = IoBuffer.allocate(512).setAutoExpand(true);
this._charsetDecoder = Charset.forName("utf-8").newDecoder();
}
public FrameType getFrameType() {
return this._frameType;
}
public String getDataAsString() {
try {
_tmp.flip();
return _tmp.getString(_charsetDecoder);
} catch (CharacterCodingException e) {
return null;
}
}
public byte[] getDataAsArray() {
_tmp.flip();
byte[] data = new byte[_tmp.remaining()];
_tmp.get(data);
return data;
}
void append(byte[] data) {
this._tmp.put(data);
}
void setFrameType(FrameType _frameType) {
this._frameType = _frameType;
}
}
private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;
public void setByteOrder(ByteOrder byteOrder) {
this.byteOrder = byteOrder;
}
private CharsetDecoder charsetDecoder = Charset.forName("utf-8")
.newDecoder();
public void setCharsetDecoder(CharsetDecoder charsetDecoder) {
this.charsetDecoder = charsetDecoder;
}
@Override
protected boolean doDecode(IoSession session, IoBuffer in,
ProtocolDecoderOutput decoderOutput)
throws CharacterCodingException, NoSuchAlgorithmException {
if (!in.hasRemaining())
return false;
WSSessionState state = getSessionState(session);
switch (state) {
case Handshake:
doHandshake(session, in);
break;
case Connected:
if (in.remaining() < 2)
return false;
in.mark().order(this.byteOrder);
byte fstByte = in.get();
int opCode = fstByte & 0xf;
switch (opCode) {
case 0x0:
case 0x1:
case 0x2:
boolean isFinalFrame = fstByte < 0;
boolean isRsvColZero = (fstByte & 0x70) == 0;
if (!isRsvColZero) {
closeConnection(session, in);
break;
}
byte secByte = in.get();
boolean isMasking = secByte < 0;
int dataLength = 0;
byte payload = (byte) (secByte & 0x7f);
if (payload == 126)
dataLength = in.getUnsignedShort();
else if (payload == 127)
dataLength = (int) in.getLong();
else
dataLength = payload;
if (in.remaining() < (isMasking ? dataLength + 4 : dataLength)) {
in.reset();
return false;
}
byte[] mask = new byte[4];
byte[] data = new byte[dataLength];
if (isMasking)
in.get(mask);
in.get(data);
// 用掩码处理数据。
for( int i=0, maskLength=mask.length, looplimit=data.length; i= 1.1F;
if (!hasWebsocket || !hasUpgrade || !isGetMethod)
throw new WSException("Invalid websocket request!");
if (isSecWsKeyNull || !isValidVersion || !isValidHttpVer)
throw new WSException("Invalid websocket request!");
String secWsAccept = getSecWebsocketAccept(secWskeyCol);
String response = getResponseData(upgradeCol, connectionCol,
secWsAccept, secWsProtocolCol);
initRequestContext(session, msgColumns);
session.write(response);
}
private String[] splitHandshakeMessage(String handshakeMessage) {
StringTokenizer st = new StringTokenizer(handshakeMessage, END_TAG);
String[] result = new String[st.countTokens()];
for (int i = 0; st.hasMoreTokens(); i++)
result[i] = st.nextToken();
return result;
}
private boolean contains(String src, String target) {
if (src == null || src.isEmpty())
return false;
else
return src.toLowerCase().contains(target);
}
private String getSecWebsocketAccept(String secWebsocketkey)
throws NoSuchAlgorithmException {
StringBuilder srcBuilder = new StringBuilder();
srcBuilder.append(secWebsocketkey);
srcBuilder.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
MessageDigest md = MessageDigest.getInstance("SHA-1");
md.update(srcBuilder.toString().getBytes(charsetDecoder.charset()));
byte[] ciphertext = md.digest();
String result = new String(Base64.encodeBase64(ciphertext),
charsetDecoder.charset());
return result;
}
private String getResponseData(String upgrade, String connection,
String secWsAccept, String secWsProtocol) {
StringBuilder result = new StringBuilder();
result.append("HTTP/1.1 101 Switching Protocols\r\n");
result.append("Upgrade:").append(upgrade).append(END_TAG);
result.append("Connection:").append(connection).append(END_TAG);
result.append("Sec-WebSocket-Accept:").append(secWsAccept)
.append(END_TAG);
if (secWsProtocol != null && !"".equals(secWsProtocol))
result.append("Sec-WebSocket-Protocol:").append(secWsProtocol)
.append(END_TAG);
result.append(END_TAG);
return result.toString();
}
private void closeConnection(IoSession session, IoBuffer buffer) {
buffer.free();
session.close(true);
}
@SuppressWarnings("unused")
private void closeConnection(IoSession session, String errorMsg) {
session.write(errorMsg).addListener(
new IoFutureListener() {
@Override
public void operationComplete(WriteFuture future) {
future.getSession().close(true);
}
});
}
private void initRequestContext(IoSession session, String[] data) {
session.setAttribute("__SESSION_CONTEXT", data);
}
}
下面是编码器部分:
public class WSEncoder extends ProtocolEncoderAdapter {
private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN;
public void setByteOrder(ByteOrder byteOrder) {
this.byteOrder = byteOrder;
}
private CharsetEncoder charsetEncoder = Charset.forName("utf-8")
.newEncoder();
public void setCharsetEncoder(CharsetEncoder charsetEncoder) {
this.charsetEncoder = charsetEncoder;
}
private int defaultPageSize = 65536;
public void setDefaultPageSize(int defaultPageSize) {
this.defaultPageSize = defaultPageSize;
}
// 返回的数据默认不必进行掩码运算。
private boolean isMasking = false;
public void setIsMasking(boolean masking) {
this.isMasking = masking;
}
@Override
public void encode(IoSession session, Object message, ProtocolEncoderOutput encoderOutput) throws CharacterCodingException {
IoBuffer buff = IoBuffer.allocate(1024).setAutoExpand(true);
WSSessionState status = getSessionState(session);
switch (status) {
case Handshake:
try {
buff.putString((String) message, charsetEncoder)
.flip();
encoderOutput.write(buff);
} catch (CharacterCodingException e) {
session.close(true);
}
session.setAttribute(WSSessionState.ATTRIBUTE_KEY, WSSessionState.Connected);
break;
case Connected:
if (!session.isConnected() || message == null)
return;
byte dataType = 1;
// 将数据统一转换成byte数组进行处理。
byte[] data = null;
if (message instanceof String)
data = ((String) message).getBytes(charsetEncoder.charset());
else {
data = (byte[]) message;
dataType = 2;
}
// 生成掩码。
byte[] mask = new byte[4];
Random random = new Random();
// 用掩码处理数据。
for( int i=0,limit=data.length; i 0 && remainLength <= 125) {
payload = remainLength;
} else if (remainLength > 125 && remainLength <=65536) {
payload = 126;
} else {
payload = 127;
}
switch(payload) {
case 126 :
headerLeng += 2;
break;
case 127 :
headerLeng += 8;
break;
default :
headerLeng += 0;
break;
}
headerLeng += isMasking ? 4 : 0;
// 计算当前帧中剩余的可用于保存“负载数据”的字节长度。
realLength = ( pageSize - headerLeng ) >= remainLength ? remainLength : ( pageSize - headerLeng );
// 判断当前帧是否为最后一帧。
isFinalFrame = (remainLength - (pageSize - headerLeng)) < 0;
// 生成第一个字节
byte fstByte = (byte)(isFinalFrame ? 0x80 : 0x0);
// 若当前帧为第一帧,则还需保存数据类型信息。
fstByte += isFirstFrame ? dataType : 0;
buff.put(fstByte);
// 生成第二个字节。判断是否需要掩码,若需要掩码,则置标志位的值为1.
byte sndByte = (byte)(isMasking ? 0x80 : 0);
// 保存payload信息。
sndByte += payload;
buff.put(sndByte);
switch(payload) {
case 126 :
buff.putUnsignedShort(realLength);
break;
case 127 :
buff.putLong(realLength);
break;
default :
break;
}
if (isMasking) {
random.nextBytes(mask);
buff.put(mask);
}
for(int loopCount=dataIndex+realLength, i=0; dataIndex
下面是工具类实现:
final class WSToolKit {
private WSToolKit() {
}
static enum WSSessionState {
Handshake, Connected;
public final static String ATTRIBUTE_KEY = "__SESSION_STATE";
}
static T nvl(T t1, T t2) {
return t1 == null ? t2 : t1;
}
static WSSessionState getSessionState (IoSession session) {
WSSessionState state = (WSSessionState) session
.getAttribute(WSSessionState.ATTRIBUTE_KEY);
if (state == null) {
state = WSSessionState.Handshake;
session.setAttribute(WSSessionState.ATTRIBUTE_KEY, state);
}
return state;
}
static String getMessageColumnValue(String[] headers, String headerTag) {
for (String header : headers) {
if( header.startsWith(headerTag) )
return header.substring(headerTag.length(), header.length())
.trim();
}
return null;
}
static String subString(String src, int order, String flag) {
for (int i = 1, j = 0, k = 0;; i++) {
j = src.indexOf(flag, k);
if (i < order) {
if (j == -1)
return "";
else
k = j + 1;
} else {
if (j == -1)
return src.substring(k, src.length());
else
return src.substring(k, j);
}
}
}
}
如果大家想转载,请注明出处!谢谢。