Browse Source

实现缓存、修复多数据源BUG

mxd 5 years ago
parent
commit
556073e7fd

+ 1 - 1
pom.xml

@@ -11,7 +11,7 @@
     </parent>
     <groupId>org.ssssssss</groupId>
     <artifactId>ssssssss-core</artifactId>
-    <version>0.1.0</version>
+    <version>0.1.1</version>
     <packaging>jar</packaging>
     <name>ssssssss-core</name>
     <description>auto generate http api based on xml</description>

+ 59 - 0
src/main/java/org/ssssssss/cache/DefaultSqlCache.java

@@ -0,0 +1,59 @@
+package org.ssssssss.cache;
+
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+public class DefaultSqlCache implements SqlCache {
+
+    private String separator = ":";
+
+    private LinkedHashMap<String, Object> cached;
+
+    public DefaultSqlCache(int maxSize) {
+        this.cached = new LinkedHashMap<String, Object>(maxSize, 0.75f, true) {
+            @Override
+            protected boolean removeEldestEntry(Map.Entry<String, Object> eldest) {
+                return size() > maxSize;
+            }
+        };
+    }
+    @Override
+    public void put(String name, String key, Object value) {
+        this.cached.put(name + separator + key,value);
+    }
+    @Override
+    public Object get(String name, String key) {
+        return cached.get(name + separator + key);
+    }
+    @Override
+    public void remove(String name) {
+        Iterator<Map.Entry<String, Object>> iterator = cached.entrySet().iterator();
+        String prefix = name + separator;
+        while(iterator.hasNext()){
+            Map.Entry<String, Object> entry = iterator.next();
+            if(entry.getKey().startsWith(prefix)){
+                iterator.remove();
+            }
+        }
+    }
+
+    public static void main(String[] args) {
+        DefaultSqlCache sqlCache = new DefaultSqlCache(10);
+        for (int i = 0; i < 10; i++) {
+            sqlCache.put("test",i+"",i);
+        }
+        for (int i = 0; i < 5; i++) {
+            sqlCache.get("test",i+"");
+        }
+        for (int i = 10; i < 15; i++) {
+            sqlCache.put("test",i+"",i);
+        }
+        for (int i = 10; i < 15; i++) {
+            sqlCache.put("test1",i+"",i);
+        }
+        System.out.println(sqlCache.cached);
+        sqlCache.remove("test");
+        System.out.println(sqlCache.cached);
+    }
+}

+ 43 - 0
src/main/java/org/ssssssss/cache/SqlCache.java

@@ -0,0 +1,43 @@
+package org.ssssssss.cache;
+
+import org.ssssssss.utils.MD5Utils;
+
+import java.util.Arrays;
+
+/**
+ * SQL缓存接口
+ */
+public interface SqlCache {
+
+    /**
+     * 计算key
+     * @param sql   sql
+     * @param parameters sql参数
+     */
+    default String buildSqlCacheKey(String sql, Object[] parameters) {
+        return MD5Utils.encrypt(sql + ":" + Arrays.toString(parameters));
+    }
+
+    /**
+     * 存入缓存
+     * @param name 名字
+     * @param key   key
+     * @param value 值
+     */
+    void put(String name, String key, Object value);
+
+    /**
+     * 获取缓存
+     * @param name  名字
+     * @param key   key
+     * @return
+     */
+    Object get(String name,String key);
+
+    /**
+     * 删除缓存
+     * @param name  名字
+     */
+    void remove(String name);
+
+}

+ 66 - 24
src/main/java/org/ssssssss/executor/SqlExecutor.java

@@ -1,5 +1,6 @@
 package org.ssssssss.executor;
 
+import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
@@ -9,6 +10,7 @@ import org.springframework.jdbc.datasource.DataSourceUtils;
 import org.springframework.jdbc.support.GeneratedKeyHolder;
 import org.springframework.jdbc.support.JdbcUtils;
 import org.springframework.jdbc.support.KeyHolder;
+import org.ssssssss.cache.SqlCache;
 import org.ssssssss.context.RequestContext;
 import org.ssssssss.dialect.Dialect;
 import org.ssssssss.dialect.DialectUtils;
@@ -17,6 +19,7 @@ import org.ssssssss.exception.S8Exception;
 import org.ssssssss.provider.KeyProvider;
 import org.ssssssss.scripts.SqlNode;
 import org.ssssssss.session.DynamicDataSource;
+import org.ssssssss.session.ExecuteSqlStatement;
 import org.ssssssss.session.SqlStatement;
 import org.ssssssss.utils.Assert;
 import org.ssssssss.utils.DomUtils;
@@ -38,16 +41,21 @@ public class SqlExecutor {
     private DynamicDataSource dynamicDataSource;
 
     private ColumnMapRowMapper columnMapRowMapper = new ColumnMapRowMapper();
-    ;
 
     private Map<String, KeyProvider> keyProviders = new HashMap<>();
 
     private Map<String, Dialect> cachedDialects = new ConcurrentHashMap<>();
 
+    private SqlCache sqlCache;
+
     public SqlExecutor(DynamicDataSource dynamicDataSource) {
         this.dynamicDataSource = dynamicDataSource;
     }
 
+    public void setSqlCache(SqlCache sqlCache) {
+        this.sqlCache = sqlCache;
+    }
+
     /**
      * 设置是否是驼峰命名
      *
@@ -86,41 +94,70 @@ public class SqlExecutor {
 
     /**
      * 执行SQL
-     *
-     * @param mode       SQL模式
-     * @param sql        SQL
-     * @param parameters SQL参数
-     * @param returnType 返回值类型
-     * @return
      */
-    public Object execute(String dataSourceName, SqlMode mode, String sql, Object[] parameters, Class<?> returnType) {
-        JdbcTemplate jdbcTemplate = getJdbcTemplate(dataSourceName);
-        printLog(dataSourceName, sql, parameters);
-        if (SqlMode.SELECT_LIST == mode) {
+    public Object execute(ExecuteSqlStatement statement) {
+        // 获取SQL
+        String sql = statement.getSql();
+        // 获取参数
+        Object[] parameters = statement.getParameters();
+        // 获取SQL模式
+        SqlMode mode = statement.getSqlMode();
+        // 获取返回值类型
+        Class<?> returnType = statement.getReturnType();
+        // 缓存Key
+        String sqlCacheKey = null;
+        // 返回值
+        Object value;
+        // 判断是否使用缓存
+        if (this.sqlCache != null && StringUtils.isNotBlank(statement.getUseCache())) {
+            // 构建key
+            sqlCacheKey = this.sqlCache.buildSqlCacheKey(sql, parameters);
+            // 查询缓存
+            value = this.sqlCache.get(statement.getUseCache(), sqlCacheKey);
+            if (value != null) {
+                return value;
+            }
+        }
+        JdbcTemplate jdbcTemplate = getJdbcTemplate(statement.getDataSourceName());
+        // 打印SQL日志
+        printLog(statement.getDataSourceName(), sql, parameters);
+        if (SqlMode.SELECT_LIST == mode) {  //查询List
             if (returnType == null || returnType == Map.class) {
-                return jdbcTemplate.query(sql, parameters, columnMapRowMapper);
+                value = jdbcTemplate.query(sql, parameters, columnMapRowMapper);
+            } else {
+                value = jdbcTemplate.queryForList(sql, parameters, returnType);
+            }
+
+        } else if (SqlMode.UPDATE == mode || SqlMode.INSERT == mode || SqlMode.DELETE == mode) {    //增删改
+            int retVal = jdbcTemplate.update(sql, parameters);
+            // 删除缓存
+            if (retVal > 0 && this.sqlCache != null && StringUtils.isNotBlank(statement.getDeleteCache())) {
+                this.sqlCache.remove(statement.getDeleteCache());
             }
-            return jdbcTemplate.queryForList(sql, parameters, returnType);
-        } else if (SqlMode.UPDATE == mode || SqlMode.INSERT == mode || SqlMode.DELETE == mode) {
-            int value = jdbcTemplate.update(sql, parameters);
             // 当设置返回值是boolean类型时,做>0比较
             if (returnType == Boolean.class) {
-                return value > 0;
+                return retVal > 0;
             }
-            return value;
-        } else if (SqlMode.SELECT_ONE == mode) {
+            return retVal;
+        } else if (SqlMode.SELECT_ONE == mode) {    //查询一条
             Collection collection;
             if (returnType == null || returnType == Map.class) {
                 collection = jdbcTemplate.query(sql, columnMapRowMapper, parameters);
             } else {
                 collection = jdbcTemplate.queryForList(sql, returnType, parameters);
             }
-            return collection != null && collection.size() >= 1 ? collection.iterator().next() : null;
+            value = collection != null && collection.size() >= 1 ? collection.iterator().next() : null;
         } else {
             throw new S8Exception("暂时不支持[" + mode + "]模式");
         }
+        // 判断是否使用了缓存
+        if (sqlCacheKey != null && value != null) {
+            this.sqlCache.put(statement.getUseCache(), sqlCacheKey, value);
+        }
+        return value;
     }
 
+
     public Object executeInsertWithPk(SqlStatement statement, RequestContext requestContext) throws SQLException {
         String dataSourceName = statement.getDataSourceName();
         JdbcTemplate jdbcTemplate = getJdbcTemplate(dataSourceName);
@@ -148,7 +185,8 @@ public class SqlExecutor {
                     // 获取插入SQL
                     String insertSQL = statement.getSqlNode().getSql(requestContext);
                     // 执行插入
-                    executeUpdate(dataSourceName, connection, insertSQL, requestContext.getParameters());
+                    executeUpdate(dataSourceName, connection, insertSQL, requestContext.getParameters(),statement.getDeleteCache());
+
                     // 清空参数
                     requestContext.getParameters().clear();
                     if (!before) {
@@ -166,7 +204,7 @@ public class SqlExecutor {
                     // 获取插入SQL
                     String insertSQL = statement.getSqlNode().getSql(requestContext);
                     // 执行插入
-                    executeUpdate(dataSourceName, connection, insertSQL, requestContext.getParameters());
+                    executeUpdate(dataSourceName, connection, insertSQL, requestContext.getParameters(),statement.getDeleteCache());
                 }
                 return value;
             } finally {
@@ -188,13 +226,17 @@ public class SqlExecutor {
     /**
      * 执行插入
      */
-    private int executeUpdate(String dataSourceName, Connection connection, String sql, List<Object> parameters) throws SQLException {
+    private int executeUpdate(String dataSourceName, Connection connection, String sql, List<Object> parameters,String deleteCache) throws SQLException {
         PreparedStatement ps = null;
         try {
             printLog(dataSourceName, sql, parameters);
             ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
             new ArgumentPreparedStatementSetter(parameters.toArray()).setValues(ps);
-            return ps.executeUpdate();
+            int val = ps.executeUpdate();
+            if (this.sqlCache != null && StringUtils.isNotBlank(deleteCache)) {
+                this.sqlCache.remove(deleteCache);
+            }
+            return val;
         } finally {
             JdbcUtils.closeStatement(ps);
         }
@@ -254,7 +296,7 @@ public class SqlExecutor {
      * 获取数据库方言
      */
     public Dialect getDialect(String dataSourceName) throws SQLException {
-        Dialect dialect = cachedDialects.get(cachedDialects);
+        Dialect dialect = cachedDialects.get(dataSourceName);
         if (dialect == null && !cachedDialects.containsKey(dataSourceName)) {
             JdbcTemplate jdbcTemplate = getJdbcTemplate(dataSourceName);
             Connection connection = jdbcTemplate.getDataSource().getConnection();

+ 7 - 7
src/main/java/org/ssssssss/executor/StatementExecutor.java

@@ -12,10 +12,7 @@ import org.ssssssss.expression.interpreter.AbstractReflection;
 import org.ssssssss.model.Page;
 import org.ssssssss.model.PageResult;
 import org.ssssssss.provider.PageProvider;
-import org.ssssssss.session.Configuration;
-import org.ssssssss.session.FunctionStatement;
-import org.ssssssss.session.SqlStatement;
-import org.ssssssss.session.Statement;
+import org.ssssssss.session.*;
 import org.ssssssss.utils.Assert;
 import org.ssssssss.utils.DomUtils;
 import org.w3c.dom.Node;
@@ -156,15 +153,18 @@ public class StatementExecutor {
             // 获取数据库方言
             Dialect dialect = sqlExecutor.getDialect(sqlStatement.getDataSourceName());
             PageResult<Object> pageResult = new PageResult<>();
+            ExecuteSqlStatement statement = sqlStatement.buildExecuteSqlStatement(dialect.getCountSql(sql), context.getParameters());
+            statement.setReturnType(Long.class);
+            statement.setSqlMode(SqlMode.SELECT_ONE);
             // 获取总条数
-            long total = (long) sqlExecutor.execute(sqlStatement.getDataSourceName(), SqlMode.SELECT_ONE, dialect.getCountSql(sql), context.getParameters().toArray(), Long.class);
+            long total = (long) sqlExecutor.execute(statement);
             pageResult.setTotal(total);
             // 当条数>0时,执行查询语句,否则不查询以提高性能
             if (total > 0) {
                 // 获取分页语句
                 String pageSql = dialect.getPageSql(sql, context, page.getOffset(), page.getLimit());
                 // 执行查询
-                pageResult.setList((List) sqlExecutor.execute(sqlStatement.getDataSourceName(), SqlMode.SELECT_LIST, pageSql, context.getParameters().toArray(), sqlStatement.getReturnType()));
+                pageResult.setList((List) sqlExecutor.execute(sqlStatement.buildExecuteSqlStatement(pageSql, context.getParameters())));
             }
             return pageResult;
         } else if (SqlMode.INSERT_WITH_PK == sqlStatement.getSqlMode()) {   //插入返回主键
@@ -173,7 +173,7 @@ public class StatementExecutor {
             // 获取要执行的SQL
             String sql = sqlStatement.getSqlNode().getSql(context).trim();
             // 普通SQL执行
-            return sqlExecutor.execute(sqlStatement.getDataSourceName(), sqlStatement.getSqlMode(), sql, context.getParameters().toArray(), sqlStatement.getReturnType());
+            return sqlExecutor.execute(sqlStatement.buildExecuteSqlStatement(sql, context.getParameters()));
         }
     }
 }

+ 6 - 0
src/main/java/org/ssssssss/session/DynamicDataSource.java

@@ -11,10 +11,16 @@ public class DynamicDataSource {
     private Map<String, DataSource> dataSourceMap = new HashMap<>();
 
     public void put(String dataSourceName, DataSource dataSource) {
+        if(dataSourceName == null){
+            dataSourceName = "";
+        }
         this.dataSourceMap.put(dataSourceName, dataSource);
     }
 
     public DataSource getDataSource(String dataSourceName) {
+        if(dataSourceName == null){
+            dataSourceName = "";
+        }
         DataSource dataSource = dataSourceMap.get(dataSourceName);
         Assert.isNotNull(dataSource, String.format("找不到数据源%s", dataSourceName));
         return dataSource;

+ 138 - 0
src/main/java/org/ssssssss/session/ExecuteSqlStatement.java

@@ -0,0 +1,138 @@
+package org.ssssssss.session;
+
+import org.ssssssss.enums.SqlMode;
+import org.ssssssss.scripts.SqlNode;
+import org.w3c.dom.Node;
+
+public class ExecuteSqlStatement {
+
+    /**
+     * ID
+     */
+    private String id;
+
+    /**
+     * SQL
+     */
+    private String sql;
+
+    /**
+     * SQL参数
+     */
+    private Object[] parameters;
+
+    /**
+     * SQL模式
+     */
+    private SqlMode sqlMode;
+
+    /**
+     * 返回值类型
+     */
+    private Class<?> returnType;
+
+    /**
+     * 数据源名称
+     */
+    private String dataSourceName;
+
+    /**
+     * selectKey节点
+     */
+    private Node selectKey;
+
+    /**
+     * selectKey转SqlNode
+     */
+    private SqlNode selectKeySqlNode;
+
+    /**
+     * 使用的缓存名称
+     */
+    private String useCache;
+
+    /**
+     * 删除的缓存名称
+     */
+    private String deleteCache;
+
+    public String getId() {
+        return id;
+    }
+
+    public void setId(String id) {
+        this.id = id;
+    }
+
+    public String getSql() {
+        return sql;
+    }
+
+    public void setSql(String sql) {
+        this.sql = sql;
+    }
+
+    public Object[] getParameters() {
+        return parameters;
+    }
+
+    public void setParameters(Object[] parameters) {
+        this.parameters = parameters;
+    }
+
+    public SqlMode getSqlMode() {
+        return sqlMode;
+    }
+
+    public void setSqlMode(SqlMode sqlMode) {
+        this.sqlMode = sqlMode;
+    }
+
+    public Class<?> getReturnType() {
+        return returnType;
+    }
+
+    public void setReturnType(Class<?> returnType) {
+        this.returnType = returnType;
+    }
+
+    public String getDataSourceName() {
+        return dataSourceName;
+    }
+
+    public void setDataSourceName(String dataSourceName) {
+        this.dataSourceName = dataSourceName;
+    }
+
+    public Node getSelectKey() {
+        return selectKey;
+    }
+
+    public void setSelectKey(Node selectKey) {
+        this.selectKey = selectKey;
+    }
+
+    public SqlNode getSelectKeySqlNode() {
+        return selectKeySqlNode;
+    }
+
+    public void setSelectKeySqlNode(SqlNode selectKeySqlNode) {
+        this.selectKeySqlNode = selectKeySqlNode;
+    }
+
+    public String getUseCache() {
+        return useCache;
+    }
+
+    public void setUseCache(String useCache) {
+        this.useCache = useCache;
+    }
+
+    public String getDeleteCache() {
+        return deleteCache;
+    }
+
+    public void setDeleteCache(String deleteCache) {
+        this.deleteCache = deleteCache;
+    }
+}

+ 43 - 0
src/main/java/org/ssssssss/session/SqlStatement.java

@@ -4,6 +4,8 @@ import org.ssssssss.enums.SqlMode;
 import org.ssssssss.scripts.SqlNode;
 import org.w3c.dom.Node;
 
+import java.util.List;
+
 public class SqlStatement extends Statement {
 
     /**
@@ -41,6 +43,16 @@ public class SqlStatement extends Statement {
      */
     private SqlNode selectKeySqlNode;
 
+    /**
+     * 使用的缓存名称
+     */
+    private String useCache;
+
+    /**
+     * 删除的缓存名称
+     */
+    private String deleteCache;
+
     public SqlMode getSqlMode() {
         return sqlMode;
     }
@@ -96,4 +108,35 @@ public class SqlStatement extends Statement {
     public void setSelectKeySqlNode(SqlNode selectKeySqlNode) {
         this.selectKeySqlNode = selectKeySqlNode;
     }
+
+    public String getUseCache() {
+        return useCache;
+    }
+
+    public void setUseCache(String useCache) {
+        this.useCache = useCache;
+    }
+
+    public String getDeleteCache() {
+        return deleteCache;
+    }
+
+    public void setDeleteCache(String deleteCache) {
+        this.deleteCache = deleteCache;
+    }
+
+    public ExecuteSqlStatement buildExecuteSqlStatement(String sql, List<Object> parameters){
+        ExecuteSqlStatement executeSqlStatement = new ExecuteSqlStatement();
+        executeSqlStatement.setSql(sql);
+        executeSqlStatement.setParameters(parameters.toArray());
+        executeSqlStatement.setId(this.getId());
+        executeSqlStatement.setDataSourceName(this.dataSourceName);
+        executeSqlStatement.setSqlMode(this.sqlMode);
+        executeSqlStatement.setReturnType(this.returnType);
+        executeSqlStatement.setSelectKey(this.selectKey);
+        executeSqlStatement.setSelectKeySqlNode(this.selectKeySqlNode);
+        executeSqlStatement.setUseCache(this.useCache);
+        executeSqlStatement.setDeleteCache(this.deleteCache);
+        return executeSqlStatement;
+    }
 }

+ 35 - 0
src/main/java/org/ssssssss/utils/MD5Utils.java

@@ -0,0 +1,35 @@
+package org.ssssssss.utils;
+
+import org.ssssssss.exception.S8Exception;
+
+import java.security.MessageDigest;
+
+public class MD5Utils {
+
+    private static final char[] HEX_CHARS = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
+
+    /**
+     * MD5加密
+     */
+    public static String encrypt(String value){
+        return encrypt(value.getBytes());
+    }
+
+    /**
+     * MD5加密
+     */
+    public static String encrypt(byte[] value){
+        try {
+            byte[] bytes = MessageDigest.getInstance("MD5").digest(value);
+            char[] chars = new char[32];
+            for (int i = 0; i < chars.length; i = i + 2) {
+                byte b = bytes[i / 2];
+                chars[i] = HEX_CHARS[(b >>> 0x4) & 0xf];
+                chars[i + 1] = HEX_CHARS[b & 0xf];
+            }
+            return new String(chars);
+        } catch (Exception e) {
+            throw new S8Exception("md5 encrypt error",e);
+        }
+    }
+}

+ 16 - 6
src/main/java/org/ssssssss/utils/S8XMLFileParser.java

@@ -13,13 +13,14 @@ import org.ssssssss.session.*;
 import org.w3c.dom.Document;
 import org.w3c.dom.Node;
 import org.w3c.dom.NodeList;
+import org.xml.sax.InputSource;
 import org.xml.sax.SAXException;
 
+import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
 import javax.xml.parsers.ParserConfigurationException;
 import javax.xml.xpath.XPathConstants;
-import java.io.File;
-import java.io.IOException;
+import java.io.*;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -40,7 +41,9 @@ public class S8XMLFileParser {
     static XMLStatement parse(File file) {
         XMLStatement statement = null;
         try {
-            Document document = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(file);
+            DocumentBuilder documentBuilder = DocumentBuilderFactory.newInstance().newDocumentBuilder();
+            documentBuilder.setEntityResolver((publicId, systemId) -> new InputSource(new StringReader("")));
+            Document document = documentBuilder.parse(file);
             // 解析根节点
             statement = parseRoot(document);
             // 解析验证节点
@@ -134,9 +137,10 @@ public class S8XMLFileParser {
             Node item = nodeList.item(i);
             SqlStatement sqlStatement = new SqlStatement();
             parseStatement(sqlStatement, item, xmlStatement);
-            sqlStatement.setDataSourceName(DomUtils.getNodeAttributeValue(item, "datasource"));
+            sqlStatement.setDataSourceName(StringUtils.defaultString(DomUtils.getNodeAttributeValue(item, "datasource"),""));
+            SqlMode sqlMode = SqlMode.valueOf(item.getNodeName().toUpperCase().replace("-", "_"));
             // 设置SqlMode
-            sqlStatement.setSqlMode(SqlMode.valueOf(item.getNodeName().toUpperCase().replace("-", "_")));
+            sqlStatement.setSqlMode(sqlMode);
             String returnType = DomUtils.getNodeAttributeValue(item, "return-type");
             if ("int".equalsIgnoreCase(returnType)) {
                 sqlStatement.setReturnType(Integer.class);
@@ -161,7 +165,13 @@ public class S8XMLFileParser {
             } else {
                 sqlStatement.setReturnType(Map.class);
             }
-            if (SqlMode.SELECT_LIST == sqlStatement.getSqlMode()) {
+            String cacheName = DomUtils.getNodeAttributeValue(item, "cache-name");
+            if(SqlMode.SELECT_LIST == sqlMode || SqlMode.SELECT_ONE == sqlMode){
+                sqlStatement.setUseCache(cacheName);
+            }else{
+                sqlStatement.setDeleteCache(cacheName);
+            }
+            if (SqlMode.SELECT_LIST == sqlMode) {
                 //设置是否是分页
                 sqlStatement.setPagination("true".equalsIgnoreCase(DomUtils.getNodeAttributeValue(item, "page")));
             }