소스 검색

`<foreach>`新增`index`参数、新增`<elseif>`、`<else>`标签

mxd 3 년 전
부모
커밋
644d3c751e

+ 13 - 0
magic-api/src/main/java/org/ssssssss/magicapi/modules/db/mybatis/ForeachSqlNode.java

@@ -36,6 +36,11 @@ public class ForeachSqlNode extends SqlNode {
 	 */
 	private String separator;
 
+	/**
+	 * 序号
+	 */
+	private String index;
+
 	public void setCollection(String collection) {
 		this.collection = collection;
 	}
@@ -56,6 +61,10 @@ public class ForeachSqlNode extends SqlNode {
 		this.separator = separator;
 	}
 
+	public void setIndex(String index) {
+		this.index = index;
+	}
+
 	@Override
 	public String getSql(Map<String, Object> paramMap, List<Object> parameters) {
 		// 提取集合
@@ -74,11 +83,15 @@ public class ForeachSqlNode extends SqlNode {
 		}
 		// 开始拼接SQL,
 		StringBuilder sqlBuilder = new StringBuilder(StringUtils.defaultString(this.open));
+		boolean hasIndex = index != null && index.length() > 0;
 		// 获取数组长度
 		int len = Array.getLength(value);
 		for (int i = 0; i < len; i++) {
 			// 存入item对象
 			paramMap.put(this.item, Array.get(value, i));
+			if (hasIndex) {
+				paramMap.put(this.index, i);
+			}
 			// 拼接子节点
 			sqlBuilder.append(executeChildren(paramMap, parameters));
 			// 拼接分隔符

+ 8 - 2
magic-api/src/main/java/org/ssssssss/magicapi/modules/db/mybatis/IfSqlNode.java

@@ -7,7 +7,7 @@ import java.util.List;
 import java.util.Map;
 
 /**
- * 对应XML中 <if>
+ * 对应XML中 <if>、<elseif>
  *
  * @author jmxd
  * @version : 2020-05-18
@@ -18,8 +18,11 @@ public class IfSqlNode extends SqlNode {
 	 */
 	private final String test;
 
-	public IfSqlNode(String test) {
+	private final SqlNode nextNode;
+
+	public IfSqlNode(String test, SqlNode nextNode) {
 		this.test = test;
+		this.nextNode = nextNode;
 	}
 
 	@Override
@@ -30,6 +33,9 @@ public class IfSqlNode extends SqlNode {
 		if (BooleanLiteral.isTrue(value)) {
 			return executeChildren(paramMap, parameters);
 		}
+		if (nextNode != null) {
+			return nextNode.getSql(paramMap, parameters);
+		}
 		return "";
 	}
 }

+ 50 - 33
magic-api/src/main/java/org/ssssssss/magicapi/modules/db/mybatis/MybatisParser.java

@@ -3,7 +3,6 @@ package org.ssssssss.magicapi.modules.db.mybatis;
 import org.ssssssss.magicapi.core.exception.MagicAPIException;
 import org.w3c.dom.Document;
 import org.w3c.dom.Node;
-import org.w3c.dom.NodeList;
 
 import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
@@ -22,7 +21,7 @@ public class MybatisParser {
 			DocumentBuilder documentBuilder = DocumentBuilderFactory.newInstance().newDocumentBuilder();
 			Document document = documentBuilder.parse(new ByteArrayInputStream(xml.getBytes()));
 			SqlNode sqlNode = new TextSqlNode("");
-			parseNodeList(sqlNode, document.getDocumentElement().getChildNodes());
+			parseNodeList(sqlNode, new NodeStream(document.getDocumentElement().getChildNodes()));
 			return sqlNode;
 		} catch (Exception e) {
 			throw new MagicAPIException("SQL解析错误", e);
@@ -33,72 +32,90 @@ public class MybatisParser {
 		return ESCAPE_LT_PATTERN.matcher(xml).replaceAll(ESCAPE_LT_REPLACEMENT);
 	}
 
-	private static void parseNodeList(SqlNode sqlNode, NodeList nodeList) {
-		for (int i = 0, len = nodeList.getLength(); i < len; i++) {
-			Node node = nodeList.item(i);
-			if (node.getNodeType() == Node.TEXT_NODE) {
-				sqlNode.addChildNode(new TextSqlNode(node.getNodeValue().trim()));
-			} else if (node.getNodeType() != Node.COMMENT_NODE) {
-				String nodeName = node.getNodeName();
-				SqlNode childNode;
-				if ("foreach".equalsIgnoreCase(nodeName)) {
-					childNode = parseForeachSqlNode(node);
-				} else if ("if".equalsIgnoreCase(nodeName)) {
-					childNode = new IfSqlNode(getNodeAttributeValue(node, "test"));
-				} else if ("trim".equalsIgnoreCase(nodeName)) {
-					childNode = parseTrimSqlNode(node);
-				} else if ("set".equalsIgnoreCase(nodeName)) {
-					childNode = parseSetSqlNode();
-				} else if ("where".equalsIgnoreCase(nodeName)) {
-					childNode = parseWhereSqlNode();
+	private static void parseNodeList(SqlNode sqlNode, NodeStream stream) {
+		while (stream.hasMore()) {
+			SqlNode childNode;
+			if (stream.match(Node.TEXT_NODE)) {
+				childNode = new TextSqlNode(stream.consume().getNodeValue().trim());
+			} else {
+				if (stream.match("foreach")) {
+					childNode = parseForeachSqlNode(stream);
+				} else if (stream.match("if")) {
+					childNode = parseIfSqlNode(stream);
+				} else if (stream.match("trim")) {
+					childNode = parseTrimSqlNode(stream);
+				} else if (stream.match("set")) {
+					childNode = parseSetSqlNode(stream);
+				} else if (stream.match("where")) {
+					childNode = parseWhereSqlNode(stream);
 				} else {
-					throw new UnsupportedOperationException("Unsupported tags :" + nodeName);
-				}
-				sqlNode.addChildNode(childNode);
-				if (node.hasChildNodes()) {
-					parseNodeList(childNode, node.getChildNodes());
+					throw new UnsupportedOperationException("Unsupported tags :" + stream.consume().getNodeName());
 				}
 			}
+			sqlNode.addChildNode(childNode);
+		}
+	}
+
+	private static IfSqlNode parseIfSqlNode(NodeStream stream) {
+		Node ifNode = stream.consume();
+		String test = getNodeAttributeValue(ifNode, "test");
+		SqlNode nextNode = null;
+		if (stream.match("else")) {
+			nextNode = new TextSqlNode("");
+			parseNodeList(nextNode, new NodeStream(stream.consume().getChildNodes()));
+		} else if (stream.match("elseif")) {
+			nextNode = parseIfSqlNode(stream);
+		}
+		return processChildren(new IfSqlNode(test, nextNode), ifNode);
+	}
+
+	private static <T extends SqlNode> T processChildren(T sqlNode, Node node) {
+		if (node.hasChildNodes()) {
+			parseNodeList(sqlNode, new NodeStream(node.getChildNodes()));
 		}
+		return sqlNode;
 	}
 
 	/**
 	 * 解析foreach节点
 	 */
-	private static ForeachSqlNode parseForeachSqlNode(Node node) {
+	private static ForeachSqlNode parseForeachSqlNode(NodeStream stream) {
+		Node node = stream.consume();
 		ForeachSqlNode foreachSqlNode = new ForeachSqlNode();
 		foreachSqlNode.setCollection(getNodeAttributeValue(node, "collection"));
 		foreachSqlNode.setSeparator(getNodeAttributeValue(node, "separator"));
 		foreachSqlNode.setClose(getNodeAttributeValue(node, "close"));
 		foreachSqlNode.setOpen(getNodeAttributeValue(node, "open"));
 		foreachSqlNode.setItem(getNodeAttributeValue(node, "item"));
-		return foreachSqlNode;
+		foreachSqlNode.setIndex(getNodeAttributeValue(node, "index"));
+		return processChildren(foreachSqlNode, node);
 	}
 
 	/**
 	 * 解析trim节点
 	 */
-	private static TrimSqlNode parseTrimSqlNode(Node node) {
+	private static TrimSqlNode parseTrimSqlNode(NodeStream stream) {
+		Node node = stream.consume();
 		TrimSqlNode trimSqlNode = new TrimSqlNode();
 		trimSqlNode.setPrefix(getNodeAttributeValue(node, "prefix"));
 		trimSqlNode.setPrefixOverrides(getNodeAttributeValue(node, "prefixOverrides"));
 		trimSqlNode.setSuffix(getNodeAttributeValue(node, "suffix"));
 		trimSqlNode.setSuffixOverrides(getNodeAttributeValue(node, "suffixOverrides"));
-		return trimSqlNode;
+		return processChildren(trimSqlNode, node);
 	}
 
 	/**
 	 * 解析set节点
 	 */
-	private static SetSqlNode parseSetSqlNode() {
-		return new SetSqlNode();
+	private static SetSqlNode parseSetSqlNode(NodeStream stream) {
+		return processChildren(new SetSqlNode(), stream.consume());
 	}
 
 	/**
 	 * 解析where节点
 	 */
-	private static WhereSqlNode parseWhereSqlNode() {
-		return new WhereSqlNode();
+	private static WhereSqlNode parseWhereSqlNode(NodeStream stream) {
+		return processChildren(new WhereSqlNode(), stream.consume());
 	}
 
 	private static String getNodeAttributeValue(Node node, String attributeKey) {

+ 50 - 0
magic-api/src/main/java/org/ssssssss/magicapi/modules/db/mybatis/NodeStream.java

@@ -0,0 +1,50 @@
+package org.ssssssss.magicapi.modules.db.mybatis;
+
+import org.w3c.dom.Node;
+import org.w3c.dom.NodeList;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class NodeStream {
+
+	private final List<Node> nodes;
+
+	private int index = 0;
+
+	private final int len;
+
+	public NodeStream(NodeList nodeList) {
+		this.nodes = filterCommentAndBlankNodes(nodeList);
+		this.len = this.nodes.size();
+	}
+
+	private static List<Node> filterCommentAndBlankNodes(NodeList nodeList) {
+		List<Node> nodes = new ArrayList<>();
+		for (int i = 0, len = nodeList.getLength(); i < len; i++) {
+			Node node = nodeList.item(i);
+			short nodeType = node.getNodeType();
+			if (nodeType != Node.COMMENT_NODE && (nodeType != Node.TEXT_NODE || node.getNodeValue().trim().length() > 0)) {
+				nodes.add(node);
+			}
+		}
+		return nodes;
+	}
+
+	public boolean match(String nodeName) {
+		return hasMore() && nodeName.equalsIgnoreCase(this.nodes.get(this.index).getNodeName());
+	}
+
+
+	public boolean match(short nodeType) {
+		return hasMore() && nodeType == this.nodes.get(this.index).getNodeType();
+	}
+
+	public Node consume() {
+		return this.nodes.get(this.index++);
+	}
+
+	public boolean hasMore() {
+		return this.index < this.len;
+	}
+}