Ver código fonte

支持在xml中调用java

mxd 5 anos atrás
pai
commit
d7dbca5101

+ 14 - 7
src/main/java/com/ssssssss/executor/RequestExecutor.java

@@ -4,7 +4,7 @@ import com.ssssssss.context.RequestContext;
 import com.ssssssss.expression.ExpressionEngine;
 import com.ssssssss.model.JsonBean;
 import com.ssssssss.session.Configuration;
-import com.ssssssss.session.SqlStatement;
+import com.ssssssss.session.Statement;
 import com.ssssssss.session.ValidateStatement;
 import com.ssssssss.session.XMLStatement;
 import com.ssssssss.utils.Assert;
@@ -75,14 +75,14 @@ public class RequestExecutor {
             if (pathVariables != null) {
                 requestContext.putAll(pathVariables);
             }
-            SqlStatement sqlStatement = configuration.getStatement(requestMapping);
+            Statement statement = configuration.getStatement(requestMapping);
             // 执行校验
-            Object value = validate(sqlStatement, requestContext);
+            Object value = validate(statement, requestContext);
             if (value != null) {
                 return value;
             }
             // 执行SQL
-            value = statementExecutor.execute(sqlStatement, requestContext);
+            value = statementExecutor.execute(statement, requestContext);
             return new JsonBean<>(value);
         } catch (Exception e) {
             logger.error("系统出现错误", e);
@@ -90,9 +90,16 @@ public class RequestExecutor {
         }
     }
 
-    private JsonBean<Void> validate(SqlStatement sqlStatement, RequestContext requestContext) {
-        List<String> validates = sqlStatement.getValidates();
-        XMLStatement xmlStatement = sqlStatement.getXmlStatement();
+    /**
+     * 验证节点
+     *
+     * @param statement
+     * @param requestContext
+     * @return
+     */
+    private JsonBean<Void> validate(Statement statement, RequestContext requestContext) {
+        List<String> validates = statement.getValidates();
+        XMLStatement xmlStatement = statement.getXmlStatement();
         for (String validateId : validates) {
             ValidateStatement validateStatement = xmlStatement.getValidateStatement(validateId);
             NodeList nodeList = validateStatement.getNodes();

+ 118 - 3
src/main/java/com/ssssssss/executor/StatementExecutor.java

@@ -3,11 +3,27 @@ package com.ssssssss.executor;
 import com.ssssssss.context.RequestContext;
 import com.ssssssss.dialect.Dialect;
 import com.ssssssss.dialect.DialectUtils;
+import com.ssssssss.expression.interpreter.AbstractReflection;
 import com.ssssssss.model.Page;
 import com.ssssssss.model.PageResult;
 import com.ssssssss.provider.PageProvider;
+import com.ssssssss.session.Configuration;
+import com.ssssssss.session.FunctionStatement;
 import com.ssssssss.session.SqlStatement;
+import com.ssssssss.session.Statement;
+import com.ssssssss.utils.Assert;
+import com.ssssssss.utils.DomUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.w3c.dom.Node;
+import org.w3c.dom.NodeList;
 
+import javax.xml.xpath.XPathConstants;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
 import java.sql.SQLException;
 
 /**
@@ -22,15 +38,114 @@ public class StatementExecutor {
      */
     private PageProvider pageProvider;
 
-    public StatementExecutor(SqlExecutor sqlExecutor, PageProvider pageProvider) {
+    private static Logger logger = LoggerFactory.getLogger(StatementExecutor.class);
+    private ApplicationContext applicationContext;
+    private Configuration configuration;
+
+    public StatementExecutor(SqlExecutor sqlExecutor, PageProvider pageProvider, ApplicationContext applicationContext) {
         this.sqlExecutor = sqlExecutor;
         this.pageProvider = pageProvider;
+        this.applicationContext = applicationContext;
+    }
+
+    public void setConfiguration(Configuration configuration) {
+        this.configuration = configuration;
+    }
+
+    /**
+     * 执行statement
+     */
+    public Object execute(Statement statement, RequestContext context) throws SQLException, ClassNotFoundException {
+        if (statement instanceof SqlStatement) {
+            return executeSqlStatement((SqlStatement) statement, context);
+        } else if (statement instanceof FunctionStatement) {
+            return executeFunctionStatement((FunctionStatement) statement, context);
+        }
+        return null;
+    }
+
+    private Object executeFunctionStatement(FunctionStatement functionStatement, RequestContext context) throws ClassNotFoundException, SQLException {
+        NodeList nodeList = functionStatement.getNodeList();
+        Object value = null;
+        for (int i = 0, len = nodeList.getLength(); i < len; i++) {
+            Node node = nodeList.item(i);
+            if (node.getNodeType() == Node.COMMENT_NODE) {
+                continue;
+            }
+            if ("java".equalsIgnoreCase(node.getNodeName())) {
+                // 解析类名和方法名
+                String className = DomUtils.getNodeAttributeValue(node, "class");
+                Assert.isNotBlank(className, "class不能为空!");
+                String method = DomUtils.getNodeAttributeValue(node, "method");
+                Assert.isNotBlank(method, "method不能为空!");
+                // 解析参数
+                NodeList values = (NodeList) DomUtils.evaluate("value", node, XPathConstants.NODESET);
+                Object[] args = new Object[0];
+                if (values != null) {
+                    // 取出参数值
+                    args = new Object[values.getLength()];
+                    for (int j = 0; j < args.length; j++) {
+                        // 解析表达式
+                        String expression = values.item(j).getTextContent();
+                        if (StringUtils.isNotBlank(expression)) {
+                            args[j] = context.evaluate(expression.trim());
+                        }
+                    }
+                }
+                // 调用java方法
+                value = executeJava(className, method, args);
+            } else if ("execute-sql".equalsIgnoreCase(node.getNodeName())) {
+                String sqlId = DomUtils.getNodeAttributeValue(node, "id");
+                Statement statement = configuration.getStatementById(sqlId);
+                Assert.isNotNull(statement, String.format("找不到SQL:%s", sqlId));
+                // 解析参数
+                NodeList params = (NodeList) DomUtils.evaluate("param", node, XPathConstants.NODESET);
+                for (int j = 0, l = params.getLength(); j < l; j++) {
+                    Node param = params.item(j);
+                    String paramName = DomUtils.getNodeAttributeValue(param, "name");
+                    String paramValue = DomUtils.getNodeAttributeValue(param, "value");
+                    Assert.isNotBlanks("execute-sql/param的参数名和值都不能为空", paramName, paramValue);
+                    // 重新覆盖值
+                    context.put(paramName, context.evaluate(paramValue));
+                }
+                //执行SQL
+                value = executeSqlStatement((SqlStatement) statement, context);
+            } else {
+                logger.warn("不支持节点{}", node.getNodeName());
+                continue;
+            }
+            // 设置返回值重置到context中
+            String returnVal = DomUtils.getNodeAttributeValue(node, "return");
+            if (StringUtils.isNotBlank(returnVal)) {
+                context.put(returnVal, value);
+            }
+        }
+        return value;
     }
 
     /**
-     * 执行SqlStatement
+     * 调用java方法
+     *
+     * @param className  类名
+     * @param methodName 方法名
+     * @param args       参数
+     * @return
+     * @throws ClassNotFoundException
      */
-    public Object execute(SqlStatement sqlStatement, RequestContext context) throws SQLException {
+    private Object executeJava(String className, String methodName, Object... args) throws ClassNotFoundException {
+        Class<?> clazz = Class.forName(className);
+        Object target = null;
+        AbstractReflection reflection = AbstractReflection.getInstance();
+        Method method = (Method) reflection.getMethod(clazz, methodName, args);
+        try {
+            target = applicationContext.getBean(clazz);
+        } catch (BeansException ignored) {
+            Assert.isTrue(Modifier.isStatic(method.getModifiers()), String.format("%s不在spring容器中时%s必须是静态方法", className, methodName));
+        }
+        return reflection.callMethod(target, method, args);
+    }
+
+    private Object executeSqlStatement(SqlStatement sqlStatement, RequestContext context) throws SQLException {
         // 获取要执行的SQL
         String sql = sqlStatement.getSqlNode().getSql(context).trim();
         if (sqlStatement.isPagination()) {  //判断是否是分页语句

+ 40 - 18
src/main/java/com/ssssssss/session/Configuration.java

@@ -49,34 +49,53 @@ public class Configuration implements InitializingBean {
     private boolean banner;
 
     /**
-     * 缓存已加载的SqlStatement
+     * 缓存已加载的statement(request-mapping映射)
      */
-    private Map<String,SqlStatement> statementMap = new ConcurrentHashMap<>();
+    private Map<String, Statement> statementMappingMap = new ConcurrentHashMap<>();
+
+    /**
+     * 缓存已加载的statement(ID映射)
+     */
+    private Map<String, Statement> statementIdMap = new ConcurrentHashMap<>();
 
     private static Logger logger = LoggerFactory.getLogger(Configuration.class);
 
     /**
-     * 根据RequestMapping获取SqlStatement对象
+     * 根据RequestMapping获取statement对象
+     */
+    public Statement getStatement(String requestMapping) {
+        return statementMappingMap.get(requestMapping);
+    }
+
+    /**
+     * 根据RequestMapping获取statement对象
      */
-    public SqlStatement getStatement(String requestMapping){
-        return statementMap.get(requestMapping);
+    public Statement getStatementById(String id) {
+        return statementIdMap.get(id);
     }
 
     /**
-     * 注册sql语句成接口,当已存在时,刷新其配置
+     * 注册Statement成接口,当已存在时,刷新其配置
      */
-    public void addStatement(SqlStatement sqlStatement){
-        RequestMappingInfo requestMappingInfo = getRequestMappingInfo(sqlStatement);
+    public void addStatement(Statement statement) {
+        RequestMappingInfo requestMappingInfo = getRequestMappingInfo(statement);
+        if (StringUtils.isNotBlank(statement.getId())) {
+            // 设置ID与statement的映射
+            statementIdMap.put(statement.getId(), statement);
+        }
+        if (requestMappingInfo == null) {
+            return;
+        }
         // 如果已经注册过,则先取消注册
-        if(statementMap.containsKey(sqlStatement.getRequestMapping())){
-            logger.debug("刷新接口:{}",sqlStatement.getRequestMapping());
+        if (statementMappingMap.containsKey(statement.getRequestMapping())) {
+            logger.debug("刷新接口:{}", statement.getRequestMapping());
             // 取消注册
             requestMappingHandlerMapping.unregisterMapping(requestMappingInfo);
         }else{
-            logger.debug("注册接口:{}",sqlStatement.getRequestMapping());
+            logger.debug("注册接口:{}", statement.getRequestMapping());
         }
         // 添加至缓存
-        statementMap.put(sqlStatement.getRequestMapping(),sqlStatement);
+        statementMappingMap.put(statement.getRequestMapping(), statement);
         // 注册接口
         requestMappingHandlerMapping.registerMapping(requestMappingInfo,requestHandler,requestHandleMethod);
     }
@@ -84,13 +103,15 @@ public class Configuration implements InitializingBean {
     /**
      * 获取RequestMappingInfo对象
      */
-    private RequestMappingInfo getRequestMappingInfo(SqlStatement sqlStatement){
-        String requestMapping = sqlStatement.getRequestMapping();
-        Assert.isNotBlank(requestMapping,"request-mapping 不能为空!");
+    private RequestMappingInfo getRequestMappingInfo(Statement statement) {
+        String requestMapping = statement.getRequestMapping();
+        if (StringUtils.isBlank(requestMapping)) {
+            return null;
+        }
         RequestMappingInfo.Builder builder = RequestMappingInfo.paths(requestMapping);
-        if(StringUtils.isNotBlank(sqlStatement.getRequestMethod())){
-            RequestMethod requestMethod = RequestMethod.valueOf(sqlStatement.getRequestMethod().toUpperCase());
-            Assert.isNotNull(requestMethod,String.format("不支持的请求方法:%s",sqlStatement.getRequestMethod()));
+        if (StringUtils.isNotBlank(statement.getRequestMethod())) {
+            RequestMethod requestMethod = RequestMethod.valueOf(statement.getRequestMethod().toUpperCase());
+            Assert.isNotNull(requestMethod, String.format("不支持的请求方法:%s", statement.getRequestMethod()));
             builder.methods(requestMethod);
         }
         return builder.build();
@@ -136,6 +157,7 @@ public class Configuration implements InitializingBean {
             loader.run();
             // 如果启动刷新则定时重新加载
             if(enableRefresh){
+                logger.info("启动自动刷新ssssssss");
                 Executors.newScheduledThreadPool(1).scheduleAtFixedRate(loader,3,3, TimeUnit.SECONDS);
             }
         }

+ 16 - 0
src/main/java/com/ssssssss/session/FunctionStatement.java

@@ -0,0 +1,16 @@
+package com.ssssssss.session;
+
+import org.w3c.dom.NodeList;
+
+public class FunctionStatement extends Statement {
+
+    private NodeList nodeList;
+
+    public NodeList getNodeList() {
+        return nodeList;
+    }
+
+    public void setNodeList(NodeList nodeList) {
+        this.nodeList = nodeList;
+    }
+}

+ 1 - 52
src/main/java/com/ssssssss/session/SqlStatement.java

@@ -3,20 +3,7 @@ package com.ssssssss.session;
 import com.ssssssss.enums.SqlMode;
 import com.ssssssss.scripts.SqlNode;
 
-import java.util.ArrayList;
-import java.util.List;
-
-public class SqlStatement {
-
-    /**
-     * 请求路径
-     */
-    private String requestMapping;
-
-    /**
-     * 请求方法
-     */
-    private String requestMethod;
+public class SqlStatement extends Statement {
 
     /**
      * SQL模式
@@ -38,29 +25,6 @@ public class SqlStatement {
      */
     private Class<?> returnType;
 
-    private List<String> validates = new ArrayList<>();
-
-    /**
-     * XMLStatement对象
-     */
-    private XMLStatement xmlStatement;
-
-    public String getRequestMapping() {
-        return requestMapping;
-    }
-
-    public void setRequestMapping(String requestMapping) {
-        this.requestMapping = requestMapping;
-    }
-
-    public String getRequestMethod() {
-        return requestMethod;
-    }
-
-    public void setRequestMethod(String requestMethod) {
-        this.requestMethod = requestMethod;
-    }
-
     public SqlMode getSqlMode() {
         return sqlMode;
     }
@@ -77,14 +41,6 @@ public class SqlStatement {
         this.sqlNode = sqlNode;
     }
 
-    public XMLStatement getXmlStatement() {
-        return xmlStatement;
-    }
-
-    public void setXmlStatement(XMLStatement xmlStatement) {
-        this.xmlStatement = xmlStatement;
-    }
-
     public Class<?> getReturnType() {
         return returnType;
     }
@@ -101,11 +57,4 @@ public class SqlStatement {
         this.pagination = pagination;
     }
 
-    public List<String> getValidates() {
-        return validates;
-    }
-
-    public void addValidate(String id) {
-        this.validates.add(id);
-    }
 }

+ 69 - 0
src/main/java/com/ssssssss/session/Statement.java

@@ -0,0 +1,69 @@
+package com.ssssssss.session;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class Statement {
+
+    /**
+     * ID
+     */
+    private String id;
+
+    /**
+     * 请求路径
+     */
+    private String requestMapping;
+
+    /**
+     * 请求方法
+     */
+    private String requestMethod;
+
+    private List<String> validates = new ArrayList<>();
+
+    /**
+     * XMLStatement对象
+     */
+    private XMLStatement xmlStatement;
+
+    public String getId() {
+        return id;
+    }
+
+    public void setId(String id) {
+        this.id = id;
+    }
+
+    public String getRequestMapping() {
+        return requestMapping;
+    }
+
+    public void setRequestMapping(String requestMapping) {
+        this.requestMapping = requestMapping;
+    }
+
+    public String getRequestMethod() {
+        return requestMethod;
+    }
+
+    public void setRequestMethod(String requestMethod) {
+        this.requestMethod = requestMethod;
+    }
+
+    public XMLStatement getXmlStatement() {
+        return xmlStatement;
+    }
+
+    public void setXmlStatement(XMLStatement xmlStatement) {
+        this.xmlStatement = xmlStatement;
+    }
+
+    public List<String> getValidates() {
+        return validates;
+    }
+
+    public void addValidate(String id) {
+        this.validates.add(id);
+    }
+}

+ 8 - 8
src/main/java/com/ssssssss/session/XMLStatement.java

@@ -21,9 +21,9 @@ public class XMLStatement {
     private Map<String, ValidateStatement> validateStatements = new HashMap<>();
 
     /**
-     * xml文件中sql语句包括select-list/select-one/insert/update/delete
+     * xml文件中function,以及sql语句包括select-list/select-one/insert/update/delete
      */
-    private List<SqlStatement> sqlStatements = new ArrayList<>();
+    private List<Statement> statements = new ArrayList<>();
 
     public String getRequestMapping() {
         return requestMapping;
@@ -33,17 +33,17 @@ public class XMLStatement {
         this.requestMapping = requestMapping;
     }
 
-    public List<SqlStatement> getSqlStatements() {
-        return sqlStatements;
+    public List<Statement> getStatements() {
+        return statements;
     }
 
     /**
-     * 添加一个SQL节点
+     * 添加statement
      *
-     * @param sqlStatements
+     * @param statements
      */
-    public void addSqlStatement(List<SqlStatement> sqlStatements) {
-        this.sqlStatements.addAll(sqlStatements);
+    public void addStatement(List<Statement> statements) {
+        this.statements.addAll(statements);
     }
 
     /**

+ 11 - 0
src/main/java/com/ssssssss/utils/Assert.java

@@ -33,4 +33,15 @@ public class Assert {
         }
     }
 
+    /**
+     * 断言值不能为空字符串
+     */
+    public static void isNotBlanks(String message, String... values) {
+        if (values != null) {
+            for (String value : values) {
+                isNotBlank(value, message);
+            }
+        }
+    }
+
 }

+ 36 - 11
src/main/java/com/ssssssss/utils/S8XMLFileParser.java

@@ -5,9 +5,7 @@ import com.ssssssss.scripts.ForeachSqlNode;
 import com.ssssssss.scripts.IfSqlNode;
 import com.ssssssss.scripts.SqlNode;
 import com.ssssssss.scripts.TextSqlNode;
-import com.ssssssss.session.SqlStatement;
-import com.ssssssss.session.ValidateStatement;
-import com.ssssssss.session.XMLStatement;
+import com.ssssssss.session.*;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.math.NumberUtils;
 import org.slf4j.Logger;
@@ -49,14 +47,38 @@ public class S8XMLFileParser {
             parseValidateStatement(document.getElementsByTagName("validate"), statement);
             // 解析select/insert/update/delete节点
             for (String tagName : TAG_NAMES) {
-                statement.addSqlStatement(parseSqlStatement(statement, tagName, document));
+                statement.addStatement(parseSqlStatement(statement, tagName, document));
             }
+            // 解析functionStatement
+            statement.addStatement(parseFunctionStatement(statement, document.getElementsByTagName("function")));
         } catch (SAXException | IOException | ParserConfigurationException e) {
             logger.error("解析S8XML文件出错", e);
         }
         return statement;
     }
 
+    private static List<Statement> parseFunctionStatement(XMLStatement statement, NodeList nodeList) {
+        List<Statement> statements = new ArrayList<>();
+        for (int i = 0, len = nodeList.getLength(); i < len; i++) {
+            Node node = nodeList.item(i);
+            FunctionStatement functionStatement = new FunctionStatement();
+            // 设置请求路径
+            functionStatement.setRequestMapping(DomUtils.getNodeAttributeValue(node, "request-mapping"));
+            // 设置请求方法
+            functionStatement.setRequestMethod(DomUtils.getNodeAttributeValue(node, "request-method"));
+            // 设置ID
+            functionStatement.setId(DomUtils.getNodeAttributeValue(node, "id"));
+
+            // TODO 这里后续需要改进
+            // 设置子节点,不进行深层解析,执行时在解析
+            functionStatement.setNodeList((NodeList) DomUtils.evaluate("*", node, XPathConstants.NODESET));
+
+            functionStatement.setXmlStatement(statement);
+            statements.add(functionStatement);
+        }
+        return statements;
+    }
+
     /**
      * 解析Validate节点
      */
@@ -85,12 +107,13 @@ public class S8XMLFileParser {
     /**
      * 解析节点
      */
-    private static List<SqlStatement> parseSqlStatement(XMLStatement xmlStatement, String tagName, Document document) {
-        List<SqlStatement> sqlStatements = new ArrayList<>();
+    private static List<Statement> parseSqlStatement(XMLStatement xmlStatement, String tagName, Document document) {
+        List<Statement> sqlStatements = new ArrayList<>();
         NodeList nodeList = document.getElementsByTagName(tagName);
         for (int i = 0, len = nodeList.getLength(); i < len; i++) {
             Node item = nodeList.item(i);
             SqlStatement sqlStatement = new SqlStatement();
+            sqlStatement.setId(DomUtils.getNodeAttributeValue(item, "id"));
             sqlStatement.setXmlStatement(xmlStatement);
             String validate = DomUtils.getNodeAttributeValue(item, "validate");
             if (StringUtils.isNotBlank(validate)) {
@@ -104,11 +127,13 @@ public class S8XMLFileParser {
             sqlStatement.setSqlMode(SqlMode.valueOf(item.getNodeName().toUpperCase().replace("-", "_")));
 
             String requestMapping = DomUtils.getNodeAttributeValue(item, "request-mapping");
-            Assert.isNotBlank(requestMapping, "请求方法不能为空!");
-            // 设置请求路径
-            sqlStatement.setRequestMapping(StringUtils.defaultString(xmlStatement.getRequestMapping()) + requestMapping);
-            // 设置请求方法
-            sqlStatement.setRequestMethod(DomUtils.getNodeAttributeValue(item, "request-method"));
+            String id = DomUtils.getNodeAttributeValue(item, "id");
+            if (StringUtils.isNotBlank(requestMapping)) {
+                // 设置请求路径
+                sqlStatement.setRequestMapping(StringUtils.defaultString(xmlStatement.getRequestMapping()) + requestMapping);
+                // 设置请求方法
+                sqlStatement.setRequestMethod(DomUtils.getNodeAttributeValue(item, "request-method"));
+            }
             String returnType = DomUtils.getNodeAttributeValue(item, "return-type");
             if ("int".equalsIgnoreCase(returnType)) {
                 sqlStatement.setReturnType(Integer.class);

+ 1 - 1
src/main/java/com/ssssssss/utils/XmlFileLoader.java

@@ -53,7 +53,7 @@ public class XmlFileLoader implements Runnable{
                     //判断是否更新
                     if (lastModified == null || lastModified < file.lastModified()) {
                         XMLStatement xmlStatement = S8XMLFileParser.parse(file);
-                        xmlStatement.getSqlStatements().forEach(configuration::addStatement);
+                        xmlStatement.getStatements().forEach(configuration::addStatement);
                     }
                 }
             }