【代码积累】NIO server

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;


public class NioServer {
	/*在构造时指定绑定的本端地址,*/
	private InetSocketAddress localAddress = null;
	private ServerSocketChannel serverChannel = null;
	private Selector connectionSelector = null; /*处理serverChannel的连接事件*/
	private Selector ioSelector = null;		/*专门处理IO*/
	private Thread task = null;
	private Thread service = null;
	private ThreadPoolExecutor connectionThreadpool = null;
	private ThreadPoolExecutor serviceThreadpool = null;
	private ConnectionPool connectionPool = new ConnectionPool();
	private CmdSetHandler cmdsetHandler = new CmdSetHandler();
	
	public NioServer(String localAddress,int localPort) {
		try {
			this.localAddress = new InetSocketAddress(InetAddress.getByName(localAddress), localPort);
		} catch (UnknownHostException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	private void startAcceptorTask() {
		ServerTask servertask = new ServerTask();
		task = new Thread(servertask);
		task.setName("ServerTask");
		task.start();
	}
	
	private void startServiceTask() {
		ServiceTask servicetask = new ServiceTask();
		service = new Thread(servicetask);
		service.setName("ServiceTask");
		service.start();
	}
	
	public void launch() {
		startAcceptorTask();
		startServiceTask();
		
		System.out.println("Server starting...");
	}
	
	private class ConnectionPool {
		private ConcurrentHashMap<String,SocketChannel> connectionPool = new ConcurrentHashMap<String,SocketChannel>();
				
		public void putConnection(String key,SocketChannel channel) {
			connectionPool.put(key, channel);
		}
		
		public void removeConnection(String key) {
			connectionPool.remove(key);
		}
	}
	
	private class CmdSetHandler {
		private LinkedBlockingQueue<CmdSetBase> cmdset = new LinkedBlockingQueue<CmdSetBase>();
		
		public void addCmd(CmdSetBase cmd) {
			try {
				cmdset.put(cmd);
			} catch (InterruptedException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		
		public void executeCmdSet() {
			Iterator<CmdSetBase> ite = cmdset.iterator();
			while(ite.hasNext()) {
				ite.next().execute();
				ite.remove();
			}
		}
	}
	
	private abstract class CmdSetBase {
		public abstract void execute();
	}
	
	private class CmdSetRegister extends CmdSetBase {
		private SocketChannel channel = null;
		private Selector selector = null;
		private int interestOps = 0;
		private Object attachment = null;
		
		public CmdSetRegister(SocketChannel channel, Selector selector,int interestOps,Object attachment) {
			super();
			this.channel = channel;
			this.selector = selector;
			this.interestOps = interestOps;
			this.attachment = attachment;
		}

		@Override
		public void execute() {
			// TODO Auto-generated method stub
			try {
				channel.register(selector, interestOps, attachment);
			} catch (ClosedChannelException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
	}
	
	private class ServerConnectionThreadFactory implements ThreadFactory {
		private int cnt = 0;
		
		@Override
		public Thread newThread(Runnable r) {
			// TODO Auto-generated method stub
			Thread acceptor = new Thread(r);
			acceptor.setName("ServerConnectionThread["+(++cnt)+"]");
			
			return acceptor;
		}
		
	}
	
	private class ServerServiceThreadFactory implements ThreadFactory {
		private int cnt = 0;
		
		@Override
		public Thread newThread(Runnable r) {
			// TODO Auto-generated method stub
			Thread serviceHandler = new Thread(r);
			serviceHandler.setName("ServerServiceThread["+(++cnt)+"]");
			
			return serviceHandler;
		}
		
	}
	
	private class RejectPolicy implements RejectedExecutionHandler {

		@Override
		public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
			// TODO Auto-generated method stub
			throw new RejectedExecutionException("Task " + r.toString() +
                    " rejected from " +
                    executor.toString());
		}
		
	}
	
	/*每个连接配置一个ServiceHandler,考虑到不同的连接可能会有数据独立性的需求,因此不用一个ServiceHandler对象分配多个线程,而是
	 * 每个连接的收发都新建一个ServiceHandler对象,然后分配给线程执行。*/
	private class ServiceHandler implements Runnable {
		private SelectionKey key = null;
		private Set<SelectionKey> keySet = null;
		private ByteBuffer sendbuffer = ByteBuffer.allocate(128);
		private ByteBuffer recvbuffer = ByteBuffer.allocate(128);

		public ServiceHandler() {
			super();
		}

		public ServiceHandler setKey(SelectionKey key,Set<SelectionKey> keySet) {
			this.key = key;
			this.keySet = keySet;
			return this;
		}
		
		private void readData() {
			try {
				recvbuffer.clear();
				SocketChannel channel = (SocketChannel) key.channel();
				int recvlen = channel.read(recvbuffer);
				if( recvlen > 0 ) {
					recvbuffer.flip();
					/*打印一下收到的数据*/
					System.out.println("Server:recv msg="+new String(recvbuffer.array(),0,recvlen));
					
					/*要实现一个echo server,则将收到数据回填到sendbuffer,然后加入WRITE事件的关注*/
					System.out.println("Server:send msg back to client.");
					sendbuffer.clear();
					sendbuffer.put(recvbuffer);
					sendbuffer.flip();
					key.interestOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
				} else if( recvlen < 0 ) {
					channel.close();
					key.cancel();
				}
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
				try {
					key.channel().close();
					key.cancel();
				} catch (IOException e1) {
					// TODO Auto-generated catch block
					e1.printStackTrace();
				}
			} finally {
				keySet.remove(key);
			}
			
		}
		
		private void writeData() {
			try {
				if( sendbuffer.remaining() > 0 ) {
					SocketChannel channel = (SocketChannel) key.channel();
					channel.write(sendbuffer);
					
					/*发送完毕后不再关注WRITE事件*/
					key.interestOps(SelectionKey.OP_READ);
				}
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			} finally {
				keySet.remove(key);
			}
		}
		
		@Override
		public void run() {
			// TODO Auto-generated method stub
			if( key.isWritable() ) {
				System.out.println("ServiceHandler:writable");
				writeData();				
			} else if( key.isReadable() ) {
				System.out.println("ServiceHandler:readable");
				readData();
			}
		}
		
	}
	
	private class ConnectionHandler implements Runnable {
		private SelectionKey key = null;
		private Set<SelectionKey> keyset = null;
		private SocketChannel clientChannel = null;
		
		public ConnectionHandler(SelectionKey key,Set<SelectionKey> keyset) {
			super();
			this.key = key;
			this.keyset = keyset;
		}
		
		@Override
		public void run() {
			// TODO Auto-generated method stub
			try {
				clientChannel = ((ServerSocketChannel)key.channel()).accept();
				if( null != clientChannel ) {
					clientChannel.configureBlocking(false); 
					InetSocketAddress remoteAddress = (InetSocketAddress)clientChannel.getRemoteAddress();
					String remoteIP = remoteAddress.getHostString();
					int remotePort = remoteAddress.getPort();
					System.out.println("Server:Accepted connection from IP="+remoteIP+":port="+remotePort);
				
					/*注册到service selector*/
					cmdsetHandler.addCmd(new CmdSetRegister(clientChannel,ioSelector,SelectionKey.OP_READ,new ServiceHandler()));
					ioSelector.wakeup(); /*让阻塞的selector立即返回,进而返回到循环的其实,然后执行 cmdsetHandler.executeCmdSet()  */
					
					/*加入到连接池*/
					String key = remoteIP+":"+remotePort;
					connectionPool.putConnection(key, clientChannel);
				}
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			} finally {
				keyset.remove(key); /*处理完后再删除,这样在处理过程中,selector会屏蔽掉该key的事件,不会再分配新的线程执行accept(重复的accept)*/
			}
		}
		
	}
	
	private class ServerTask implements Runnable {
		private void initServerChannel() {
			try {
				serverChannel = ServerSocketChannel.open();
				serverChannel.configureBlocking(false);
				serverChannel.bind(localAddress); /*default backlog is 5*/
				
				System.out.println("Server:bind IP="+localAddress.getHostName()+":port="+localAddress.getPort());
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		
		private void initSelector() {
			try {
				connectionSelector = Selector.open();
				ioSelector = Selector.open();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		
		private void startAccept() {
			try {
				serverChannel.register(connectionSelector, SelectionKey.OP_ACCEPT);
			} catch (ClosedChannelException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		
		private void initThreadPool() {
			connectionThreadpool = new ThreadPoolExecutor(3,
                    10,
                    20,
                    TimeUnit.SECONDS,
                    new LinkedBlockingQueue<Runnable>(100), /*使用有界队列,确保server不会被大量请求耗尽资源*/
                    new ServerConnectionThreadFactory(),
                    new RejectPolicy()) ;
			
			serviceThreadpool = new ThreadPoolExecutor(3,
                    10,
                    20,
                    TimeUnit.SECONDS,
                    new LinkedBlockingQueue<Runnable>(100), /*使用有界队列,确保server不会被大量请求耗尽资源*/
                    new ServerServiceThreadFactory(),
                    new RejectPolicy()) ;
		}
		
		private void init() {
			/*init thread-pool*/
			initThreadPool();
			/*init channel*/
			initServerChannel();
			/*init selector*/
			initSelector();
			/*start accepting connections*/
			startAccept();
		}
		
		private void handleKey(SelectionKey key,Set<SelectionKey> keyset) {
			if( true == key.isAcceptable() ) {
				System.out.println("Server:Acceptable");
				/*server启动监听后,如果没有连接到来,不会触发该事件。仅当有请求到来,server-socketchannel进入可接受请求的状态,或者有错误
				 * 发生的时候,才会触发此事件。*/
				connectionThreadpool.execute(new ConnectionHandler(key,keyset));
			} else {
				System.out.println("Server:Not handled process.");
			}
		}
		
		@Override
		public void run() {
			// TODO Auto-generated method stub
			/*initialization process*/
			init();
			
			while(!Thread.currentThread().isInterrupted()) {
				try {
					int num = connectionSelector.select();
					if( num > 0 ) {
						Set<SelectionKey> keyset = connectionSelector.selectedKeys();
						
						Iterator<SelectionKey> ite = keyset.iterator();
						while(ite.hasNext()) {
							handleKey(ite.next(),keyset);
							//ite.remove();
						}
					}
				} catch (IOException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
			}
		}
		
	}
	
	private class ServiceTask implements Runnable {

		private void handleKey(SelectionKey key,Set<SelectionKey> keySet) {
			/*service处理只有读/写,交给ServiceHandler去处理,这个task的循环中只做分发。*/
			serviceThreadpool.execute(((ServiceHandler)key.attachment()).setKey(key,keySet));
		}
		
		@Override
		public void run() {
			// TODO Auto-generated method stub
			while(!Thread.currentThread().isInterrupted() ) {
				if( null != ioSelector ) {
					try {
						cmdsetHandler.executeCmdSet();
						
						int num = ioSelector.select();
						if( num > 0 ) {
							System.out.println("ioSelector set number = "+num);
							Set<SelectionKey> keySet = ioSelector.selectedKeys();
							Iterator<SelectionKey> ite = keySet.iterator();
							while(ite.hasNext()) {
								handleKey(ite.next(),keySet);
								//ite.remove();
							}
						}
					} catch (IOException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					}
				}
			}
		}
	}
}

/*NIO几个大坑:
 * 1、每次获取一个新的SocketChannel,都需要设置其阻塞模式
 * 		无论是新open的还是通过accept方法返回的,都需要设置,否则默认是blocking模式。
 * 2、Selector.selectedKeys返回的key-set,处理完后需要clear
 * 		如果key不从key-set中删除,select会暂时屏蔽该key对应的channel的关注事件。
 * 3、Selector有两种select方式,分别是阻塞式与非阻塞式
 * 		通常用select或select(time),这是阻塞模式,而selectNow会立即返回,不阻塞;
 * 		selectNow的优先级比 wakeUp要高,如果调用wakeup的同时调用了selectNow,则立即执行后者,前者失效。
 * 4、select方法与register方法均是阻塞的
 * 		如果selector被阻塞在select方法上,此时另外一个线程尝试将新的channel注册到此selector,则会被阻塞,直到select返回,
 * 		这是一个巨大的坑,注册的线程会被阻塞在register方法上,而selector不会返回任何东西。
 * 5、OP_READ 事件
 * 		并不是数据到来时才触发这个事件,除了数据到来,还有:
 * 		corresponding channel is ready for reading
		has reached end-of-stream	
		has been remotely shut down for further reading
		has an error pending
		因此一旦抛异常,将channel.close掉,并将key.cancel掉。如果read返回-1,表示遇到EOS,则closechannel,或取消注册OP_READ
		EOS表示对端关闭了socket或者关闭了输出。
   6、问题
   		1、handle select如果采用分发线程的方式,则selectedKeys会重复的进入,导致产生多个功能一样的线程,如果规避这个问题,难道只能用循环处理,不能用异步的方式解决?
   		2、对月OP_WRITE以及OP_READ两个事件,需要更详细的说明
 * */

你可能感兴趣的:(【代码积累】NIO server)