在阅读本文前需要对socket以及自定义协议有一个基本的了解,可以先查看上一篇文章《基于Java Socket的自定义协议,实现Android与服务器的长连接(一)》学习相关的基础知识点。
上一篇文章中,我们对socket编程和自定义协议做了一个简单的了解,本文将在此基础上加以深入,来实现Android和服务器之间的长连接,现定义协议如下:
从上述的协议定义中,我们可以看出,四种协议有共同的3个要素,分别是:长度、版本号、数据类型,那么我们可以先抽象出一个基本的协议,如下:
import android.util.Log;
import com.shandiangou.sdgprotocol.lib.Config;
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;
import java.io.ByteArrayOutputStream;
/**
* Created by meishan on 16/12/1.
*
* 协议类型: 0表示数据,1表示数据Ack,2表示ping,3表示pingAck
*/
public abstract class BasicProtocol {
// 长度均以字节(byte)为单位
public static final int LENGTH_LEN = 4; //记录整条数据长度数值的长度
protected static final int VER_LEN = 1; //协议的版本长度(其中前3位作为预留位,后5位作为版本号)
protected static final int TYPE_LEN = 1; //协议的数据类型长度
private int reserved = 0; //预留信息
private int version = Config.VERSION; //版本号
/**
* 获取整条数据长度
* 单位:字节(byte)
*
* @return
*/
protected int getLength() {
return LENGTH_LEN + VER_LEN + TYPE_LEN;
}
public int getReserved() {
return reserved;
}
public void setReserved(int reserved) {
this.reserved = reserved;
}
public int getVersion() {
return version;
}
public void setVersion(int version) {
this.version = version;
}
/**
* 获取协议类型,由子类实现
*
* @return
*/
public abstract int getProtocolType();
/**
* 由预留值和版本号计算完整版本号的byte[]值
*
* @return
*/
private int getVer(byte r, byte v, int vLen) {
int num = 0;
int rLen = 8 - vLen;
for (int i = 0; i < rLen; i++) {
num += (((r >> (rLen - 1 - i)) & 0x1) << (7 - i));
}
return num + v;
}
/**
* 拼接发送数据,此处拼接了协议版本、协议类型和数据长度,具体内容子类中再拼接
* 按顺序拼接
*
* @return
*/
public byte[] genContentData() {
byte[] length = SocketUtil.int2ByteArrays(getLength());
byte reserved = (byte) getReserved();
byte version = (byte) getVersion();
byte[] ver = {(byte) getVer(reserved, version, 5)};
byte[] type = {(byte) getProtocolType()};
ByteArrayOutputStream baos = new ByteArrayOutputStream(LENGTH_LEN + VER_LEN + TYPE_LEN);
baos.write(length, 0, LENGTH_LEN);
baos.write(ver, 0, VER_LEN);
baos.write(type, 0, TYPE_LEN);
return baos.toByteArray();
}
/**
* 解析出整条数据长度
*
* @param data
* @return
*/
protected int parseLength(byte[] data) {
return SocketUtil.byteArrayToInt(data, 0, LENGTH_LEN);
}
/**
* 解析出预留位
*
* @param data
* @return
*/
protected int parseReserved(byte[] data) {
byte r = data[LENGTH_LEN];//前4个字节(0,1,2,3)为数据长度的int值,与版本号组成一个字节
return (r >> 5) & 0xFF;
}
/**
* 解析出版本号
*
* @param data
* @return
*/
protected int parseVersion(byte[] data) {
byte v = data[LENGTH_LEN]; //与预留位组成一个字节
return ((v << 3) & 0xFF) >> 3;
}
/**
* 解析出协议类型
*
* @param data
* @return
*/
public static int parseType(byte[] data) {
byte t = data[LENGTH_LEN + VER_LEN];//前4个字节(0,1,2,3)为数据长度的int值,以及ver占一个字节
return t & 0xFF;
}
/**
* 解析接收数据,此处解析了协议版本、协议类型和数据长度,具体内容子类中再解析
*
* @param data
* @return
* @throws ProtocolException 协议版本不一致,抛出异常
*/
public int parseContentData(byte[] data) throws ProtocolException {
int reserved = parseReserved(data);
int version = parseVersion(data);
int protocolType = parseType(data);
if (version != getVersion()) {
throw new ProtocolException("input version is error: " + version);
}
return LENGTH_LEN + VER_LEN + TYPE_LEN;
}
@Override
public String toString() {
return "Version: " + getVersion() + ", Type: " + getProtocolType();
}
}
上述涉及到的Config类和SocketUtil类如下:
/**
* Created by meishan on 16/12/2.
*/
public class Config {
public static final int VERSION = 1; //协议版本号
public static final String ADDRESS = "10.17.64.237"; //服务器地址
public static final int PORT = 9013; //服务器端口号
}
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
/**
* Created by meishan on 16/12/1.
*/
public class SocketUtil {
private static Map msgImp = new HashMap<>();
static {
msgImp.put(DataProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.DataProtocol"); //0
msgImp.put(DataAckProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.DataAckProtocol"); //1
msgImp.put(PingProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.PingProtocol"); //2
msgImp.put(PingAckProtocol.PROTOCOL_TYPE, "com.shandiangou.sdgprotocol.lib.protocol.PingAckProtocol"); //3
}
/**
* 解析数据内容
*
* @param data
* @return
*/
public static BasicProtocol parseContentMsg(byte[] data) {
int protocolType = BasicProtocol.parseType(data);
String className = msgImp.get(protocolType);
BasicProtocol basicProtocol;
try {
basicProtocol = (BasicProtocol) Class.forName(className).newInstance();
basicProtocol.parseContentData(data);
} catch (Exception e) {
basicProtocol = null;
e.printStackTrace();
}
return basicProtocol;
}
/**
* 读数据
*
* @param inputStream
* @return
* @throws SocketExceptions
*/
public static BasicProtocol readFromStream(InputStream inputStream) {
BasicProtocol protocol;
BufferedInputStream bis;
//header中保存的是整个数据的长度值,4个字节表示。在下述write2Stream方法中,会先写入header
byte[] header = new byte[BasicProtocol.LENGTH_LEN];
try {
bis = new BufferedInputStream(inputStream);
int temp;
int len = 0;
while (len < header.length) {
temp = bis.read(header, len, header.length - len);
if (temp > 0) {
len += temp;
} else if (temp == -1) {
bis.close();
return null;
}
}
len = 0;
int length = byteArrayToInt(header);//数据的长度值
byte[] content = new byte[length];
while (len < length) {
temp = bis.read(content, len, length - len);
if (temp > 0) {
len += temp;
}
}
protocol = parseContentMsg(content);
} catch (IOException e) {
e.printStackTrace();
return null;
}
return protocol;
}
/**
* 写数据
*
* @param protocol
* @param outputStream
*/
public static void write2Stream(BasicProtocol protocol, OutputStream outputStream) {
BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
byte[] buffData = protocol.genContentData();
byte[] header = int2ByteArrays(buffData.length);
try {
bufferedOutputStream.write(header);
bufferedOutputStream.write(buffData);
bufferedOutputStream.flush();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 关闭输入流
*
* @param is
*/
public static void closeInputStream(InputStream is) {
try {
if (is != null) {
is.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 关闭输出流
*
* @param os
*/
public static void closeOutputStream(OutputStream os) {
try {
if (os != null) {
os.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static byte[] int2ByteArrays(int i) {
byte[] result = new byte[4];
result[0] = (byte) ((i >> 24) & 0xFF);
result[1] = (byte) ((i >> 16) & 0xFF);
result[2] = (byte) ((i >> 8) & 0xFF);
result[3] = (byte) (i & 0xFF);
return result;
}
public static int byteArrayToInt(byte[] b) {
int intValue = 0;
for (int i = 0; i < b.length; i++) {
intValue += (b[i] & 0xFF) << (8 * (3 - i)); //int占4个字节(0,1,2,3)
}
return intValue;
}
public static int byteArrayToInt(byte[] b, int byteOffset, int byteCount) {
int intValue = 0;
for (int i = byteOffset; i < (byteOffset + byteCount); i++) {
intValue += (b[i] & 0xFF) << (8 * (3 - (i - byteOffset)));
}
return intValue;
}
public static int bytes2Int(byte[] b, int byteOffset) {
ByteBuffer byteBuffer = ByteBuffer.allocate(Integer.SIZE / Byte.SIZE);
byteBuffer.put(b, byteOffset, 4); //占4个字节
byteBuffer.flip();
return byteBuffer.getInt();
}
}
接下来我们实现具体的协议。
import android.util.Log;
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;
import java.io.ByteArrayOutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
/**
* Created by meishan on 16/12/1.
*/
public class DataProtocol extends BasicProtocol implements Serializable {
public static final int PROTOCOL_TYPE = 0;
private static final int PATTION_LEN = 1;
private static final int DTYPE_LEN = 1;
private static final int MSGID_LEN = 4;
private int pattion;
private int dtype;
private int msgId;
private String data;
@Override
public int getLength() {
return super.getLength() + PATTION_LEN + DTYPE_LEN + MSGID_LEN + data.getBytes().length;
}
@Override
public int getProtocolType() {
return PROTOCOL_TYPE;
}
public int getPattion() {
return pattion;
}
public void setPattion(int pattion) {
this.pattion = pattion;
}
public int getDtype() {
return dtype;
}
public void setDtype(int dtype) {
this.dtype = dtype;
}
public void setMsgId(int msgId) {
this.msgId = msgId;
}
public int getMsgId() {
return msgId;
}
public String getData() {
return data;
}
public void setData(String data) {
this.data = data;
}
/**
* 拼接发送数据
*
* @return
*/
@Override
public byte[] genContentData() {
byte[] base = super.genContentData();
byte[] pattion = {(byte) this.pattion};
byte[] dtype = {(byte) this.dtype};
byte[] msgid = SocketUtil.int2ByteArrays(this.msgId);
byte[] data = this.data.getBytes();
ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
baos.write(base, 0, base.length); //协议版本+数据类型+数据长度+消息id
baos.write(pattion, 0, PATTION_LEN); //业务类型
baos.write(dtype, 0, DTYPE_LEN); //业务数据格式
baos.write(msgid, 0, MSGID_LEN); //消息id
baos.write(data, 0, data.length); //业务数据
return baos.toByteArray();
}
/**
* 解析接收数据,按顺序解析
*
* @param data
* @return
* @throws ProtocolException
*/
@Override
public int parseContentData(byte[] data) throws ProtocolException {
int pos = super.parseContentData(data);
//解析pattion
pattion = data[pos] & 0xFF;
pos += PATTION_LEN;
//解析dtype
dtype = data[pos] & 0xFF;
pos += DTYPE_LEN;
//解析msgId
msgId = SocketUtil.byteArrayToInt(data, pos, MSGID_LEN);
pos += MSGID_LEN;
//解析data
try {
this.data = new String(data, pos, data.length - pos, "utf-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
return pos;
}
@Override
public String toString() {
return "data: " + data;
}
}
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;
import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;
/**
* Created by meishan on 16/12/1.
*/
public class DataAckProtocol extends BasicProtocol {
public static final int PROTOCOL_TYPE = 1;
private static final int ACKMSGID_LEN = 4;
private int ackMsgId;
private String unused;
@Override
public int getLength() {
return super.getLength() + ACKMSGID_LEN + unused.getBytes().length;
}
@Override
public int getProtocolType() {
return PROTOCOL_TYPE;
}
public int getAckMsgId() {
return ackMsgId;
}
public void setAckMsgId(int ackMsgId) {
this.ackMsgId = ackMsgId;
}
public String getUnused() {
return unused;
}
public void setUnused(String unused) {
this.unused = unused;
}
/**
* 拼接发送数据
*
* @return
*/
@Override
public byte[] genContentData() {
byte[] base = super.genContentData();
byte[] ackMsgId = SocketUtil.int2ByteArrays(this.ackMsgId);
byte[] unused = this.unused.getBytes();
ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
baos.write(base, 0, base.length); //协议版本+数据类型+数据长度+消息id
baos.write(ackMsgId, 0, ACKMSGID_LEN); //消息id
baos.write(unused, 0, unused.length); //unused
return baos.toByteArray();
}
@Override
public int parseContentData(byte[] data) throws ProtocolException {
int pos = super.parseContentData(data);
//解析ackMsgId
ackMsgId = SocketUtil.byteArrayToInt(data, pos, ACKMSGID_LEN);
pos += ACKMSGID_LEN;
//解析unused
try {
unused = new String(data, pos, data.length - pos, "utf-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
return pos;
}
}
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;
import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;
/**
* Created by meishan on 16/12/1.
*/
public class PingProtocol extends BasicProtocol {
public static final int PROTOCOL_TYPE = 2;
private static final int PINGID_LEN = 4;
private int pingId;
private String unused;
@Override
public int getLength() {
return super.getLength() + PINGID_LEN + unused.getBytes().length;
}
@Override
public int getProtocolType() {
return PROTOCOL_TYPE;
}
public int getPingId() {
return pingId;
}
public void setPingId(int pingId) {
this.pingId = pingId;
}
public String getUnused() {
return unused;
}
public void setUnused(String unused) {
this.unused = unused;
}
/**
* 拼接发送数据
*
* @return
*/
@Override
public byte[] genContentData() {
byte[] base = super.genContentData();
byte[] pingId = SocketUtil.int2ByteArrays(this.pingId);
byte[] unused = this.unused.getBytes();
ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
baos.write(base, 0, base.length); //协议版本+数据类型+数据长度+消息id
baos.write(pingId, 0, PINGID_LEN); //消息id
baos.write(unused, 0, unused.length); //unused
return baos.toByteArray();
}
@Override
public int parseContentData(byte[] data) throws ProtocolException {
int pos = super.parseContentData(data);
//解析pingId
pingId = SocketUtil.byteArrayToInt(data, pos, PINGID_LEN);
pos += PINGID_LEN;
try {
unused = new String(data, pos, data.length - pos, "utf-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
return pos;
}
}
import com.shandiangou.sdgprotocol.lib.ProtocolException;
import com.shandiangou.sdgprotocol.lib.SocketUtil;
import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;
/**
* Created by meishan on 16/12/1.
*/
public class PingAckProtocol extends BasicProtocol {
public static final int PROTOCOL_TYPE = 3;
private static final int ACKPINGID_LEN = 4;
private int ackPingId;
private String unused;
@Override
public int getLength() {
return super.getLength() + ACKPINGID_LEN + unused.getBytes().length;
}
@Override
public int getProtocolType() {
return PROTOCOL_TYPE;
}
public int getAckPingId() {
return ackPingId;
}
public void setAckPingId(int ackPingId) {
this.ackPingId = ackPingId;
}
public String getUnused() {
return unused;
}
public void setUnused(String unused) {
this.unused = unused;
}
/**
* 拼接发送数据
*
* @return
*/
@Override
public byte[] genContentData() {
byte[] base = super.genContentData();
byte[] ackPingId = SocketUtil.int2ByteArrays(this.ackPingId);
byte[] unused = this.unused.getBytes();
ByteArrayOutputStream baos = new ByteArrayOutputStream(getLength());
baos.write(base, 0, base.length); //协议版本+数据类型+数据长度+消息id
baos.write(ackPingId, 0, ACKPINGID_LEN); //消息id
baos.write(unused, 0, unused.length); //unused
return baos.toByteArray();
}
@Override
public int parseContentData(byte[] data) throws ProtocolException {
int pos = super.parseContentData(data);
//解析ackPingId
ackPingId = SocketUtil.byteArrayToInt(data, pos, ACKPINGID_LEN);
pos += ACKPINGID_LEN;
//解析unused
try {
unused = new String(data, pos, data.length - pos, "utf-8");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
return pos;
}
}
上述已经给出了四种协议的实现,接下来我们将使用它们来实现app和服务端之间的通信,这里我们把数据的发送、接收和心跳分别用一个线程去实现,具体如下:
import android.os.Handler;
import android.os.Looper;
import android.os.Message;
import android.util.Log;
import com.shandiangou.sdgprotocol.lib.protocol.BasicProtocol;
import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;
import com.shandiangou.sdgprotocol.lib.protocol.PingProtocol;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ConnectException;
import java.net.Socket;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.SocketFactory;
/**
* 写数据采用死循环,没有数据时wait,有新消息时notify
*
* Created by meishan on 16/12/1.
*/
public class ClientRequestTask implements Runnable {
private static final int SUCCESS = 100;
private static final int FAILED = -1;
private boolean isLongConnection = true;
private Handler mHandler;
private SendTask mSendTask;
private ReciveTask mReciveTask;
private HeartBeatTask mHeartBeatTask;
private Socket mSocket;
private boolean isSocketAvailable;
private boolean closeSendTask;
protected volatile ConcurrentLinkedQueue dataQueue = new ConcurrentLinkedQueue<>();
public ClientRequestTask(RequestCallBack requestCallBacks) {
mHandler = new MyHandler(requestCallBacks);
}
@Override
public void run() {
try {
try {
mSocket = SocketFactory.getDefault().createSocket(Config.ADDRESS, Config.PORT);
// mSocket.setSoTimeout(10);
} catch (ConnectException e) {
failedMessage(-1, "服务器连接异常,请检查网络");
return;
}
isSocketAvailable = true;
//开启接收线程
mReciveTask = new ReciveTask();
mReciveTask.inputStream = mSocket.getInputStream();
mReciveTask.start();
//开启发送线程
mSendTask = new SendTask();
mSendTask.outputStream = mSocket.getOutputStream();
mSendTask.start();
//开启心跳线程
if (isLongConnection) {
mHeartBeatTask = new HeartBeatTask();
mHeartBeatTask.outputStream = mSocket.getOutputStream();
mHeartBeatTask.start();
}
} catch (IOException e) {
failedMessage(-1, "网络发生异常,请稍后重试");
e.printStackTrace();
}
}
public void addRequest(DataProtocol data) {
dataQueue.add(data);
toNotifyAll(dataQueue);//有新增待发送数据,则唤醒发送线程
}
public synchronized void stop() {
//关闭接收线程
closeReciveTask();
//关闭发送线程
closeSendTask = true;
toNotifyAll(dataQueue);
//关闭心跳线程
closeHeartBeatTask();
//关闭socket
closeSocket();
//清除数据
clearData();
failedMessage(-1, "断开连接");
}
/**
* 关闭接收线程
*/
private void closeReciveTask() {
if (mReciveTask != null) {
mReciveTask.interrupt();
mReciveTask.isCancle = true;
if (mReciveTask.inputStream != null) {
try {
if (isSocketAvailable && !mSocket.isClosed() && mSocket.isConnected()) {
mSocket.shutdownInput();//解决java.net.SocketException问题,需要先shutdownInput
}
} catch (IOException e) {
e.printStackTrace();
}
SocketUtil.closeInputStream(mReciveTask.inputStream);
mReciveTask.inputStream = null;
}
mReciveTask = null;
}
}
/**
* 关闭发送线程
*/
private void closeSendTask() {
if (mSendTask != null) {
mSendTask.isCancle = true;
mSendTask.interrupt();
if (mSendTask.outputStream != null) {
synchronized (mSendTask.outputStream) {//防止写数据时停止,写完再停
SocketUtil.closeOutputStream(mSendTask.outputStream);
mSendTask.outputStream = null;
}
}
mSendTask = null;
}
}
/**
* 关闭心跳线程
*/
private void closeHeartBeatTask() {
if (mHeartBeatTask != null) {
mHeartBeatTask.isCancle = true;
if (mHeartBeatTask.outputStream != null) {
SocketUtil.closeOutputStream(mHeartBeatTask.outputStream);
mHeartBeatTask.outputStream = null;
}
mHeartBeatTask = null;
}
}
/**
* 关闭socket
*/
private void closeSocket() {
if (mSocket != null) {
try {
mSocket.close();
isSocketAvailable = false;
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 清除数据
*/
private void clearData() {
dataQueue.clear();
isLongConnection = false;
}
private void toWait(Object o) {
synchronized (o) {
try {
o.wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
/**
* notify()调用后,并不是马上就释放对象锁的,而是在相应的synchronized(){}语句块执行结束,自动释放锁后
*
* @param o
*/
protected void toNotifyAll(Object o) {
synchronized (o) {
o.notifyAll();
}
}
private void failedMessage(int code, String msg) {
Message message = mHandler.obtainMessage(FAILED);
message.what = FAILED;
message.arg1 = code;
message.obj = msg;
mHandler.sendMessage(message);
}
private void successMessage(BasicProtocol protocol) {
Message message = mHandler.obtainMessage(SUCCESS);
message.what = SUCCESS;
message.obj = protocol;
mHandler.sendMessage(message);
}
private boolean isConnected() {
if (mSocket.isClosed() || !mSocket.isConnected()) {
ClientRequestTask.this.stop();
return false;
}
return true;
}
/**
* 服务器返回处理,主线程运行
*/
public class MyHandler extends Handler {
private RequestCallBack mRequestCallBack;
public MyHandler(RequestCallBack callBack) {
super(Looper.getMainLooper());
this.mRequestCallBack = callBack;
}
@Override
public void handleMessage(Message msg) {
super.handleMessage(msg);
switch (msg.what) {
case SUCCESS:
mRequestCallBack.onSuccess((BasicProtocol) msg.obj);
break;
case FAILED:
mRequestCallBack.onFailed(msg.arg1, (String) msg.obj);
break;
default:
break;
}
}
}
/**
* 数据接收线程
*/
public class ReciveTask extends Thread {
private boolean isCancle = false;
private InputStream inputStream;
@Override
public void run() {
while (!isCancle) {
if (!isConnected()) {
break;
}
if (inputStream != null) {
BasicProtocol reciverData = SocketUtil.readFromStream(inputStream);
if (reciverData != null) {
if (reciverData.getProtocolType() == 1 || reciverData.getProtocolType() == 3) {
successMessage(reciverData);
}
} else {
break;
}
}
}
SocketUtil.closeInputStream(inputStream);//循环结束则退出输入流
}
}
/**
* 数据发送线程
* 当没有发送数据时让线程等待
*/
public class SendTask extends Thread {
private boolean isCancle = false;
private OutputStream outputStream;
@Override
public void run() {
while (!isCancle) {
if (!isConnected()) {
break;
}
BasicProtocol dataContent = dataQueue.poll();
if (dataContent == null) {
toWait(dataQueue);//没有发送数据则等待
if (closeSendTask) {
closeSendTask();//notify()调用后,并不是马上就释放对象锁的,所以在此处中断发送线程
}
} else if (outputStream != null) {
synchronized (outputStream) {
SocketUtil.write2Stream(dataContent, outputStream);
}
}
}
SocketUtil.closeOutputStream(outputStream);//循环结束则退出输出流
}
}
/**
* 心跳实现,频率5秒
* Created by meishan on 16/12/1.
*/
public class HeartBeatTask extends Thread {
private static final int REPEATTIME = 5000;
private boolean isCancle = false;
private OutputStream outputStream;
private int pingId;
@Override
public void run() {
pingId = 1;
while (!isCancle) {
if (!isConnected()) {
break;
}
try {
mSocket.sendUrgentData(0xFF);
} catch (IOException e) {
isSocketAvailable = false;
ClientRequestTask.this.stop();
break;
}
if (outputStream != null) {
PingProtocol pingProtocol = new PingProtocol();
pingProtocol.setPingId(pingId);
pingProtocol.setUnused("ping...");
SocketUtil.write2Stream(pingProtocol, outputStream);
pingId = pingId + 2;
}
try {
Thread.sleep(REPEATTIME);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
SocketUtil.closeOutputStream(outputStream);
}
}
}
其中涉及到的RequestCallBack接口如下:
/**
* Created by meishan on 16/12/1.
*/
public interface RequestCallBack {
void onSuccess(BasicProtocol msg);
void onFailed(int errorCode, String msg);
}
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.net.Socket;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
/**
* Created by meishan on 16/12/1.
*/
public class ServerResponseTask implements Runnable {
private ReciveTask reciveTask;
private SendTask sendTask;
private Socket socket;
private ResponseCallback tBack;
private volatile ConcurrentLinkedQueue dataQueue = new ConcurrentLinkedQueue<>();
private static ConcurrentHashMap onLineClient = new ConcurrentHashMap<>();
private String userIP;
public String getUserIP() {
return userIP;
}
public ServerResponseTask(Socket socket, ResponseCallback tBack) {
this.socket = socket;
this.tBack = tBack;
this.userIP = socket.getInetAddress().getHostAddress();
System.out.println("用户IP地址:" + userIP);
}
@Override
public void run() {
try {
//开启接收线程
reciveTask = new ReciveTask();
reciveTask.inputStream = new DataInputStream(socket.getInputStream());
reciveTask.start();
//开启发送线程
sendTask = new SendTask();
sendTask.outputStream = new DataOutputStream(socket.getOutputStream());
sendTask.start();
} catch (Exception e) {
e.printStackTrace();
}
}
public void stop() {
if (reciveTask != null) {
reciveTask.isCancle = true;
reciveTask.interrupt();
if (reciveTask.inputStream != null) {
SocketUtil.closeInputStream(reciveTask.inputStream);
reciveTask.inputStream = null;
}
reciveTask = null;
}
if (sendTask != null) {
sendTask.isCancle = true;
sendTask.interrupt();
if (sendTask.outputStream != null) {
synchronized (sendTask.outputStream) {//防止写数据时停止,写完再停
sendTask.outputStream = null;
}
}
sendTask = null;
}
}
public void addMessage(BasicProtocol data) {
if (!isConnected()) {
return;
}
dataQueue.offer(data);
toNotifyAll(dataQueue);//有新增待发送数据,则唤醒发送线程
}
public Socket getConnectdClient(String clientID) {
return onLineClient.get(clientID);
}
/**
* 打印已经链接的客户端
*/
public static void printAllClient() {
if (onLineClient == null) {
return;
}
Iterator inter = onLineClient.keySet().iterator();
while (inter.hasNext()) {
System.out.println("client:" + inter.next());
}
}
public void toWaitAll(Object o) {
synchronized (o) {
try {
o.wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public void toNotifyAll(Object obj) {
synchronized (obj) {
obj.notifyAll();
}
}
private boolean isConnected() {
if (socket.isClosed() || !socket.isConnected()) {
onLineClient.remove(userIP);
ServerResponseTask.this.stop();
System.out.println("socket closed...");
return false;
}
return true;
}
public class ReciveTask extends Thread {
private DataInputStream inputStream;
private boolean isCancle;
@Override
public void run() {
while (!isCancle) {
if (!isConnected()) {
isCancle = true;
break;
}
BasicProtocol clientData = SocketUtil.readFromStream(inputStream);
if (clientData != null) {
if (clientData.getProtocolType() == 0) {
System.out.println("dtype: " + ((DataProtocol) clientData).getDtype() + ", pattion: " + ((DataProtocol) clientData).getPattion() + ", msgId: " + ((DataProtocol) clientData).getMsgId() + ", data: " + ((DataProtocol) clientData).getData());
DataAckProtocol dataAck = new DataAckProtocol();
dataAck.setUnused("收到消息:" + ((DataProtocol) clientData).getData());
dataQueue.offer(dataAck);
toNotifyAll(dataQueue); //唤醒发送线程
tBack.targetIsOnline(userIP);
} else if (clientData.getProtocolType() == 2) {
System.out.println("pingId: " + ((PingProtocol) clientData).getPingId());
PingAckProtocol pingAck = new PingAckProtocol();
pingAck.setUnused("收到心跳");
dataQueue.offer(pingAck);
toNotifyAll(dataQueue); //唤醒发送线程
tBack.targetIsOnline(userIP);
}
} else {
System.out.println("client is offline...");
break;
}
}
SocketUtil.closeInputStream(inputStream);
}
}
public class SendTask extends Thread {
private DataOutputStream outputStream;
private boolean isCancle;
@Override
public void run() {
while (!isCancle) {
if (!isConnected()) {
isCancle = true;
break;
}
BasicProtocol procotol = dataQueue.poll();
if (procotol == null) {
toWaitAll(dataQueue);
} else if (outputStream != null) {
synchronized (outputStream) {
SocketUtil.write2Stream(procotol, outputStream);
}
}
}
SocketUtil.closeOutputStream(outputStream);
}
}
其中涉及到的ResponseCallback接口如下:
/**
* Created by meishan on 16/12/1.
*/
public interface ResponseCallback {
void targetIsOffline(DataProtocol reciveMsg);
void targetIsOnline(String clientIp);
}
上述代码中处理了几种情况下的异常,比如,建立连接后,服务端停止运行,此时客户端的输入流还在阻塞状态,怎么保证客户端不抛出异常,这些处理可以结合SocketUtil类来看。
import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;
/**
* Created by meishan on 16/12/1.
*/
public class ConnectionClient {
private boolean isClosed;
private ClientRequestTask mClientRequestTask;
public ConnectionClient(RequestCallBack requestCallBack) {
mClientRequestTask = new ClientRequestTask(requestCallBack);
new Thread(mClientRequestTask).start();
}
public void addNewRequest(DataProtocol data) {
if (mClientRequestTask != null && !isClosed)
mClientRequestTask.addRequest(data);
}
public void closeConnect() {
isClosed = true;
mClientRequestTask.stop();
}
}
import com.shandiangou.sdgprotocol.lib.protocol.DataProtocol;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Created by meishan on 16/12/1.
*/
public class ConnectionServer {
private static boolean isStart = true;
private static ServerResponseTask serverResponseTask;
public ConnectionServer() {
}
public static void main(String[] args) {
ServerSocket serverSocket = null;
ExecutorService executorService = Executors.newCachedThreadPool();
try {
serverSocket = new ServerSocket(Config.PORT);
while (isStart) {
Socket socket = serverSocket.accept();
serverResponseTask = new ServerResponseTask(socket,
new ResponseCallback() {
@Override
public void targetIsOffline(DataProtocol reciveMsg) {// 对方不在线
if (reciveMsg != null) {
System.out.println(reciveMsg.getData());
}
}
@Override
public void targetIsOnline(String clientIp) {
System.out.println(clientIp + " is onLine");
System.out.println("-----------------------------------------");
}
});
if (socket.isConnected()) {
executorService.execute(serverResponseTask);
}
}
serverSocket.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (serverSocket != null) {
try {
isStart = false;
serverSocket.close();
if (serverSocket != null)
serverResponseTask.stop();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
http://blog.csdn.net/mr_oorange/article/details/52353626
http://vtrtbb.iteye.com/blog/849336
http://blog.csdn.net/a8128230/article/details/46676149
http://www.tuicool.com/articles/6j2ARrB
http://weixiaolu.iteye.com/blog/1479656