Ver código fonte

修复分页缓存BUG、优化代码

mxd 4 anos atrás
pai
commit
f8b0fff83a

+ 32 - 30
src/main/java/org/ssssssss/magicapi/cache/SqlCache.java

@@ -1,6 +1,5 @@
 package org.ssssssss.magicapi.cache;
 
-import org.ssssssss.magicapi.modules.BoundSql;
 import org.ssssssss.magicapi.utils.MD5Utils;
 
 import java.util.Arrays;
@@ -10,34 +9,37 @@ import java.util.Arrays;
  */
 public interface SqlCache {
 
-    /**
-     * 计算key
-     */
-    default String buildSqlCacheKey(BoundSql boundSql) {
-        return MD5Utils.encrypt(boundSql.getSql() + ":" + Arrays.toString(boundSql.getParameters()));
-    }
-
-    /**
-     * 存入缓存
-     * @param name 名字
-     * @param key   key
-     * @param value 值
-     * @param ttl 有效期
-     */
-    void put(String name, String key, Object value, long ttl);
-
-    /**
-     * 获取缓存
-     * @param name  名字
-     * @param key   key
-     * @return
-     */
-    <T> T get(String name, String key);
-
-    /**
-     * 删除缓存
-     * @param name  名字
-     */
-    void delete(String name);
+	/**
+	 * 计算key
+	 */
+	default String buildSqlCacheKey(String sql, Object[] params) {
+		return MD5Utils.encrypt(sql + ":" + Arrays.toString(params));
+	}
+
+	/**
+	 * 存入缓存
+	 *
+	 * @param name  名字
+	 * @param key   key
+	 * @param value 值
+	 * @param ttl   有效期
+	 */
+	void put(String name, String key, Object value, long ttl);
+
+	/**
+	 * 获取缓存
+	 *
+	 * @param name 名字
+	 * @param key  key
+	 * @return
+	 */
+	<T> T get(String name, String key);
+
+	/**
+	 * 删除缓存
+	 *
+	 * @param name 名字
+	 */
+	void delete(String name);
 
 }

+ 25 - 0
src/main/java/org/ssssssss/magicapi/config/MagicDynamicDataSource.java

@@ -5,9 +5,14 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.DataSourceTransactionManager;
+import org.springframework.jdbc.datasource.DataSourceUtils;
+import org.ssssssss.magicapi.adapter.DialectAdapter;
+import org.ssssssss.magicapi.dialect.Dialect;
+import org.ssssssss.magicapi.exception.MagicAPIException;
 import org.ssssssss.magicapi.utils.Assert;
 
 import javax.sql.DataSource;
+import java.sql.Connection;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -92,6 +97,8 @@ public class MagicDynamicDataSource {
 
 		private DataSource dataSource;
 
+		private Dialect dialect;
+
 		public DataSourceNode(DataSource dataSource) {
 			this.dataSource = dataSource;
 			this.dataSourceTransactionManager = new DataSourceTransactionManager(this.dataSource);
@@ -105,6 +112,24 @@ public class MagicDynamicDataSource {
 		public DataSourceTransactionManager getDataSourceTransactionManager() {
 			return dataSourceTransactionManager;
 		}
+
+		public Dialect getDialect(DialectAdapter dialectAdapter){
+			if(this.dialect == null){
+				Connection connection = null;
+				try {
+					connection = this.dataSource.getConnection();
+					this.dialect = dialectAdapter.getDialectFromUrl(connection.getMetaData().getURL());
+					if(this.dialect == null){
+						throw new MagicAPIException("自动获取数据库方言失败");
+					}
+				} catch (Exception e) {
+					throw new MagicAPIException("自动获取数据库方言失败", e);
+				} finally {
+					DataSourceUtils.releaseConnection(connection, this.dataSource);
+				}
+			}
+			return dialect;
+		}
 	}
 
 	public void setDefault(DataSource dataSource) {

+ 27 - 17
src/main/java/org/ssssssss/magicapi/modules/BoundSql.java

@@ -8,8 +8,8 @@ import org.ssssssss.script.parsing.ast.literal.BooleanLiteral;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Optional;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
 import java.util.regex.Pattern;
 
 public class BoundSql {
@@ -28,10 +28,13 @@ public class BoundSql {
 
 	private List<Object> parameters = new ArrayList<>();
 
-	private String cacheKey;
+	private SqlCache sqlCache;
 
+	private String cacheName;
 
-	BoundSql(String sql) {
+	private long ttl;
+
+	BoundSql(String sql){
 		MagicScriptContext context = MagicScriptContext.get();
 		// 处理?{}参数
 		this.sql = ifTokenParser.parse(sql.trim(), text -> {
@@ -70,6 +73,13 @@ public class BoundSql {
 		this.sql = this.sql == null ? null : REPLACE_MULTI_WHITE_LINE.matcher(this.sql.trim()).replaceAll("\r\n");
 	}
 
+	BoundSql(String sql, SqlCache sqlCache, String cacheName, long ttl) {
+		this(sql);
+		this.sqlCache = sqlCache;
+		this.cacheName = cacheName;
+		this.ttl = ttl;
+	}
+
 	/**
 	 * 添加SQL参数
 	 */
@@ -91,28 +101,28 @@ public class BoundSql {
 		return parameters.toArray();
 	}
 
-	/**
-	 * 清空缓存key
-	 */
-	public BoundSql removeCacheKey() {
-		this.cacheKey = null;
-		return this;
-	}
 
 	/**
-	 * 获取缓存key
+	 * 获取缓存值
 	 */
-	public String getCacheKey(SqlCache sqlCache) {
-		if (cacheKey == null) {
-			cacheKey = sqlCache.buildSqlCacheKey(this);
+	<T> T getCacheValue(String sql, Object[] params, Supplier<T> supplier) {
+		if (cacheName == null) {
+			return null;
+		}
+		String cacheKey = sqlCache.buildSqlCacheKey(sql, params);
+		Object cacheValue = sqlCache.get(cacheName, cacheKey);
+		if (cacheValue != null) {
+			return (T) cacheValue;
 		}
-		return cacheKey;
+		T value = supplier.get();
+		sqlCache.put(cacheName, cacheKey, value, ttl);
+		return value;
 	}
 
 	/**
 	 * 获取缓存值
 	 */
-	public <T> Optional<T> getCacheValue(SqlCache sqlCache, String cacheName) {
-		return Optional.ofNullable(cacheName == null ? null : sqlCache.get(cacheName, getCacheKey(sqlCache)));
+	<T> T getCacheValue(Supplier<T> supplier) {
+		return getCacheValue(this.getSql(), this.getParameters(), supplier);
 	}
 }

+ 17 - 114
src/main/java/org/ssssssss/magicapi/modules/SQLModule.java

@@ -2,7 +2,6 @@ package org.ssssssss.magicapi.modules;
 
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.RowMapper;
-import org.springframework.jdbc.datasource.DataSourceUtils;
 import org.springframework.jdbc.support.GeneratedKeyHolder;
 import org.springframework.jdbc.support.KeyHolder;
 import org.ssssssss.magicapi.adapter.ColumnMapperAdapter;
@@ -12,7 +11,6 @@ import org.ssssssss.magicapi.config.MagicDynamicDataSource;
 import org.ssssssss.magicapi.config.MagicDynamicDataSource.DataSourceNode;
 import org.ssssssss.magicapi.config.MagicModule;
 import org.ssssssss.magicapi.dialect.Dialect;
-import org.ssssssss.magicapi.exception.MagicAPIException;
 import org.ssssssss.magicapi.model.Page;
 import org.ssssssss.magicapi.provider.PageProvider;
 import org.ssssssss.magicapi.provider.ResultProvider;
@@ -20,7 +18,6 @@ import org.ssssssss.script.MagicScriptContext;
 import org.ssssssss.script.annotation.Comment;
 import org.ssssssss.script.annotation.UnableCall;
 
-import java.sql.Connection;
 import java.sql.PreparedStatement;
 import java.sql.Statement;
 import java.util.HashMap;
@@ -171,19 +168,6 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 		return new Transaction(this.dataSourceNode.getDataSourceTransactionManager());
 	}
 
-	/**
-	 * 添加至缓存
-	 *
-	 * @param value 缓存名
-	 */
-	@UnableCall
-	private <T> T putCacheValue(T value, BoundSql boundSql) {
-		if (this.cacheName != null) {
-			this.sqlCache.put(this.cacheName, boundSql.getCacheKey(this.sqlCache), value, this.ttl);
-		}
-		return value;
-	}
-
 	/**
 	 * 使用缓存
 	 *
@@ -266,9 +250,8 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	 */
 	@Comment("查询SQL,返回List类型结果")
 	public List<Map<String, Object>> select(@Comment("`SQL`语句") String sql) {
-		BoundSql boundSql = new BoundSql(sql);
-		return (List<Map<String, Object>>) boundSql.getCacheValue(this.sqlCache, this.cacheName)
-				.orElseGet(() -> putCacheValue(dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters()), boundSql));
+		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
+		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters()));
 	}
 
 	/**
@@ -284,70 +267,6 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 		return value;
 	}
 
-//	public int save(String tableName,Map<String,Object> params){
-//		return save(tableName,params,"id");
-//	}
-
-	/**
-	 * 如果已存在就修改,否则增加
-	 */
-//	public int save(String tableName,Map<String,Object> data,String primaryKey){
-//		Object[] params = new Object[]{data.get(primaryKey)};
-//		Integer count = dataSourceNode.getJdbcTemplate().queryForObject("select count(1) from "+tableName+" where "+primaryKey+" =  ?", params, Integer.class);
-//		if(count > 0){
-//			return jdbcUpdate(tableName,data,primaryKey);
-//		}
-//		return 0;
-//		Object primaryKeyValue = data.get(primaryKey);
-//		if(null == primaryKeyValue){
-//			return jdbcInsert(tableName,data,primaryKey);
-//		}
-//		return jdbcUpdate(tableName,data,primaryKey);
-//	}
-
-//	public int jdbcUpdate(String tableName,Map<String,Object> data,String primaryKey){
-//		StringBuffer sb = new StringBuffer();
-//		sb.append("update ");
-//		sb.append(tableName);
-//		sb.append(" set ");
-//		List<Object> params = new ArrayList<>();
-//		for(Map.Entry<String, Object> entry : data.entrySet()){
-//			String key = entry.getKey();
-//			if(!key.equals(primaryKey)){
-//				sb.append(key + "=" + "?,");
-//				params.add(entry.getValue());
-//			}
-//		}
-//		sb.append(" where ");
-//		sb.append(primaryKey);
-//		sb.append("=?");
-//		params.add(data.get(primaryKey));
-//		return dataSourceNode.getJdbcTemplate().update(sb.toString().replace("?, ","? "),params.toArray());
-//	}
-//
-//	public int jdbcInsert(String tableName,Map<String,Object> data,String primaryKey){
-//		List<Object> params = new ArrayList<>();
-//		params.add("");
-//		List<String> fields = new ArrayList<>();
-//		List<String> valuePlaceholders = new ArrayList<>();
-//		StringBuffer sb = new StringBuffer();
-//		sb.append("insert into ");
-//		sb.append(tableName);
-//		for(Map.Entry<String, Object> entry : data.entrySet()){
-//			String key = entry.getKey();
-//			if(!key.equals(primaryKey)){
-//				fields.add(key);
-//				valuePlaceholders.add("?");
-//				params.add(entry.getValue());
-//			}
-//		}
-//		sb.append("("+ primaryKey + "," + StringUtils.join(fields,",") +")");
-//		sb.append(" values(?,"+StringUtils.join(valuePlaceholders,",")+")");
-//		String id = UUID.randomUUID().toString().replace("-","");
-//		params.set(0,id);
-//		return dataSourceNode.getJdbcTemplate().update(sb.toString(),params.toArray());
-//	}
-
 	/**
 	 * 插入并返回主键
 	 */
@@ -384,27 +303,14 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	 */
 	@Comment("执行分页查询,分页条件手动传入")
 	public Object page(@Comment("`SQL`语句") String sql, @Comment("限制条数") long limit, @Comment("跳过条数") long offset) {
-		BoundSql boundSql = new BoundSql(sql);
-		Connection connection = null;
-		Dialect dialect;
-		try {
-			connection = dataSourceNode.getJdbcTemplate().getDataSource().getConnection();
-			dialect = dialectAdapter.getDialectFromUrl(connection.getMetaData().getURL());
-		} catch (Exception e) {
-			throw new MagicAPIException("自动获取数据库方言失败", e);
-		} finally {
-			DataSourceUtils.releaseConnection(connection, dataSourceNode.getJdbcTemplate().getDataSource());
-		}
-		if (dialect == null) {
-			throw new MagicAPIException("自动获取数据库方言失败");
-		}
-		int count = (int) boundSql.getCacheValue(this.sqlCache, this.cacheName)
-				.orElseGet(() -> putCacheValue(dataSourceNode.getJdbcTemplate().queryForObject(dialect.getCountSql(boundSql.getSql()), Integer.class, boundSql.getParameters()), boundSql));
-		List<Object> list = null;
+		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
+		Dialect dialect = dataSourceNode.getDialect(dialectAdapter);
+		String countSql = dialect.getCountSql(boundSql.getSql());
+		int count = boundSql.getCacheValue(countSql, boundSql.getParameters(), () -> dataSourceNode.getJdbcTemplate().queryForObject(countSql, Integer.class, boundSql.getParameters()));
+		List<Map<String, Object>> list = null;
 		if (count > 0) {
 			String pageSql = dialect.getPageSql(boundSql.getSql(), boundSql, offset, limit);
-			list = (List<Object>) boundSql.removeCacheKey().getCacheValue(this.sqlCache, this.cacheName)
-					.orElseGet(() -> putCacheValue(dataSourceNode.getJdbcTemplate().query(pageSql, this.columnMapRowMapper, boundSql.getParameters()), boundSql));
+			list = boundSql.getCacheValue(pageSql, boundSql.getParameters(), () -> dataSourceNode.getJdbcTemplate().query(pageSql, this.columnMapRowMapper, boundSql.getParameters()));
 		}
 		return resultProvider.buildPageResult(count, list);
 	}
@@ -414,9 +320,8 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	 */
 	@Comment("查询int值,适合单行单列int的结果")
 	public Integer selectInt(@Comment("`SQL`语句") String sql) {
-		BoundSql boundSql = new BoundSql(sql);
-		return (Integer) boundSql.getCacheValue(this.sqlCache, this.cacheName)
-				.orElseGet(() -> putCacheValue(dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Integer.class), boundSql));
+		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
+		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Integer.class));
 	}
 
 	/**
@@ -424,12 +329,11 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	 */
 	@Comment("查询单条结果,查不到返回null")
 	public Map<String, Object> selectOne(@Comment("`SQL`语句") String sql) {
-		BoundSql boundSql = new BoundSql(sql);
-		return (Map<String, Object>) boundSql.getCacheValue(this.sqlCache, this.cacheName)
-				.orElseGet(() -> {
-					List<Map<String, Object>> list = dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters());
-					return list != null && list.size() > 0 ? list.get(0) : null;
-				});
+		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
+		return boundSql.getCacheValue(() -> {
+			List<Map<String, Object>> list = dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters());
+			return list != null && list.size() > 0 ? list.get(0) : null;
+		});
 	}
 
 	/**
@@ -437,9 +341,8 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	 */
 	@Comment("查询单行单列的值")
 	public Object selectValue(@Comment("`SQL`语句") String sql) {
-		BoundSql boundSql = new BoundSql(sql);
-		return boundSql.getCacheValue(this.sqlCache, this.cacheName)
-				.orElseGet(() -> putCacheValue(dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Object.class), boundSql));
+		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
+		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Object.class));
 	}
 
 	@UnableCall

+ 2 - 1
src/main/java/org/ssssssss/magicapi/provider/ResultProvider.java

@@ -7,6 +7,7 @@ import org.ssssssss.script.exception.MagicScriptAssertException;
 import org.ssssssss.script.exception.MagicScriptException;
 
 import java.util.List;
+import java.util.Map;
 
 /**
  * 结果构建接口
@@ -69,7 +70,7 @@ public interface ResultProvider {
 	 * @param total 总数
 	 * @param data  数据内容
 	 */
-	default Object buildPageResult(long total, List<Object> data) {
+	default Object buildPageResult(long total, List<Map<String,Object>> data) {
 		return new PageResult<>(total, data);
 	}
 }