springboot启动项目自动动态加载数据库的groovy脚本

将groovy脚本保存在数据库中,页面支持动态增删改查,启动springboot项目时,从数据库中读取groovy配置表,然后编译脚本,项目中就可以直接调用使用脚本。

开发环境:springboot+MybatisPlus

脚本实体类:Func.java

package com.zhou.sct.dao;

import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;

import java.io.Serializable;
import java.util.Date;

/**
 * 勾稽自定义函数
 * @author lang.zhou
 * @since 2022/7/20 15:19
 */
@Data
@TableName("gv_func")
public class Func implements Serializable {
    private static final long serialVersionUID=1L;
    /**
     * 函数主键
     */
    @TableId(value = "ID")
    private String id;
    /**
     * 函数名
     */
    @TableField("FUNC_NAME")
    private String funcName;
    /**
     * 函数内容
     */
    @TableField("FUNC_BODY")
    private String funcBody;
    /**
     * 函数说明
     */
    @TableField("DESCRIPTION")
    private String description;
    /**
     * 是否可编辑
     */
    @TableField("EDITABLE")
    private Integer editable = 1;
    /**
     * 示例
     */
    @TableField("TEST_EXPRESS")
    private String testExpress ;

    /**
     * 函数分类
     */
    @TableField("CATALOG")
    private Integer catalog ;

    @TableField("CREATE_DT")
    private Date createDt ;

    @TableField("UPDATE_DT")
    private Date updateDt ;
}

创建springboot启动执行任务:GroovyApplicationRunner.java

import com.zhou.sct.service.FuncService;
import com.zhou.sct.util.GroovyUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

/**
 * @author lang.zhou
 * @since 2023/1/17 17:54
 */
@Component
@Slf4j
public class GroovyApplicationRunner implements ApplicationRunner {
    @Autowired
    private FuncService funcService;
    @Override
    public void run(ApplicationArguments args) throws Exception {
        //启动时预编译表达式
        GroovyUtil.loadDbFunc(funcService);
    }
}

GroovyUtil.java

package com.zhou.sct.util;

import com.zhou.sct.common.ScriptLoader;
import com.zhou.sct.dao.Func;
import com.zhou.sct.service.FuncService;
import lombok.extern.slf4j.Slf4j;
import org.codehaus.groovy.jsr223.GroovyScriptEngineImpl;

import javax.script.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * @author lang.zhou
 * @since  2022/7/20 15:51
 */
@Slf4j
public class GroovyUtil {

    /**
     * 大量动态计算表达式会导致内存溢出,这里缓存编译后的表达式,防止内存溢出
     */
    private static final Map scriptMap = new ConcurrentHashMap<>(100);
    
    private static GroovyScriptEngineImpl engine = ScriptLoader.createSecureScript();

    /**
     * 变量作用域,不同线程进行隔离,保证并发计算下不出错
     */
    private static Map bindingMap = new HashMap<>(1);

    /**
     * 编译函数并缓存
     */
    private static boolean cacheFunc(Func func) {
        try{
            CompiledScript script = getCompiledScript(func.getFuncBody());
            script.eval();
            return true;
        }catch (Exception e){
            //log.error("函数【{}】配置错误:{}",func.getFuncName(),e.getMessage());
            scriptMap.remove(func.getFuncBody());
            return false;
        }
    }

    /**
     * 将函数进行分组编译,由于函数之间可能存在相互依赖,把不依赖的函数进行优先编译,可以减少编译失败的次数
     */
    private static void loadFunctionGroup(List list, List failedList){
        List> funcListGroup = new ArrayList<>(4);
        //模型类函数
        List modelFuncList = list.stream().filter(a -> Objects.equals(a.getCatalog(), 4)).collect(Collectors.toList());
        log.info("模型类函数数量:{}",modelFuncList.size());
        //工具类函数
        List utilFuncList = list.stream().filter(a -> Objects.equals(a.getCatalog(), 3)).collect(Collectors.toList());
        log.info("工具类函数数量:{}",utilFuncList.size());
        //其他函数
        List otherFuncList = list.stream().filter(a -> !Objects.equals(a.getCatalog(), 4) && !Objects.equals(a.getCatalog(), 3)).collect(Collectors.toList());
        log.info("其他函数数量:{}",otherFuncList.size());
        funcListGroup.add(modelFuncList);
        funcListGroup.add(utilFuncList);
        funcListGroup.add(otherFuncList);
        for (List funcList : funcListGroup) {
            for (Func func : funcList) {
                boolean b = cacheFunc(func);
                if(!b){
                    failedList.add(func);
                }
            }
        }
    }

    /**
     * 编译数据库表中的函数,预编译表达式
     */
    public static void loadDbFunc(FuncService funcService) {
        log.info("====开始加载groovy脚本函数");
        //查询全部函数
        List funcList = funcService.list();
        //加载失败的函数
        List failedList = new ArrayList<>();
        //按顺序加载函数
        loadFunctionGroup(funcList,failedList);
        //函数加载受执行先后顺序的影响,将执行失败的函数进行重复执行
        if(failedList.size() > 0){
            List errorFuncList = cacheFailFunc(failedList);
            //将编译报错的函数打印出来
            if(errorFuncList != null && errorFuncList.size() > 0){
                for (Func func : errorFuncList) {
                    log.error("函数【{}】配置错误",func.getFuncName());
                }
            }
        }
        log.info("====加载groovy脚本函数完成");
    }
    /**
     * 将失败的函数脚本进行编译
     */
    private static List cacheFailFunc(List failedList){
        int n = failedList.size();
        for (Iterator iterator = failedList.iterator(); iterator.hasNext(); ) {
            Func func = iterator.next();
            boolean b = cacheFunc(func);
            if(b){
                iterator.remove();
            }
        }
        //全部执行成功或者没有新的函数执行成功则返回
        if(failedList.size() == 0 || failedList.size() == n){
            return null;
        }else{
            return cacheFailFunc(failedList);
        }
    }


    

    

    public static void main(String[] args) throws Exception{
        try{
           GroovyUtil.put("a",1);
           GroovyUtil.put("b",2); 
           System.out.println(GroovyUtil.eval("a+b"));
        }finally{
            //每次调用计算都要清空作用域
            GroovyUtil.clearScope();
        } 
    }

    /**
     * 对表达式进行编译和缓存
     */
    private static CompiledScript getCompiledScript(String expression) throws ScriptException {
        CompiledScript script = scriptMap.get(expression);
        if(script == null){
            script = ((Compilable) engine).compile(expression);
            scriptMap.put(expression,script);
        }
        return script;
    }

    /**
     * 根据当前线程得到引擎
     */
    public static Bindings getEngineBinding(){
        return bindingMap.computeIfAbsent(Thread.currentThread().getName(), k -> engine.createBindings());
    }
    /**
     * 根据当前线程得到引擎
     */
    public static void put(String k , Object v){
        getEngineBinding().put(k,v);
    }
    /**
     * 计算表达式
     */
    public static Object eval(String expression) throws Exception {
        return eval(expression,getEngineBinding());
    }
    /**
     * 计算表达式
     */
    public static Object eval(String expression,Bindings binding) throws Exception {
        CompiledScript script = getCompiledScript(expression);
        return script.eval(binding);
    }
    /**
     * 计算表达式得到布尔值
     */
    public static boolean valid(String expression) throws Exception {
        return valid(expression,getEngineBinding());
    }
    /**
     * 计算表达式得到布尔值
     */
    public static boolean valid(String expression, Bindings binding) throws Exception {
        return (boolean) eval(expression,binding);
    }
    
    /**
     * 将函数编译到一个新的脚本引擎(用于保存函数前的编译的校验)
     */
    private static void loadDbFunc(ScriptEngine se,FuncService funcService) {
        List funcs = funcService.list();
        List failList = new ArrayList<>();
        for (Func func : funcs) {
            try{
                se.eval(func.getFuncBody());
            }catch (Exception e){
                failList.add(func);
            }
        }
        if(failList.size() > 0){
            List errorFuncList = loadFailFunc(se,failList);
            if(errorFuncList != null && errorFuncList.size() > 0){
                for (Func func : errorFuncList) {
                    log.error("函数【{}】配置错误",func.getFuncName());
                }
            }
        }
    }
    private static List loadFailFunc(ScriptEngine se,List failedList){
        int n = failedList.size();
        for (Iterator iterator = failedList.iterator(); iterator.hasNext(); ) {
            Func func = iterator.next();
            try{
                se.eval(func.getFuncBody());
                iterator.remove();
            }catch (Exception e){
                //
            }
        }
        if(failedList.size() == 0 || failedList.size() == n){
            return null;
        }else{
            return loadFailFunc(se,failedList);
        }
    }

    /**
     * 测试一个自定义函数(保存函数时校验)
     */
    public static Object test(Func func) throws Exception {
        ScriptEngine se = ScriptLoader.createSecureScript();
        FuncService service = SpringFactory.getBean(FuncService.class);
        loadDbFunc(se,service);
        se.eval(func.getFuncBody());
        if(StringUtils.isNotBlank(func.getTestExpress())){
            return se.eval(func.getTestExpress());
        }
        return null;
    }

    /**
     * 加载一个自定义函数(用于函数修改后进行动态编译,使函数生效)
     */
    public static void load(Func func) throws ScriptException {
        scriptMap.remove(func.getFuncBody());
        CompiledScript script = getCompiledScript(func.getFuncBody());
        script.eval();
    }

    /**
     * 使用传入的参数执行函数(这里将参数名固定化,可避免参数命名不同而绕过缓存,产生过多的动态表达式计算)
     */
    public static Object execute(Func func, Object... args) throws Exception {
        Bindings binding = getEngineBinding();
        StringJoiner j = new StringJoiner(",");
        for (int i = 0; i < args.length; i++) {
            j.add("argv" + i);
            binding.put("argv" + i,args[i]);
        }
        String exp = func.getFuncName() + "(" + j.toString() + ")";

        CompiledScript script = getCompiledScript(exp);

        return script.eval(binding);
    }

    /**
     * 清空变量作用域(每次计算后必须调用)
     */
    public static void clearScope(){
        bindingMap.remove(Thread.currentThread().getName());
    }

ScriptLoader.java

package com.zhou.sct.common;

import com.zhou.sct.dao.Func;
import com.zhou.sct.service.FuncService;
import groovy.lang.GroovyClassLoader;
import lombok.extern.slf4j.Slf4j;
import org.codehaus.groovy.ast.stmt.Statement;
import org.codehaus.groovy.ast.stmt.WhileStatement;
import org.codehaus.groovy.control.CompilerConfiguration;
import org.codehaus.groovy.control.customizers.SecureASTCustomizer;
import org.codehaus.groovy.jsr223.GroovyScriptEngineImpl;
import org.codehaus.groovy.syntax.Types;

import javax.script.ScriptEngine;
import javax.script.ScriptEngineManager;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 脚本引擎创建
 * @author lang.zhou
 * @date 2022/7/20 15:48
 */
@Slf4j
public class ScriptLoader {

    private static final SecureASTCustomizer box = createSecureASTCustomizer();

    /**
     * 创建脚本安全运行沙盒
     */
    private static SecureASTCustomizer createSecureASTCustomizer() {
        SecureASTCustomizer box = new SecureASTCustomizer();
        //禁止闭包
        box.setClosuresAllowed(true);
        List blackList = new ArrayList<>(10);
        //blackList.add(Types.KEYWORD_WHILE);
        blackList.add(Types.KEYWORD_GOTO);
        box.setTokensBlacklist(blackList);
        //导入包检查
        box.setIndirectImportCheckEnabled(true);
        List list = new ArrayList<>(10);
        list.add("com.alibaba.fastjson.JSONObject");
        list.add("java.io.File");
        box.setImportsBlacklist(list);
        List> sl = new ArrayList<>();
        //不能使用while
        //sl.add(WhileStatement.class);
        box.setStatementsBlacklist(sl);
        return box;
    }

    /**
     * 脚本引擎加载自定义函数
     */
    public static void loadFunc(ScriptEngine engine){
        FuncService service = SpringFactory.getBean(FuncService.class);
        List funcs = service.list();
        List s = new ArrayList<>();
        for (Func func : funcs) {
            try{
                engine.eval(func.getFuncBody());
            }catch (Exception e){
                s.add(func.getFuncName());
            }
        }
        if(s.size() > 0){
            log.error("脚本函数加载失败:{}", Arrays.toString(s.toArray()));
        }
    }

    /**
     * 创建一个空的脚本引擎
     */
    public static GroovyScriptEngineImpl createBlankScript(){
        GroovyScriptEngineImpl engine = (GroovyScriptEngineImpl) new ScriptEngineManager().getEngineByName("groovy");

        return engine;
    }

    /**
     * 创建一个沙盒运行的脚本引擎
     */
    public static GroovyScriptEngineImpl createSecureScript(){
        CompilerConfiguration conf = new CompilerConfiguration();
        conf.addCompilationCustomizers(box);

        GroovyClassLoader loader = new GroovyClassLoader(ScriptLoader.class.getClassLoader(), conf);

        GroovyScriptEngineImpl engine = createBlankScript();

        engine.setClassLoader(loader);

        return engine;
    }
}

 调用方式:

/**
 * 使用动态表达式计算
 */
public void test() throws Exception{
    //用于计算的表达式
    String express = null;
    
    //用于计算的数据
    List> data = new ArrayList<>(0);
    for(Map map : data){
        try {
            Bindings bindings = GroovyUtil.getEngineBinding();
            bindings.putAll(map);
            Object o = GroovyUtil.eval(express, bindings);
        }finally {
            GroovyUtil.clearScope();
        }
    }
}

/**
 * 使用预编译函数进行计算
 */
public void test() throws Exception{
    //用于计算的函数
    Func func = null;
    //函数参数
    Object[] args = new Object[]{};

    try {
        Object o = GroovyUtil.execute(func, args);

    }finally {
        GroovyUtil.clearScope();
    }
}

你可能感兴趣的:(spring,boot,数据库,java)