소스 검색

增加`SQL`拦截器

mxd 4 년 전
부모
커밋
7147a5e805

+ 9 - 0
src/main/java/org/ssssssss/magicapi/interceptor/SQLInterceptor.java

@@ -0,0 +1,9 @@
+package org.ssssssss.magicapi.interceptor;
+
+import org.ssssssss.magicapi.modules.BoundSql;
+
+public interface SQLInterceptor {
+
+	void preHandle(BoundSql boundSql);
+
+}

+ 34 - 3
src/main/java/org/ssssssss/magicapi/modules/BoundSql.java

@@ -1,6 +1,7 @@
 package org.ssssssss.magicapi.modules;
 
 import org.ssssssss.magicapi.cache.SqlCache;
+import org.ssssssss.magicapi.interceptor.SQLInterceptor;
 import org.ssssssss.script.MagicScriptContext;
 import org.ssssssss.script.functions.StreamExtension;
 import org.ssssssss.script.parsing.GenericTokenParser;
@@ -34,7 +35,7 @@ public class BoundSql {
 
 	private long ttl;
 
-	BoundSql(String sql){
+	BoundSql(String sql) {
 		MagicScriptContext context = MagicScriptContext.get();
 		// 处理?{}参数
 		this.sql = ifTokenParser.parse(sql.trim(), text -> {
@@ -80,6 +81,21 @@ public class BoundSql {
 		this.ttl = ttl;
 	}
 
+	private BoundSql() {
+
+	}
+
+	BoundSql copy(String newSql) {
+		BoundSql boundSql = new BoundSql();
+		boundSql.setParameters(new ArrayList<>(this.parameters));
+		boundSql.setSql(this.sql);
+		boundSql.ttl = this.ttl;
+		boundSql.cacheName = this.cacheName;
+		boundSql.sqlCache = this.sqlCache;
+		boundSql.sql = newSql;
+		return boundSql;
+	}
+
 	/**
 	 * 添加SQL参数
 	 */
@@ -94,6 +110,20 @@ public class BoundSql {
 		return sql;
 	}
 
+	/**
+	 * 设置要执行的SQL
+	 */
+	public void setSql(String sql) {
+		this.sql = sql;
+	}
+
+	/**
+	 * 设置要执行的参数
+	 */
+	public void setParameters(List<Object> parameters) {
+		this.parameters = parameters;
+	}
+
 	/**
 	 * 获取要执行的参数
 	 */
@@ -105,7 +135,7 @@ public class BoundSql {
 	/**
 	 * 获取缓存值
 	 */
-	<T> T getCacheValue(String sql, Object[] params, Supplier<T> supplier) {
+	private <T> T getCacheValue(String sql, Object[] params, Supplier<T> supplier) {
 		if (cacheName == null) {
 			return supplier.get();
 		}
@@ -122,7 +152,8 @@ public class BoundSql {
 	/**
 	 * 获取缓存值
 	 */
-	<T> T getCacheValue(Supplier<T> supplier) {
+	<T> T getCacheValue(List<SQLInterceptor> interceptors, Supplier<T> supplier) {
+		interceptors.forEach(interceptor -> interceptor.preHandle(this));
 		return getCacheValue(this.getSql(), this.getParameters(), supplier);
 	}
 }

+ 21 - 8
src/main/java/org/ssssssss/magicapi/modules/SQLModule.java

@@ -11,6 +11,7 @@ import org.ssssssss.magicapi.config.MagicDynamicDataSource;
 import org.ssssssss.magicapi.config.MagicDynamicDataSource.DataSourceNode;
 import org.ssssssss.magicapi.config.MagicModule;
 import org.ssssssss.magicapi.dialect.Dialect;
+import org.ssssssss.magicapi.interceptor.SQLInterceptor;
 import org.ssssssss.magicapi.model.Page;
 import org.ssssssss.magicapi.provider.PageProvider;
 import org.ssssssss.magicapi.provider.ResultProvider;
@@ -60,6 +61,9 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@UnableCall
 	private String cacheName;
 
+	@UnableCall
+	private List<SQLInterceptor> sqlInterceptors;
+
 	@UnableCall
 	private long ttl;
 
@@ -106,6 +110,11 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 		this.dynamicDataSource = dynamicDataSource;
 	}
 
+	@UnableCall
+	public void setSqlInterceptors(List<SQLInterceptor> sqlInterceptors) {
+		this.sqlInterceptors = sqlInterceptors;
+	}
+
 	@UnableCall
 	public void setSqlCache(SqlCache sqlCache) {
 		this.sqlCache = sqlCache;
@@ -136,6 +145,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 		sqlModule.setTtl(this.ttl);
 		sqlModule.setResultProvider(this.resultProvider);
 		sqlModule.setDialectAdapter(this.dialectAdapter);
+		sqlModule.setSqlInterceptors(this.sqlInterceptors);
 		return sqlModule;
 	}
 
@@ -251,7 +261,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("查询SQL,返回List类型结果")
 	public List<Map<String, Object>> select(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
-		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters()));
+		return boundSql.getCacheValue(this.sqlInterceptors, () -> dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters()));
 	}
 
 	/**
@@ -260,6 +270,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("执行update操作,返回受影响行数")
 	public int update(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql);
+		sqlInterceptors.forEach(sqlInterceptor -> sqlInterceptor.preHandle(boundSql));
 		int value = dataSourceNode.getJdbcTemplate().update(boundSql.getSql(), boundSql.getParameters());
 		if (this.cacheName != null) {
 			this.sqlCache.delete(this.cacheName);
@@ -273,6 +284,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("执行insert操作,返回插入条数")
 	public long insert(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql);
+		sqlInterceptors.forEach(sqlInterceptor -> sqlInterceptor.preHandle(boundSql));
 		KeyHolder keyHolder = new GeneratedKeyHolder();
 		dataSourceNode.getJdbcTemplate().update(con -> {
 			PreparedStatement ps = con.prepareStatement(boundSql.getSql(), Statement.RETURN_GENERATED_KEYS);
@@ -305,12 +317,13 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	public Object page(@Comment("`SQL`语句") String sql, @Comment("限制条数") long limit, @Comment("跳过条数") long offset) {
 		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
 		Dialect dialect = dataSourceNode.getDialect(dialectAdapter);
-		String countSql = dialect.getCountSql(boundSql.getSql());
-		int count = boundSql.getCacheValue(countSql, boundSql.getParameters(), () -> dataSourceNode.getJdbcTemplate().queryForObject(countSql, Integer.class, boundSql.getParameters()));
+		BoundSql countBoundSql = boundSql.copy(dialect.getCountSql(boundSql.getSql()));
+		int count = countBoundSql.getCacheValue(this.sqlInterceptors, () -> dataSourceNode.getJdbcTemplate().queryForObject(countBoundSql.getSql(), Integer.class, countBoundSql.getParameters()));
 		List<Map<String, Object>> list = null;
 		if (count > 0) {
 			String pageSql = dialect.getPageSql(boundSql.getSql(), boundSql, offset, limit);
-			list = boundSql.getCacheValue(pageSql, boundSql.getParameters(), () -> dataSourceNode.getJdbcTemplate().query(pageSql, this.columnMapRowMapper, boundSql.getParameters()));
+			BoundSql pageBoundSql = boundSql.copy(dialect.getCountSql(boundSql.getSql()));
+			list = pageBoundSql.getCacheValue(this.sqlInterceptors, () -> dataSourceNode.getJdbcTemplate().query(pageBoundSql.getSql(), this.columnMapRowMapper, pageBoundSql.getParameters()));
 		}
 		return resultProvider.buildPageResult(count, list);
 	}
@@ -321,7 +334,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("查询int值,适合单行单列int的结果")
 	public Integer selectInt(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
-		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Integer.class));
+		return boundSql.getCacheValue(this.sqlInterceptors, () -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Integer.class));
 	}
 
 	/**
@@ -330,7 +343,7 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("查询单条结果,查不到返回null")
 	public Map<String, Object> selectOne(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
-		return boundSql.getCacheValue(() -> {
+		return boundSql.getCacheValue(this.sqlInterceptors, () -> {
 			List<Map<String, Object>> list = dataSourceNode.getJdbcTemplate().query(boundSql.getSql(), this.columnMapRowMapper, boundSql.getParameters());
 			return list != null && list.size() > 0 ? list.get(0) : null;
 		});
@@ -342,12 +355,12 @@ public class SQLModule extends HashMap<String, SQLModule> implements MagicModule
 	@Comment("查询单行单列的值")
 	public Object selectValue(@Comment("`SQL`语句") String sql) {
 		BoundSql boundSql = new BoundSql(sql, this.sqlCache, this.cacheName, this.ttl);
-		return boundSql.getCacheValue(() -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Object.class));
+		return boundSql.getCacheValue(this.sqlInterceptors, () -> dataSourceNode.getJdbcTemplate().queryForObject(boundSql.getSql(), boundSql.getParameters(), Object.class));
 	}
 
 	@Comment("指定table,进行一系列操作")
 	public NamedTable table(String tableName) {
-		return new NamedTable(tableName,this.dataSourceNode);
+		return new NamedTable(tableName, this.dataSourceNode);
 	}
 
 	@UnableCall