Java沙箱实现是重写类加载器和安全管理器,通过设置的全局安全管理器来控制执行程序的权限
说明: 安全策略只对安装安全管理器之后的类生效,之前的类不再此管理范围之内,利用这一点可以预先设置我们需要的操作,而对某个点之后的所有非法操作进行权限设置.
类加载器重写
/** * [重写的类加载器] * 沙箱程序类加载器,可根据指定路径加载制定类class文件. * * [说明] * 仅包内可见 * * @author 刘金鑫 * @version 1.0 * */ package org.hljoj.core.judge.sandbox; import java.io.File; import java.io.FileInputStream; import org.hljoj.core.judge.util.ConstantParam; class SandboxClassLoader extends ClassLoader{ /**默认classPath*/ private String _classPath; /** * 构造函数 * @param classPath 类加载器默认classPath * */ public SandboxClassLoader(String classPath) { this._classPath = classPath; } @Override protected Class> findClass(String className) throws ClassNotFoundException { return loadClass(_classPath, className); } /** * 更改类加载器加载类的classpath,在制定路径下加载制定的类class文件 * @param classPath 要加载的类路径 * @param className 要加载的类名 * 最为限定,只能加载不含包的类. * */ public Class> loadClass(String classPath, String className) throws ClassNotFoundException{ if(className.indexOf('.') >= 0) { throw new ClassNotFoundException(className); } File classFile = new File(classPath + ConstantParam.SEPARATOR + className + ".class"); byte[] mainClass = new byte[(int) classFile.length()]; try { FileInputStream in = new FileInputStream(classFile); in.read(mainClass); in.close(); } catch (Exception e) { //e.printStackTrace(); throw new ClassNotFoundException(className); } return super.defineClass(className, mainClass, 0, mainClass.length); } /** * 获取classPath * @return String classPath * */ public String getClassPath(){ return _classPath + ConstantParam.SEPARATOR; } }
重写安全管理器
/** * [重写的安全管理器] * 安全管理器用来限制客户端提交的Java源程序运行的功能, * 对程序读/写本地文件系统,修改系统属性连接网络, * 数据库等一切可能对本地计算机系统造成危害的操作进行屏蔽, * 如有这些操作将抛出SecurityException异常,并终止程序执行. * * [说明]: * 仅包内可见 * 不允许提交的源程序执行exit(n)函数-即不允许源程序中途 * 终止虚拟机的运行,但是调用源代码端可执行exit(n)函数. * * @author 刘金鑫 * @version 1.0 * */ package org.hljoj.core.judge.sandbox; import java.io.FilePermission; import java.lang.reflect.ReflectPermission; import java.security.Permission; import java.security.SecurityPermission; import java.util.PropertyPermission; import org.hljoj.core.judge.util.ConstantParam; class SandboxSecurityManager extends SecurityManager { public static final int EXIT = ConstantParam.RANDOM.nextInt(); /** * 重写强行退出检测 * 防止用户自行终止虚拟机的运行,但是调用程序端可以执行退出 * */ public void checkExit(int status) { if (status != EXIT) throw new SecurityException("Exit On Client Is Not Allowed!"); } /** * 策略权限查看 * 当执行操作时调用,如果操作允许则返回,操作不允许抛出SecurityException * */ private void sandboxCheck(Permission perm) throws SecurityException { // 设置只读属性 if (perm instanceof SecurityPermission) { if (perm.getName().startsWith("getProperty")) { return; } } else if (perm instanceof PropertyPermission) { if (perm.getActions().equals("read")) { return; } } else if (perm instanceof FilePermission) { if (perm.getActions().equals("read")) { return; } } else if (perm instanceof RuntimePermission || perm instanceof ReflectPermission){ return; } throw new SecurityException(perm.toString()); } @Override public void checkPermission(Permission perm) { this.sandboxCheck(perm); } @Override public void checkPermission(Permission perm, Object context) { this.sandboxCheck(perm); } }
沙箱操作
/** * [沙箱] * 模拟沙箱功能,限制执行的Java程序的所有可能对本地机器的危害. * 被执行的Java程序对于本机只有读属性,其他文件操作,更改系统属性, * Socket网络连接,连接数据库等功能全部禁止. * * 沙箱与主模块之间使用Socket进行通信,完全独立于系统接收执行信 * 息类(包含要执行程序的相关信息)后执行,执行完毕后返回给主模块 * 结果类(包含执行结果的相关信息) * * @author 刘金鑫 * @version 1.0 * */ package org.hljoj.core.judge.sandbox; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.PrintStream; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.ServerSocket; import java.net.Socket; import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.Callable; import java.util.concurrent.FutureTask; import org.hljoj.core.judge.util.ConstantParam; import org.hljoj.core.judge.util.JudgeResult; public final class Sandbox { //用来收集线程的内存使用量 private static MemoryMXBean _memoryBean = ManagementFactory.getMemoryMXBean(); //定向输出 private static ByteArrayOutputStream _baos = new ByteArrayOutputStream(1024); //Socket通信 private static Socket _socket = null; private static ServerSocket _serverSocket = null; private static ObjectInputStream _inputStream = null; private static ObjectOutputStream _outputStream = null; //执行提交程序线程 private static Thread _thread = null; //系统提供的默认的classpath private static String _classPath = null; private static long _timeStart = 0; //程序运行时间 private static long _timeUsed = 0; private static int _baseMemory = 0; //程序运行空间 private static int _memoryUsed = 0; //测评结果 private static String _result = null; /** * 核心执行函数
* 用于执行指定Main.class的main函数.
* 根据线程ID可获取运行时间. * */ private static String process(int runId, final int timeLimit) throws Exception { FutureTask
* 若超时,则中断执行线程
* @return TimerTask * */ private static TimerTask getTimerTask(){ return new TimerTask(){ public void run() { if (_thread != null) _thread.interrupt(); } }; } private static int _outputSize = 0; /** * 初始化 * @param classPath 系统默认classPath路径 * @param port socket服务器监听端口 * */ private static void inita(String classPath, int port) throws Exception{ _classPath = classPath; _serverSocket = new ServerSocket(port); _socket = _serverSocket.accept(); _outputStream = new ObjectOutputStream(_socket.getOutputStream()); _inputStream = new ObjectInputStream(_socket.getInputStream()); //重新定向输出流 System.setOut(new PrintStream(new BufferedOutputStream(_baos) { public void write(byte[] b, int off, int len) throws IOException { _outputSize += len - off; try { super.write(b, off, len); if (_outputSize > ConstantParam.OUTPUT_LIMIT) throw new RuntimeException("Output Limit Exceed" + _outputSize); } catch (IOException e) { if(e.getMessage().equals("Output Limit Exceed")){ throw e; } } } })); } private static SandboxClassLoader _classLoader = null; /** * 获取指定路径下Main.class类的main入口函数. * @param runId 指定类路径 * @return Method 返回的main方法 * */ private static Method getMainMethod(int runId) throws Exception{ _classLoader = new SandboxClassLoader(_classPath); Class> targetClass = _classLoader.loadClass(_classLoader.getClassPath() + runId, "Main"); Method mainMethod = null; mainMethod = targetClass.getMethod("main", String[].class); if(!Modifier.isStatic(mainMethod.getModifiers())) throw new Exception("Method Of Main Is Not Static"); mainMethod.setAccessible(true); return mainMethod ; } /** * 测评接口. * 运行接收到的Java程序. * * @param runId 执行id * @param problemId 提交问题id * @param submitId 提交用户id * @param timeLimit 时间限制 * @param memoryLimit 空间限制 * @param standardInput 程序标准输入字符串 * @param standardOutput 程序标准输出字符串 * */ public static void run(int runId, int timeLimit, int memoryLimit, String standardInput, String standardOutput) { _timeUsed = 0; _memoryUsed = 0; _baos.reset(); _outputSize = 0; setResult(JudgeResult.WRONG_ANSWER); //定向输入流 System.setIn(new BufferedInputStream(new ByteArrayInputStream(standardInput.getBytes()))); String output = null; try { // 必须在执行前对垃圾回收,否则不准确. System.gc(); output = process(runId, timeLimit); // 将程序输出与标准输出作比较 setResult(matchOutput(standardOutput.getBytes(), output.getBytes())); // 获取程序运行时间和空间 _timeUsed = System.currentTimeMillis() - _timeStart; _memoryUsed = (int) ((_memoryBean.getHeapMemoryUsage().getUsed() - _baseMemory) / 1000); } catch (Exception e) { if (e.getMessage().equals("Initalization Error")){ setResult(JudgeResult.WRONG_ANSWER); } } if (_memoryUsed > memoryLimit) setResult(JudgeResult.MEMORY_LIMIT_EXCEED); try { //向主模块返回执行结果 sendResult(runId, (int)_timeUsed, _memoryUsed, _result); } catch (IOException e) { e.printStackTrace(); } } /** * 向主模块发送运行结果. * * @param runId 运行runId * @param timeUsed 代码运行时间(MS) * @param memoryUsed 代码运行空间(B) * @param result 代码执行结果 * */ private static void sendResult(int runId,int timeUsed, int memoryUsed, String result) throws IOException{ _outputStream.writeInt(runId); _outputStream.writeInt(timeUsed); _outputStream.writeInt(memoryUsed); _outputStream.writeUTF(result); } /** * 接收运行参数 * * @param runId 运行runId * @param timeLimit 限制代码运行时间(MS) * @param memoryLimit 限制代码运行空间(B) * @param standardInput 标准输入 * @param standardOutput 标准输出 * */ private static void receiveMsg() throws IOException{ int runId = _inputStream.readInt(); int timeLimit = _inputStream.readInt(); int memoryLimit = _inputStream.readInt(); String standardInput = _inputStream.readUTF(); String standardOutput = _inputStream.readUTF(); run(runId, timeLimit, memoryLimit, standardInput, standardOutput); } /** * 比较程序输出和标准输出,并返回比较结果.
* 标准返回1.Accepted 2.Wrong Answer 3.Presenting Error结果 * * @param standOutput 标准输出结果 * @param output 程序输出结果 * @return int 比较结果 * */ private static int matchOutput(byte[] standardOutput, byte[] output){ int i = 0; int j = 0; do{ while (i < standardOutput.length && (standardOutput[i] == ConstantParam.SPACE || standardOutput[i] == '/t' || standardOutput[i] == '/r' || standardOutput[i] == '/n')) i++; while (j < output.length && (output[j] == ConstantParam.SPACE || output[j] == '/t' || output[j] == '/r' || output[j] == '/n')) j++; if (i < standardOutput.length && j < output.length && standardOutput[i] != output[j]) return JudgeResult.WRONG_ANSWER; i++; j++; }while(j <= i && i < standardOutput.length && j < output.length); if (i != j) return JudgeResult.PRESENTING_ERROR; return JudgeResult.ACCEPTED; } /** * 设置测评结果. * @param JudgeResult结果类型. * */ private static void setResult(int resultType){ _result = JudgeResult.toString(resultType); } /** * 设置测评结果. * @param JudgeResult结果类型. * @param remark 测评结果的备注. * */ private static void setResult(int resultType, String remark){ setResult(resultType); if (remark.endsWith("StackOverflowError")) _result += "(" + JudgeResult.toString(JudgeResult.RUNTIME_STACK_OVERFLOW) + ")"; else if (remark.endsWith("/ by zero")) _result += "(" + JudgeResult.toString(JudgeResult.RUNTIME_DIVIDE_BY_ZERO) + ")"; else if (remark.contains("ArrayIndexOutOfBoundsException")) _result += "(" + JudgeResult.toString(JudgeResult.RUNTIME_ACCESS_VIOLATION) + ")"; else _result += "(" + JudgeResult.toString(JudgeResult.RUNTIME_ARRAY_BOUNDS_EXCEEDED) + ")"; } /** * 关闭网络连接 * */ private static void close(){ try { if (_inputStream != null) _inputStream.close(); if (_outputStream != null) _outputStream.close(); if (_socket != null) _socket.close(); } catch (IOException e) { e.printStackTrace(); } } /** * 沙盒入口 * 传入参数 :
* classPath -- args[0] ------ 保存class的classpath
* port -- args[1] ------- 监听端口 * */ public static void main(String[] args) throws Exception{ inita(args[0], Integer.parseInt(args[1])); SecurityManager security = System.getSecurityManager(); if (security == null) System.setSecurityManager(new SandboxSecurityManager()); while (!_socket.isClosed()){ receiveMsg(); } close(); } }