Browse Source

支持多数据

mxd 5 years ago
parent
commit
e98faeee4e

+ 4 - 2
README.md

@@ -59,8 +59,10 @@ spring.datasource.driver-class-name=com.mysql.jdbc.Driver
 在`src/main/resources/ssssssss/`下建立`user.xml`文件
 ```xml
 <?xml version="1.0" encoding="utf-8" ?>
-<!DOCTYPE ssssssss PUBLIC "-//ssssssss.org//DTD ssssssss 0.1//EN" "http://ssssssss.org/dtd/0.0.x/ssssssss.dtd">
-<ssssssss request-mapping="/user">
+<ssssssss request-mapping="/user" 
+        xmlns="http://ssssssss.org/schema"
+        xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+        xsi:schemaLocation="http://ssssssss.org/schema http://ssssssss.org/schema/ssssssss-0.1.xsd">
     <!-- 访问地址/user/list,访问方法get,并开启分页 -->
     <select-list request-mapping="/list" request-method="get" page="true">
         select username,password from sys_user

+ 1 - 1
pom.xml

@@ -11,7 +11,7 @@
     </parent>
     <groupId>org.ssssssss</groupId>
     <artifactId>ssssssss-core</artifactId>
-    <version>0.0.2</version>
+    <version>0.1.0</version>
     <packaging>jar</packaging>
     <name>ssssssss-core</name>
     <description>auto generate http api based on xml</description>

+ 40 - 33
src/main/java/org/ssssssss/executor/SqlExecutor.java

@@ -4,8 +4,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.ssssssss.enums.SqlMode;
 import org.ssssssss.exception.S8Exception;
+import org.ssssssss.session.DynamicDataSource;
 
-import javax.sql.DataSource;
 import java.sql.*;
 import java.util.Date;
 import java.util.*;
@@ -15,7 +15,7 @@ import java.util.*;
  */
 public class SqlExecutor {
 
-    private DataSource dataSource;
+    private DynamicDataSource dynamicDataSource;
 
     private Logger logger = LoggerFactory.getLogger(SqlExecutor.class);
 
@@ -24,8 +24,8 @@ public class SqlExecutor {
      */
     private boolean mapUnderscoreToCamelCase;
 
-    public SqlExecutor(DataSource dataSource) {
-        this.dataSource = dataSource;
+    public SqlExecutor(DynamicDataSource dynamicDataSource) {
+        this.dynamicDataSource = dynamicDataSource;
     }
 
     public void setMapUnderscoreToCamelCase(boolean mapUnderscoreToCamelCase) {
@@ -42,33 +42,35 @@ public class SqlExecutor {
      * @return
      * @throws SQLException
      */
-    public Object execute(SqlMode mode, String sql, List<Object> parameters, Class<?> returnType) throws SQLException {
+    public Object execute(String dataSourceName, SqlMode mode, String sql, List<Object> parameters, Class<?> returnType) throws SQLException {
         if (SqlMode.SELECT_LIST == mode) {
-            return queryForList(sql, parameters, returnType == null ? Map.class : returnType);
+            return queryForList(dataSourceName, sql, parameters, returnType == null ? Map.class : returnType);
         } else if (SqlMode.UPDATE == mode || SqlMode.INSERT == mode || SqlMode.DELETE == mode) {
-            int value = update(sql, parameters);
+            int value = update(dataSourceName, sql, parameters);
             // 当设置返回值是boolean类型时,做>0比较
-            if(returnType == Boolean.class){
+            if (returnType == Boolean.class) {
                 return value > 0;
             }
             return value;
         } else if (SqlMode.SELECT_ONE == mode) {
-            return queryForOne(sql, parameters, returnType);
+            return queryForOne(dataSourceName, sql, parameters, returnType);
         } else {
             throw new S8Exception("暂时不支持[" + mode + "]模式");
         }
     }
 
-    private Connection getConnection() throws SQLException {
-        return dataSource.getConnection();
+    private Connection getConnection(String dataSourceName) throws SQLException {
+        return dynamicDataSource.getDataSource(dataSourceName).getConnection();
     }
 
     /**
      * 获取Connection并调用回调函数执行
-     * @param connectionCallback    回调函数
+     *
+     * @param dataSourceName     数据源名称
+     * @param connectionCallback 回调函数
      */
-    public <T> T doInConnection(ConnectionCallback<T> connectionCallback) throws SQLException {
-        Connection connection = getConnection();
+    public <T> T doInConnection(String dataSourceName, ConnectionCallback<T> connectionCallback) throws SQLException {
+        Connection connection = getConnection(dataSourceName);
         try {
             return connectionCallback.execute(connection);
         } finally {
@@ -76,8 +78,8 @@ public class SqlExecutor {
         }
     }
 
-    private int update(String sql, List<Object> params) throws SQLException {
-        Connection connection = getConnection();
+    private int update(String dataSourceName, String sql, List<Object> params) throws SQLException {
+        Connection connection = getConnection(dataSourceName);
         PreparedStatement stmt = null;
         try {
             stmt = createPreparedStatement(connection, sql, params);
@@ -90,10 +92,11 @@ public class SqlExecutor {
 
     /**
      * 查询一条
-     * @param connection    连接对象
-     * @param sql   SQL
-     * @param params   SQL参数
-     * @param returnType    返回值类型
+     *
+     * @param connection 连接对象
+     * @param sql        SQL
+     * @param params     SQL参数
+     * @param returnType 返回值类型
      */
     public <T> T queryForOne(Connection connection, String sql, List<Object> params, Class<T> returnType) throws SQLException {
         PreparedStatement stmt = null;
@@ -118,12 +121,13 @@ public class SqlExecutor {
 
     /**
      * 查询一条
-     * @param sql   SQL
-     * @param params   SQL参数
-     * @param returnType    返回值类型
+     *
+     * @param sql        SQL
+     * @param params     SQL参数
+     * @param returnType 返回值类型
      */
-    private <T> T queryForOne(String sql, List<Object> params, Class<T> returnType) throws SQLException {
-        Connection connection = getConnection();
+    private <T> T queryForOne(String dataSourceName, String sql, List<Object> params, Class<T> returnType) throws SQLException {
+        Connection connection = getConnection(dataSourceName);
         try {
             return queryForOne(connection, sql, params, returnType);
         } finally {
@@ -133,10 +137,11 @@ public class SqlExecutor {
 
     /**
      * 查询List
-     * @param connection    连接对象
-     * @param sql   SQL
-     * @param params   SQL参数
-     * @param returnType    返回值类型
+     *
+     * @param connection 连接对象
+     * @param sql        SQL
+     * @param params     SQL参数
+     * @param returnType 返回值类型
      */
     public List<Object> queryForList(Connection connection, String sql, List<Object> params, Class<?> returnType) throws SQLException {
         PreparedStatement stmt = null;
@@ -177,6 +182,7 @@ public class SqlExecutor {
 
     /**
      * 下划线转驼峰命名
+     *
      * @param columnName 列名
      * @return
      */
@@ -202,8 +208,8 @@ public class SqlExecutor {
     }
 
 
-    private List<Object> queryForList(String sql, List<Object> params, Class<?> returnType) throws SQLException {
-        Connection connection = getConnection();
+    private List<Object> queryForList(String dataSourceName, String sql, List<Object> params, Class<?> returnType) throws SQLException {
+        Connection connection = getConnection(dataSourceName);
         try {
             return queryForList(connection, sql, params, returnType);
         } finally {
@@ -213,8 +219,9 @@ public class SqlExecutor {
 
     /**
      * 统一创建PrepareStatement对象
-     * @param conn  连接对象
-     * @param sql   SQL
+     *
+     * @param conn   连接对象
+     * @param sql    SQL
      * @param params SQL参数
      */
     private PreparedStatement createPreparedStatement(Connection conn, String sql, List<Object> params) throws SQLException {

+ 3 - 3
src/main/java/org/ssssssss/executor/StatementExecutor.java

@@ -137,7 +137,7 @@ public class StatementExecutor {
         Object target = null;
         AbstractReflection reflection = AbstractReflection.getInstance();
         Method method = (Method) reflection.getMethod(clazz, methodName, args);
-        Assert.isTrue(method != null,String.format("在%s中找不到方法%s",className,methodName));
+        Assert.isTrue(method != null, String.format("在%s中找不到方法%s", className, methodName));
         try {
             target = applicationContext.getBean(clazz);
         } catch (BeansException ignored) {
@@ -153,7 +153,7 @@ public class StatementExecutor {
             // 从Request中提取Page对象
             Page page = pageProvider.getPage(context.getRequest());
             // 执行分页逻辑
-            return sqlExecutor.doInConnection(connection -> {
+            return sqlExecutor.doInConnection(sqlStatement.getDataSourceName(), connection -> {
                 PageResult<Object> pageResult = new PageResult<>();
                 // 获取数据库方言
                 Dialect dialect = DialectUtils.getDialectFromUrl(connection.getMetaData().getURL());
@@ -172,7 +172,7 @@ public class StatementExecutor {
             });
         } else {
             // 普通SQL执行
-            return sqlExecutor.execute(sqlStatement.getSqlMode(), sql, context.getParameters(), sqlStatement.getReturnType());
+            return sqlExecutor.execute(sqlStatement.getDataSourceName(), sqlStatement.getSqlMode(), sql, context.getParameters(), sqlStatement.getReturnType());
         }
     }
 }

+ 22 - 0
src/main/java/org/ssssssss/session/DynamicDataSource.java

@@ -0,0 +1,22 @@
+package org.ssssssss.session;
+
+import org.ssssssss.utils.Assert;
+
+import javax.sql.DataSource;
+import java.util.HashMap;
+import java.util.Map;
+
+public class DynamicDataSource {
+
+    private Map<String, DataSource> dataSourceMap = new HashMap<>();
+
+    public void put(String dataSourceName, DataSource dataSource) {
+        this.dataSourceMap.put(dataSourceName, dataSource);
+    }
+
+    public DataSource getDataSource(String dataSourceName) {
+        DataSource dataSource = dataSourceMap.get(dataSourceName);
+        Assert.isNotNull(dataSource, String.format("找不到数据源%s", dataSourceName));
+        return dataSource;
+    }
+}

+ 9 - 0
src/main/java/org/ssssssss/session/SqlStatement.java

@@ -25,6 +25,8 @@ public class SqlStatement extends Statement {
      */
     private Class<?> returnType;
 
+    private String dataSourceName;
+
     public SqlMode getSqlMode() {
         return sqlMode;
     }
@@ -57,4 +59,11 @@ public class SqlStatement extends Statement {
         this.pagination = pagination;
     }
 
+    public String getDataSourceName() {
+        return dataSourceName;
+    }
+
+    public void setDataSourceName(String dataSourceName) {
+        this.dataSourceName = dataSourceName;
+    }
 }

+ 32 - 40
src/main/java/org/ssssssss/utils/S8XMLFileParser.java

@@ -57,33 +57,48 @@ public class S8XMLFileParser {
         return statement;
     }
 
-    private static List<Statement> parseFunctionStatement(XMLStatement statement, NodeList nodeList) {
+    private static List<Statement> parseFunctionStatement(XMLStatement xmlStatement, NodeList nodeList) {
         List<Statement> statements = new ArrayList<>();
         for (int i = 0, len = nodeList.getLength(); i < len; i++) {
             Node node = nodeList.item(i);
             FunctionStatement functionStatement = new FunctionStatement();
-            // 设置是否支持RequestBody
-            functionStatement.setRequestBody("true".equalsIgnoreCase(DomUtils.getNodeAttributeValue(node,"request-body")));
-            // 设置请求路径
-            functionStatement.setRequestMapping(DomUtils.getNodeAttributeValue(node, "request-mapping"));
-            // 设置请求方法
-            functionStatement.setRequestMethod(DomUtils.getNodeAttributeValue(node, "request-method"));
-            // 设置节点
-            functionStatement.setNode(node);
-
-            // 设置ID
-            functionStatement.setId(DomUtils.getNodeAttributeValue(node, "id"));
-
+            parseStatement(functionStatement, node, xmlStatement);
             // TODO 这里后续需要改进
             // 设置子节点,不进行深层解析,执行时在解析
             functionStatement.setNodeList((NodeList) DomUtils.evaluate("*", node, XPathConstants.NODESET));
-
-            functionStatement.setXmlStatement(statement);
             statements.add(functionStatement);
         }
         return statements;
     }
 
+    private static void parseStatement(Statement statement, Node node, XMLStatement xmlStatement) {
+        // 设置是否支持RequestBody
+        statement.setRequestBody("true".equalsIgnoreCase(DomUtils.getNodeAttributeValue(node, "request-body")));
+
+        String requestMapping = DomUtils.getNodeAttributeValue(node, "request-mapping");
+        if (StringUtils.isNotBlank(requestMapping)) {
+            // 设置请求路径
+            statement.setRequestMapping(StringUtils.defaultString(xmlStatement.getRequestMapping()) + requestMapping);
+            // 设置请求方法
+            statement.setRequestMethod(DomUtils.getNodeAttributeValue(node, "request-method"));
+        }
+        // 设置节点
+        statement.setNode(node);
+        // 设置ID
+        statement.setId(DomUtils.getNodeAttributeValue(node, "id"));
+        // 设置XMLStatement
+        statement.setXmlStatement(xmlStatement);
+        // 解析验证
+        String validate = DomUtils.getNodeAttributeValue(node, "validate");
+        if (StringUtils.isNotBlank(validate)) {
+            // 支持多个验证
+            for (String validateId : validate.split(",")) {
+                Assert.isTrue(xmlStatement.containsValidateStatement(validateId), String.format("找不到验证节点[%s]", validateId));
+                statement.addValidate(validateId);
+            }
+        }
+    }
+
     /**
      * 解析Validate节点
      */
@@ -118,33 +133,10 @@ public class S8XMLFileParser {
         for (int i = 0, len = nodeList.getLength(); i < len; i++) {
             Node item = nodeList.item(i);
             SqlStatement sqlStatement = new SqlStatement();
-            // 设置ID
-            sqlStatement.setId(DomUtils.getNodeAttributeValue(item, "id"));
-            // 设置XmlStatement
-            sqlStatement.setXmlStatement(xmlStatement);
-            // 设置节点
-            sqlStatement.setNode(item);
-            // 设置是否支持RequestBody
-            sqlStatement.setRequestBody("true".equalsIgnoreCase(DomUtils.getNodeAttributeValue(item,"request-body")));
-            String validate = DomUtils.getNodeAttributeValue(item, "validate");
-            if (StringUtils.isNotBlank(validate)) {
-                // 支持多个验证
-                for (String validateId : validate.split(",")) {
-                    Assert.isTrue(xmlStatement.containsValidateStatement(validateId), String.format("找不到验证节点[%s]", validateId));
-                    sqlStatement.addValidate(validateId);
-                }
-            }
+            parseStatement(sqlStatement, item, xmlStatement);
+            sqlStatement.setDataSourceName(DomUtils.getNodeAttributeValue(item, "datasource"));
             // 设置SqlMode
             sqlStatement.setSqlMode(SqlMode.valueOf(item.getNodeName().toUpperCase().replace("-", "_")));
-
-            String requestMapping = DomUtils.getNodeAttributeValue(item, "request-mapping");
-            String id = DomUtils.getNodeAttributeValue(item, "id");
-            if (StringUtils.isNotBlank(requestMapping)) {
-                // 设置请求路径
-                sqlStatement.setRequestMapping(StringUtils.defaultString(xmlStatement.getRequestMapping()) + requestMapping);
-                // 设置请求方法
-                sqlStatement.setRequestMethod(DomUtils.getNodeAttributeValue(item, "request-method"));
-            }
             String returnType = DomUtils.getNodeAttributeValue(item, "return-type");
             if ("int".equalsIgnoreCase(returnType)) {
                 sqlStatement.setReturnType(Integer.class);