StatementExecutor.java 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package org.ssssssss.executor;
  2. import org.apache.commons.lang3.StringUtils;
  3. import org.slf4j.Logger;
  4. import org.slf4j.LoggerFactory;
  5. import org.springframework.beans.BeansException;
  6. import org.springframework.context.ApplicationContext;
  7. import org.ssssssss.context.RequestContext;
  8. import org.ssssssss.dialect.Dialect;
  9. import org.ssssssss.enums.SqlMode;
  10. import org.ssssssss.expression.interpreter.AbstractReflection;
  11. import org.ssssssss.model.Page;
  12. import org.ssssssss.model.PageResult;
  13. import org.ssssssss.provider.PageProvider;
  14. import org.ssssssss.session.*;
  15. import org.ssssssss.utils.Assert;
  16. import org.ssssssss.utils.DomUtils;
  17. import org.w3c.dom.Node;
  18. import org.w3c.dom.NodeList;
  19. import javax.xml.xpath.XPathConstants;
  20. import java.lang.reflect.Method;
  21. import java.lang.reflect.Modifier;
  22. import java.sql.SQLException;
  23. import java.util.List;
  24. /**
  25. * SqlStatement执行器
  26. */
  27. public class StatementExecutor {
  28. private SqlExecutor sqlExecutor;
  29. /**
  30. * 分页提取器
  31. */
  32. private PageProvider pageProvider;
  33. private static Logger logger = LoggerFactory.getLogger(StatementExecutor.class);
  34. private ApplicationContext applicationContext;
  35. private Configuration configuration;
  36. public StatementExecutor(SqlExecutor sqlExecutor, PageProvider pageProvider, ApplicationContext applicationContext) {
  37. this.sqlExecutor = sqlExecutor;
  38. this.pageProvider = pageProvider;
  39. this.applicationContext = applicationContext;
  40. }
  41. public void setConfiguration(Configuration configuration) {
  42. this.configuration = configuration;
  43. }
  44. /**
  45. * 执行statement
  46. */
  47. public Object execute(Statement statement, RequestContext context) throws SQLException, ClassNotFoundException {
  48. if (statement instanceof SqlStatement) {
  49. return executeSqlStatement((SqlStatement) statement, context);
  50. } else if (statement instanceof FunctionStatement) {
  51. return executeFunctionStatement((FunctionStatement) statement, context);
  52. }
  53. return null;
  54. }
  55. private Object executeFunctionStatement(FunctionStatement functionStatement, RequestContext context) throws ClassNotFoundException, SQLException {
  56. NodeList nodeList = functionStatement.getNodeList();
  57. Object value = null;
  58. for (int i = 0, len = nodeList.getLength(); i < len; i++) {
  59. Node node = nodeList.item(i);
  60. if (node.getNodeType() == Node.COMMENT_NODE) {
  61. continue;
  62. }
  63. if ("java".equalsIgnoreCase(node.getNodeName())) {
  64. // 解析类名和方法名
  65. String className = DomUtils.getNodeAttributeValue(node, "class");
  66. Assert.isNotBlank(className, "class不能为空!");
  67. String method = DomUtils.getNodeAttributeValue(node, "method");
  68. Assert.isNotBlank(method, "method不能为空!");
  69. // 解析参数
  70. NodeList values = (NodeList) DomUtils.evaluate("value", node, XPathConstants.NODESET);
  71. Object[] args = new Object[0];
  72. if (values != null) {
  73. // 取出参数值
  74. args = new Object[values.getLength()];
  75. for (int j = 0; j < args.length; j++) {
  76. // 解析表达式
  77. String expression = values.item(j).getTextContent();
  78. if (StringUtils.isNotBlank(expression)) {
  79. args[j] = context.evaluate(expression.trim());
  80. }
  81. }
  82. }
  83. // 调用java方法
  84. value = executeJava(className, method, args);
  85. } else if ("execute-sql".equalsIgnoreCase(node.getNodeName())) {
  86. String sqlId = DomUtils.getNodeAttributeValue(node, "id");
  87. Statement statement = configuration.getStatementById(sqlId);
  88. Assert.isNotNull(statement, String.format("找不到SQL:%s", sqlId));
  89. // 解析参数
  90. NodeList params = (NodeList) DomUtils.evaluate("param", node, XPathConstants.NODESET);
  91. if (params != null) {
  92. for (int j = 0, l = params.getLength(); j < l; j++) {
  93. Node param = params.item(j);
  94. String paramName = DomUtils.getNodeAttributeValue(param, "name");
  95. String paramValue = DomUtils.getNodeAttributeValue(param, "value");
  96. Assert.isNotBlanks("execute-sql/param的参数名和值都不能为空", paramName, paramValue);
  97. // 重新覆盖值
  98. context.put(paramName, context.evaluate(paramValue));
  99. }
  100. }
  101. //执行SQL
  102. value = executeSqlStatement((SqlStatement) statement, context);
  103. } else {
  104. logger.warn("不支持节点{}", node.getNodeName());
  105. continue;
  106. }
  107. // 设置返回值重置到context中
  108. String returnVal = DomUtils.getNodeAttributeValue(node, "return");
  109. if (StringUtils.isNotBlank(returnVal)) {
  110. context.put(returnVal, value);
  111. }
  112. }
  113. return value;
  114. }
  115. /**
  116. * 调用java方法
  117. *
  118. * @param className 类名
  119. * @param methodName 方法名
  120. * @param args 参数
  121. */
  122. private Object executeJava(String className, String methodName, Object... args) throws ClassNotFoundException {
  123. Class<?> clazz = Class.forName(className);
  124. Object target = null;
  125. AbstractReflection reflection = AbstractReflection.getInstance();
  126. Method method = (Method) reflection.getMethod(clazz, methodName, args);
  127. Assert.isTrue(method != null, String.format("在%s中找不到方法%s", className, methodName));
  128. try {
  129. target = applicationContext.getBean(clazz);
  130. } catch (BeansException ignored) {
  131. Assert.isTrue(Modifier.isStatic(method.getModifiers()), String.format("%s不在spring容器中时%s必须是静态方法", className, methodName));
  132. }
  133. return reflection.callMethod(target, method, args);
  134. }
  135. private Object executeSqlStatement(SqlStatement sqlStatement, RequestContext context) throws SQLException {
  136. if (sqlStatement.isPagination()) { //判断是否是分页语句
  137. // 获取要执行的SQL
  138. String sql = sqlStatement.getSqlNode().getSql(context).trim();
  139. // 从Request中提取Page对象
  140. Page page = pageProvider.getPage(context.getRequest());
  141. // 获取数据库方言
  142. Dialect dialect = sqlExecutor.getDialect(sqlStatement.getDataSourceName());
  143. PageResult<Object> pageResult = new PageResult<>();
  144. ExecuteSqlStatement statement = sqlStatement.buildExecuteSqlStatement(dialect.getCountSql(sql), context.getParameters());
  145. statement.setReturnType(Long.class);
  146. statement.setSqlMode(SqlMode.SELECT_ONE);
  147. // 获取总条数
  148. long total = (long) sqlExecutor.execute(statement);
  149. pageResult.setTotal(total);
  150. // 当条数>0时,执行查询语句,否则不查询以提高性能
  151. if (total > 0) {
  152. // 获取分页语句
  153. String pageSql = dialect.getPageSql(sql, context, page.getOffset(), page.getLimit());
  154. // 执行查询
  155. pageResult.setList((List) sqlExecutor.execute(sqlStatement.buildExecuteSqlStatement(pageSql, context.getParameters())));
  156. }
  157. return pageResult;
  158. } else if (SqlMode.INSERT_WITH_PK == sqlStatement.getSqlMode()) { //插入返回主键
  159. return sqlExecutor.executeInsertWithPk(sqlStatement, context);
  160. } else {
  161. // 获取要执行的SQL
  162. String sql = sqlStatement.getSqlNode().getSql(context).trim();
  163. // 普通SQL执行
  164. return sqlExecutor.execute(sqlStatement.buildExecuteSqlStatement(sql, context.getParameters()));
  165. }
  166. }
  167. }