NIO SSL Socket Server

 学习NIO和SSL的结合

参考文档:http://docs.oracle.com/javase/1.5.0/docs/guide/security/jsse/JSSERefGuide.html#KeyClasses

NIO中有socketChannel但是没有sslSocketChannel,据文档中说,如果要实现ssl的socketChannel会牵涉很多代码的实现,增加api的复杂程度;ssl的实现也不应该依赖于是NIO还是传统的基于stream的IO,应该给程序员自己进行组合的自由,因此就没有在java标准中提供相关的api,而留给程序员实现。为了支持nio实现ssl通讯,从java1.5开始增加了javax.net.ssl.SSLEngine用于支持nio的ssl,这个SSLEngine的作用就像一个状态机,维护着ssl通讯中各个状态以及下一个状态。

 

为了了解SSL通讯可以参考文档中的SSL通讯过程,主要包括握手,对话,关闭对话,三个步骤。其中握手部分的主要内容有协商协议,相互验证,生成并交换对称密钥,其中相互验证和对称密钥的交换是由非对称加密来完成的。在对话过程中,实际的明文是由之前生成的对称密钥来加密的。当对话结束后,互相发送结束信号结束通讯。

 

在上述过程中不仅仅是通讯双方简单的交换数据,更重要的是要根据SSL协议的要求,在特定的状态下发送或接受特定的数据,并且这些数据是经过处理的数据,也就是在tcp头和通讯正文之间还要包括一些ssl的信息,并且正文是由特定形式加密的。SSLEngine正是完成了管理状态,封装应用程序数据发往网络,解析网络数据并传递给应用程序的角色。

 

NIO SSL Socket Server_第1张图片

先看看SSLEngine有哪些状态,有哪些工作要做,假设从一个ssl server看。首先在握手阶段,需要和client程序多次握手,进行身份验证,对称密钥生成等工作,在这段时间里,并没有实际的应用层数据交换,而只有SSL协议数据的交换。且不看实际传输的内容和意义,在握手过程中,也就是上图中step1 到step13。Server的SSLEngine初始化后总是等待client的请求(等待接收数据),此时它的状态是NEED_UNWRAP,unwrap是解包的意思,这意味着,SSLEngine等待解析一个SSL的数据包,当server收到数据包后,在nio中数据包总是放在一个buffer里而不再是读stream,我们把这个buffer交给SSLEngine,调用它的unwrap方法,SSLEngine会解析这个数据包,把其中关于SSL握手的信息提取出来,并改变自己的状态,此处它将变成NEED_WRAP状态,意味着打包,它需要把对应的SSL回复内容写到数据包中返回到客户端,也就是step2-6中所作的事情。以此类推,SSLEngine多数时间总是在解包和打包两个状态间切换,尤其是在实际通讯时,注意到在unwrap和wrap函数中都有一个源buffer和一个目的buffer,因为SSLEngine不仅提取SSL协议相关的内容还要解密网络数据并把明文传递给应用程序,这其实才是这两个函数名字的来源,只不过在握手过程中,并没有实际的数据,而只有SSL协议信息,所以那个目的buffer总是没有东西。可以把SSL通讯看做交换礼物,SSLEngine把包裹拆了把礼物给你,或者他把礼物包起来送走,只是在SSL握手时,那个包裹里没有礼物,SSLEngine只是拆了个空包裹或是寄了个空包裹。那么还有没有其它状态,有一个FINISHED 状态那是在server端处于step13时所处的状态,表示这次handshake完成了;而当进入实际交换数据的时候,也就是step14的状态,这个状态是NOT_HANDSHAKE,表示当前不在握手,一般这个时候只需要在socket可读时,调用unwrap函数解密来自网络的SSL数据包,在socket可写的时候调用wrap函数把明文数据加密发送出去。还有一个状态NEED_TASK,首先要知道一点SSLEngine是异步的,wrap和unwrap函数调用都会立刻返回,比如在server收到client第一次请求后,会调用unwrap,但实际上SSLEngine还会做很多工作,比如访问Keystore文件,这些操作是费时的,但是实际上函数却立刻返回了,这时候SSLEngine会进入NEED_TASK状态,而不是立刻进入NEED_WRAP状态,所以必须让SSLEngine完成手头的工作,才能进入下一步工作,这时可以调用SSLEngine的getDelegatedTask()方法获得那个尚未完成的工作,它是一个Runnable的对象,可以调用它的run方法等待他完成,如果你是个高并发的server,也可以在这个时候做其他事情,等待这个工作完成,再接下去做wrap工作。另外还有一个非常容易出错的地方,一个NEED_UNWRAP状态的下一个状态然有可能是NEED_UNWRAP,并且一次调用unwrap方法并不一定把buffer中的所有内容都解包出来,可能还有内容需要在一次unwrap才能把所有东西都解析完,我遇到的这种情况发生在用nio的server和老的SSLSocket通讯时,在step7-11的过程中,client只向server一次性发送了这些数据,而server端需要连续两次unwrap才能把client的数据完整处理掉。

 

除了上述4个状态描述了SSLEngine的状态,还有4个状态用于描述每次调用wrap和unwrap后的结果状态。它们分别是BUFFER_OVERFLOW表示目标buffer没有足够的空间来存放解包的内容,这往往是因为你的目的buffer太小,或者在buffer在写入前没有clear;BUFFER_UNDERFLOW表示源buffer没足够内容让SSLEngine来解包,这往往是因为,可能还有数据尚未到达,或者在buffer读取前没有flip;CLOSED表示通讯的某一段正试图结束这个SSL通讯;OK,你懂的。

 

了解了SSLEngine的状态以及wrap和unwrap的原理,完成一个基于NIO的SSLsocket也就不会没想法了。

首先NIO的socket基本都通过Selector来实现,把socket 的accept,read,write事件都注册到selector上,不断的循环select()就可以,只是对于一个SSL Server Socket而言,它只是个普通的ServerSocket,首先只关心accept事件,所以首先这在selector上注册一个事件。

当serversocket接收到一个SSL client的请求后,就要开始进行握手,这个过程是同步的,所以先不要吧read和write事件也注册到selector上,当完成握手后,才注册这两个事件,并把socket设置成非阻塞。当select到socket可读时先调用unwrap方法,可写时先调用wrap方法。

每个socket都有两组buffer,分别是appIn,netIn和appOut,netOut,其中netXX都代表从socket中读取或写入的东西,他们都是加了密的,而appXX代表应用程序可理解的数据内容,它们都通过SSLEngine的wrap和unwrap方法才能与netXX相互转换。

 

粘代码

package com.red.nio.ssl;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.Iterator;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;

public class SSLNewServer {

	private static boolean logging = true;
	
	private boolean handshakeDone = false;
	
	private Selector selector;
	private SSLEngine sslEngine;
	private SSLContext sslContext;
	
	private ByteBuffer appOut; // clear text buffer for out
	private ByteBuffer appIn; // clear text buffer for in
	private ByteBuffer netOut; // encrypted buffer for out
	private ByteBuffer netIn; // encrypted buffer for in

	private CharsetEncoder encoder = Charset.forName("UTF8").newEncoder();
	private CharsetDecoder decoder = Charset.forName("UTF8").newDecoder();
	
	public SSLNewServer() {
		try
		{
			createServerSocket();
		} catch (IOException e)
		{
			System.out.println("initializing server failed");
			e.printStackTrace();
		}
		
		try
		{
			createSSLContext();
		} catch (GeneralSecurityException e)
		{
			System.out.println("initializing SSL context failed");
			e.printStackTrace();
		} catch (IOException e)
		{
			System.out.println("reading keystore or truststore file failed");
			e.printStackTrace();
		}
		
		createSSLEngines();
		createBuffers();
	}
	
	private void createBuffers()
	{
		SSLSession session = sslEngine.getSession();
		int appBufferMax = session.getApplicationBufferSize();
		int netBufferMax = session.getPacketBufferSize();
		
		appOut = ByteBuffer.wrap("This is an SSL Server".getBytes());//server only reply this sentence 
		appIn = ByteBuffer.allocate(appBufferMax + 10);//appIn is bigger than the allowed max application buffer siz
		netOut = ByteBuffer.allocateDirect(netBufferMax);//direct allocate for better performance
		netIn = ByteBuffer.allocateDirect(netBufferMax);
	}

	//the ssl context initialization
	private void createSSLContext() throws GeneralSecurityException, FileNotFoundException, IOException
	{
		KeyStore ks = KeyStore.getInstance("JKS");
		KeyStore ts = KeyStore.getInstance("JKS");

		char[] passphrase = "123456".toCharArray();

		ks.load(new FileInputStream("ssl/kserver.keystore"), passphrase);
		ts.load(new FileInputStream("ssl/tserver.keystore"), passphrase);

		KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
		kmf.init(ks, passphrase);

		TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
		tmf.init(ts);

		SSLContext sslCtx = SSLContext.getInstance("SSL");

		sslCtx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);

		sslContext = sslCtx;
		
	}

	//create the server socket, bind it to port 1234, set unblock and register the "accept" only
	private void createServerSocket() throws IOException
	{
		selector = Selector.open();
		ServerSocketChannel ssc = ServerSocketChannel.open();
		ssc.socket().bind(new InetSocketAddress(1234));
		ssc.configureBlocking(false);
		ssc.register(selector, SelectionKey.OP_ACCEPT);
	}

	private void createSSLEngines() 
	{
		sslEngine = sslContext.createSSLEngine();
		sslEngine.setUseClientMode(false);//work in a server mode
		sslEngine.setNeedClientAuth(true);//need client authentication
	}
	
	public void selecting() {
		while (true)
		{
			try
			{
				selector.select();
			} catch (IOException e)
			{
				e.printStackTrace();
			}
			Iterator<SelectionKey> iter = selector.selectedKeys().iterator();
			while (iter.hasNext())
			{
				SelectionKey key = (SelectionKey) iter.next();
				iter.remove();
				try
				{
					handle(key);
				} catch (SSLException e)
				{
					// TODO Auto-generated catch block
					e.printStackTrace();
				} catch (IOException e)
				{
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
			}
		}
	}

	private void handle(SelectionKey key) throws IOException
	{
		if(key.isAcceptable()) {
			
			try
			{
				SocketChannel sc = ((ServerSocketChannel)key.channel()).accept();
				doHandShake(sc);//if it is an accept event, do the handshake in a blocking mode
			} catch (ClosedChannelException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		else if(key.isReadable()) {
			if (sslEngine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING)
			{
				SocketChannel sc = (SocketChannel) key.channel();
				sc.read(netIn);
				netIn.flip();
				
				SSLEngineResult engineResult = sslEngine.unwrap(netIn, appIn);
				log("server unwrap: ", engineResult);
				doTask();
				//runDelegatedTasks(engineResult, sslEngine);
				netIn.compact();
				if (engineResult.getStatus() == SSLEngineResult.Status.OK)
				{
					System.out.println("text recieved");
					appIn.flip();// ready for reading
					System.out.println(decoder.decode(appIn));
					appIn.compact();
				}
				else if(engineResult.getStatus() == SSLEngineResult.Status.CLOSED) {
					doSSLClose(key);
				}

			}

		}
		else if(key.isWritable()) {
			SocketChannel sc = (SocketChannel) key.channel();
			//if(!sslEngine.isOutboundDone()) {
				//netOut.clear();
			SSLEngineResult engineResult = sslEngine.wrap(appOut, netOut);
			log("server wrap: ", engineResult);
			doTask();
			//runDelegatedTasks(engineResult, sslEngine);
			if (engineResult.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING)
			{
				System.out.println("text sent");
			}
			netOut.flip();
			sc.write(netOut);
			netOut.compact();
			//}
		}
		
	}

	/*public static HandshakeStatus runDelegatedTasks(SSLEngineResult engineResult, SSLEngine sslEngine)
	{
		if (engineResult.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
		{
			Runnable runnable;
			while ((runnable = sslEngine.getDelegatedTask()) != null)
			{
				System.out.println("\trunning delegated task...");
				runnable.run();
			}
			HandshakeStatus hsStatus = sslEngine.getHandshakeStatus();
			if (hsStatus == HandshakeStatus.NEED_TASK)
			{
				//throw new Exception("handshake shouldn't need additional tasks");
				System.out.println("handshake shouldn't need additional tasks");
			}
			System.out.println("\tnew HandshakeStatus: " + hsStatus);
		}
		return sslEngine.getHandshakeStatus();
		
	}*/

	/*
	 * Logging code
	 */
	private static boolean resultOnce = true;

	public static void log(String str, SSLEngineResult result)
	{
		if (!logging)
		{
			return;
		}
		if (resultOnce)
		{
			resultOnce = false;
			System.out.println("The format of the SSLEngineResult is: \n"
					+ "\t\"getStatus() / getHandshakeStatus()\" +\n"
					+ "\t\"bytesConsumed() / bytesProduced()\"\n");
		}
		HandshakeStatus hsStatus = result.getHandshakeStatus();
		log(str + result.getStatus() + "/" + hsStatus + ", " + result.bytesConsumed() + "/"
				+ result.bytesProduced() + " bytes");
		if (hsStatus == HandshakeStatus.FINISHED)
		{
			log("\t...ready for application data");
		}
	}

	public static void log(String str)
	{
		if (logging)
		{
			System.out.println(str);
		}
	}
	
	
	private void doHandShake(SocketChannel sc) throws IOException
	{
		
		sslEngine.beginHandshake();//explicitly begin the handshake
		HandshakeStatus hsStatus = sslEngine.getHandshakeStatus();
		while (!handshakeDone)
		{
			switch(hsStatus){
				case FINISHED:
					//the status become FINISHED only when the ssl handshake is finished
					//but we still need to send data, so do nothing here
					break;
				case NEED_TASK:
					//do the delegate task if there is some extra work such as checking the keystore during the handshake
					hsStatus = doTask();
					break;
				case NEED_UNWRAP:
					//unwrap means unwrap the ssl packet to get ssl handshake information
					sc.read(netIn);
					netIn.flip();
					hsStatus = doUnwrap();
					break;
				case NEED_WRAP:
					//wrap means wrap the app packet into an ssl packet to add ssl handshake information
					hsStatus = doWrap();
					sc.write(netOut);
					netOut.clear();
					break;
				case NOT_HANDSHAKING:
					//now it is not in a handshake or say byebye status. here it means handshake is over and ready for ssl talk
					sc.configureBlocking(false);//set the socket to unblocking mode
					sc.register(selector, SelectionKey.OP_READ|SelectionKey.OP_WRITE);//register the read and write event
					handshakeDone = true;
					break;
			}
		}
		
	}
	
	private HandshakeStatus doTask() {
		Runnable runnable;
		while ((runnable = sslEngine.getDelegatedTask()) != null)
		{
			System.out.println("\trunning delegated task...");
			runnable.run();
		}
		HandshakeStatus hsStatus = sslEngine.getHandshakeStatus();
		if (hsStatus == HandshakeStatus.NEED_TASK)
		{
			//throw new Exception("handshake shouldn't need additional tasks");
			System.out.println("handshake shouldn't need additional tasks");
		}
		System.out.println("\tnew HandshakeStatus: " + hsStatus);
		
		return hsStatus;
	}
	
	private HandshakeStatus doUnwrap() throws SSLException{
		HandshakeStatus hsStatus;
		do{//do unwrap until the state is change to "NEED_WRAP"
			SSLEngineResult engineResult = sslEngine.unwrap(netIn, appIn);
			log("server unwrap: ", engineResult);
			hsStatus = doTask();
		}while(hsStatus ==  SSLEngineResult.HandshakeStatus.NEED_UNWRAP && netIn.remaining()>0);
		System.out.println("\tnew HandshakeStatus: " + hsStatus);
		netIn.clear();
		return hsStatus;
	}
	
	private HandshakeStatus doWrap() throws SSLException{
		HandshakeStatus hsStatus;
		SSLEngineResult engineResult = sslEngine.wrap(appOut, netOut);
		log("server wrap: ", engineResult);
		hsStatus = doTask();
		System.out.println("\tnew HandshakeStatus: " + hsStatus);
		netOut.flip();
		return hsStatus;
	}
	
	//close an ssl talk, similar to the handshake steps
	private void doSSLClose(SelectionKey key) throws IOException {
		SocketChannel sc = (SocketChannel) key.channel();
		key.cancel();
		
		try
		{
			sc.configureBlocking(true);
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		HandshakeStatus hsStatus = sslEngine.getHandshakeStatus();
		while(handshakeDone) {
			switch(hsStatus) {
			case FINISHED:
				
				break;
			case NEED_TASK:
				hsStatus = doTask();
				break;
			case NEED_UNWRAP:
				sc.read(netIn);
				netIn.flip();
				hsStatus = doUnwrap();
				break;
			case NEED_WRAP:
				hsStatus = doWrap();
				sc.write(netOut);
				netOut.clear();
				break;
			case NOT_HANDSHAKING:
				handshakeDone = false;
				sc.close();
				break;
			}
		}
	}
	
	
	
	public static void main(String[] args) {
		SSLNewServer sns = new SSLNewServer();
		sns.selecting();
	}
}


 

你可能感兴趣的:(server,socket,ssl,buffer,logging,通讯)