TTransport=>TIOStreamTransport=>TSocket
重要参数设置:
socket_.setSoLinger(false, 0);
socket_.setTcpNoDelay(true);
socket_.setSoTimeout(timeout_);//客户端读取超时时间
socket_.connect(new InetSocketAddress(host_, port_), timeout_);//客户端连接超时时间
public abstract class TTransport { //底层实现socket_.isConnected() public abstract boolean isOpen(); //底层实现socket_.isConnected() public boolean peek() { return isOpen(); } //socket_.connect(new InetSocketAddress(host_, port_), timeout_); public abstract void open() throws TTransportException; //socket_.close(); public abstract void close(); //读取len个字节到buf buf中起始位置off 返回实际读取的字节数 public abstract int read(byte[] buf, int off, int len) throws TTransportException; //读满len个字节到buf 起始位置off public int readAll(byte[] buf, int off, int len) throws TTransportException { int got = 0; int ret = 0; while (got < len) { ret = read(buf, off+got, len-got); if (ret <= 0) { throw new TTransportException( "Cannot read. Remote side has closed. Tried to read " + len + " bytes, but only got " + got + " bytes. (This is often indicative of an internal error on the server side. Please check your server logs.)"); } got += ret; } return got; } //写buf到流 public void write(byte[] buf) throws TTransportException { write(buf, 0, buf.length); } //写buf到流起始位置off 长度len public abstract void write(byte[] buf, int off, int len) throws TTransportException; //flush刷新 public void flush() throws TTransportException {} //下面基于NIO实现 //获取到buffer public byte[] getBuffer() { return null; } //获取buffer位置 public int getBufferPosition() { return 0; } //读取buffer中剩余的字节 public int getBytesRemainingInBuffer() { return -1; } public void consumeBuffer(int len) {} }
TTransport=>TNonblockingTransport=>TNonblockingSocket
重要参数:
Selector selector = SelectorProvider.provider().openSelector(); //创建选择器
SocketChannel socketChannel = SocketChannel.open();//创建ScoketChannel 同Socket 里面包装了Socket
socketChannel.configureBlocking(false);//设置为非阻塞
Socket socket = socketChannel.socket();//获取包装的Socket 设置底层的行为
socket.setSoLinger(false, 0);
socket.setTcpNoDelay(true);
setTimeout(timeout);
socketChannel_.register(selector,SelectionKey.OP_CONNECT);//注册连接事件
public abstract class TNonblockingTransport extends TTransport {
//底层实现socketChannel_.connect(socketAddress_)
public abstract boolean startConnect() throws IOException;
//底层实现socketChannel_.finishConnect()
public abstract boolean finishConnect() throws IOException;//完成后可以注册读.写事件
public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException;
public abstract int read(ByteBuffer buffer) throws IOException;
public abstract int write(ByteBuffer buffer) throws IOException;
}
NIO客户端Demo
package sampleNio; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.Iterator; /** * @author jason * */ public class NioClient implements Runnable { private InetAddress hostAddress; private int port; private Selector selector; private ByteBuffer readBuffer = ByteBuffer.allocate(8192); private ByteBuffer outBuffer = ByteBuffer.wrap("nice to meet you" .getBytes()); public NioClient(InetAddress hostAddress, int port) throws IOException { this.hostAddress = hostAddress; this.port = port; initSelector(); } public static void main(String[] args) { try { NioClient client = new NioClient( InetAddress.getByName("localhost"), 9090); new Thread(client).start(); } catch (IOException e) { e.printStackTrace(); } } @Override public void run() { while (true) { try { selector.select(); Iterator<?> selectedKeys = selector.selectedKeys().iterator(); while (selectedKeys.hasNext()) { SelectionKey key = (SelectionKey) selectedKeys.next(); selectedKeys.remove(); if (!key.isValid()) { continue; } if (key.isConnectable()) { finishConnection(key); } else if (key.isReadable()) { read(key); } else if (key.isWritable()) { write(key); } } } catch (Exception e) { e.printStackTrace(); } } } private void initSelector() throws IOException { // 创建一个selector selector = SelectorProvider.provider().openSelector(); // 打开SocketChannel SocketChannel socketChannel = SocketChannel.open(); // 设置为非阻塞 socketChannel.configureBlocking(false); // 连接指定IP和端口的地址 socketChannel .connect(new InetSocketAddress(this.hostAddress, this.port)); // 用selector注册套接字,并返回对应的SelectionKey,同时设置Key的interest set为监听服务端已建立连接的事件 socketChannel.register(selector, SelectionKey.OP_CONNECT); } private void finishConnection(SelectionKey key) throws IOException { SocketChannel socketChannel = (SocketChannel) key.channel(); try { // 判断连接是否建立成功,不成功会抛异常 socketChannel.finishConnect(); } catch (IOException e) { key.cancel(); return; } // 设置Key的interest set为OP_WRITE事件 key.interestOps(SelectionKey.OP_WRITE); } /** * 处理read * * @param key * @throws IOException */ private void read(SelectionKey key) throws IOException { SocketChannel socketChannel = (SocketChannel) key.channel(); readBuffer.clear(); int numRead; try { numRead = socketChannel.read(readBuffer); } catch (Exception e) { key.cancel(); socketChannel.close(); return; } if (numRead == 1) { System.out.println("close connection"); socketChannel.close(); key.cancel(); return; } // 处理响应 handleResponse(socketChannel, readBuffer.array(), numRead); } /** * 处理响应 * * @param socketChannel * @param data * @param numRead * @throws IOException */ private void handleResponse(SocketChannel socketChannel, byte[] data, int numRead) throws IOException { byte[] rspData = new byte[numRead]; System.arraycopy(data, 0, rspData, 0, numRead); System.out.println(new String(rspData)); socketChannel.close(); socketChannel.keyFor(selector).cancel(); } /** * 处理write * * @param key * @throws IOException */ private void write(SelectionKey key) throws IOException { SocketChannel socketChannel = (SocketChannel) key.channel(); socketChannel.write(outBuffer); if (outBuffer.remaining() > 0) { return; } // 设置Key的interest set为OP_READ事件 key.interestOps(SelectionKey.OP_READ); } }
TServerTransport=>TNonblockingServerTransport=>TNonblockingServerSocket
public abstract class TServerTransport { public abstract void listen() throws TTransportException; public final TTransport accept() throws TTransportException { TTransport transport = acceptImpl(); if (transport == null) { throw new TTransportException("accept() may not return NULL"); } return transport; } public abstract void close(); protected abstract TTransport acceptImpl() throws TTransportException; /** * Optional method implementation. This signals to the server transport * that it should break out of any accept() or listen() that it is currently * blocked on. This method, if implemented, MUST be thread safe, as it may * be called from a different thread context than the other TServerTransport * methods. */ public void interrupt() {} }
public abstract class TNonblockingServerTransport extends TServerTransport { public abstract void registerSelector(Selector selector); }
public class TNonblockingServerSocket extends TNonblockingServerTransport { private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingServerTransport.class.getName()); /** * This channel is where all the nonblocking magic happens. */ private ServerSocketChannel serverSocketChannel = null; /** * Underlying ServerSocket object */ private ServerSocket serverSocket_ = null; /** * Timeout for client sockets from accept */ private int clientTimeout_ = 0; /** * Creates just a port listening server socket */ public TNonblockingServerSocket(int port) throws TTransportException { this(port, 0); } /** * Creates just a port listening server socket */ public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException { this(new InetSocketAddress(port), clientTimeout); } public TNonblockingServerSocket(InetSocketAddress bindAddr) throws TTransportException { this(bindAddr, 0); } public TNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException { clientTimeout_ = clientTimeout; try { serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); // Make server socket serverSocket_ = serverSocketChannel.socket(); // Prevent 2MSL delay problem on server restarts serverSocket_.setReuseAddress(true); // Bind to listening port serverSocket_.bind(bindAddr); } catch (IOException ioe) { serverSocket_ = null; throw new TTransportException("Could not create ServerSocket on address " + bindAddr.toString() + "."); } } public void listen() throws TTransportException { // Make sure not to block on accept if (serverSocket_ != null) { try { serverSocket_.setSoTimeout(0); } catch (SocketException sx) { sx.printStackTrace(); } } } protected TNonblockingSocket acceptImpl() throws TTransportException { if (serverSocket_ == null) { throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); } try { SocketChannel socketChannel = serverSocketChannel.accept(); if (socketChannel == null) { return null; } TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel); tsocket.setTimeout(clientTimeout_); return tsocket; } catch (IOException iox) { throw new TTransportException(iox); } } //注册连接接收事件 public void registerSelector(Selector selector) { try { // Register the server socket channel, indicating an interest in // accepting new connections serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); } catch (ClosedChannelException e) { // this shouldn't happen, ideally... // TODO: decide what to do with this. } } public void close() { if (serverSocket_ != null) { try { serverSocket_.close(); } catch (IOException iox) { LOGGER.warn("WARNING: Could not close server socket: " + iox.getMessage()); } serverSocket_ = null; } } //可能存在线程安全问题 虽然java文档声称线程安全的 public void interrupt() { // The thread-safeness of this is dubious, but Java documentation suggests // that it is safe to do this from a different thread context close(); } }
TTransport=>TFramedTransport
//封装消息体之前带的帧大小4个字节。
public class TFramedTransport extends TTransport {
protected static final int DEFAULT_MAX_LENGTH = 16384000;
private int maxLength_;
/**
* Underlying transport
*/
private TTransport transport_ = null;
/**
* Buffer for output
*/
private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(
1024);
/**
* Buffer for input
*/
private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(
new byte[0]);
public static class Factory extends TTransportFactory {
private int maxLength_;
public Factory() {
maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;
}
public Factory(int maxLength) {
maxLength_ = maxLength;
}
@Override
public TTransport getTransport(TTransport base) {
return new TFramedTransport(base, maxLength_);
}
}
/**
* Constructor wraps around another transport
*/
public TFramedTransport(TTransport transport, int maxLength) {
transport_ = transport;
maxLength_ = maxLength;
}
public TFramedTransport(TTransport transport) {
transport_ = transport;
maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;
}
public void open() throws TTransportException {
transport_.open();
}
public boolean isOpen() {
return transport_.isOpen();
}
public void close() {
transport_.close();
}
public int read(byte[] buf, int off, int len) throws TTransportException {
if (readBuffer_ != null) {
int got = readBuffer_.read(buf, off, len);
if (got > 0) {
return got;
}
}
// Read another frame of data
readFrame();
return readBuffer_.read(buf, off, len);
}
@Override
public byte[] getBuffer() {
return readBuffer_.getBuffer();
}
@Override
public int getBufferPosition() {
return readBuffer_.getBufferPosition();
}
@Override
public int getBytesRemainingInBuffer() {
return readBuffer_.getBytesRemainingInBuffer();
}
@Override
public void consumeBuffer(int len) {
readBuffer_.consumeBuffer(len);
}
private final byte[] i32buf = new byte[4];
private void readFrame() throws TTransportException {
transport_.readAll(i32buf, 0, 4);
int size = decodeFrameSize(i32buf);
if (size < 0) {
throw new TTransportException("Read a negative frame size (" + size
+ ")!");
}
if (size > maxLength_) {
throw new TTransportException("Frame size (" + size
+ ") larger than max length (" + maxLength_ + ")!");
}
byte[] buff = new byte[size];
transport_.readAll(buff, 0, size);
readBuffer_.reset(buff);
}
public void write(byte[] buf, int off, int len) throws TTransportException {
writeBuffer_.write(buf, off, len);
}
@Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer_.get();
int len = writeBuffer_.len();
writeBuffer_.reset();
encodeFrameSize(len, i32buf);
transport_.write(i32buf, 0, 4);
transport_.write(buf, 0, len);
transport_.flush();
}
public static final void encodeFrameSize(final int frameSize,
final byte[] buf) {
buf[0] = (byte) (0xff & (frameSize >> 24));
buf[1] = (byte) (0xff & (frameSize >> 16));
buf[2] = (byte) (0xff & (frameSize >> 8));
buf[3] = (byte) (0xff & (frameSize));
}
public static final int decodeFrameSize(final byte[] buf) {
return ((buf[0] & 0xff) << 24) | ((buf[1] & 0xff) << 16)
| ((buf[2] & 0xff) << 8) | ((buf[3] & 0xff));
}
public static void main(String[] args) {
int number = 99999;
byte[] buf = new byte[4];
encodeFrameSize(number,buf);
System.out.println(decodeFrameSize(buf));
}
}