Ver Fonte

增加Oracle、DB2、PostgreSQL、SQLServer方言

mxd há 5 anos atrás
pai
commit
fcbb185288

+ 13 - 0
src/main/java/org/ssssssss/dialect/DB2Dialect.java

@@ -0,0 +1,13 @@
+package org.ssssssss.dialect;
+
+import org.ssssssss.context.RequestContext;
+
+public class DB2Dialect implements Dialect {
+    @Override
+    public String getPageSql(String sql, RequestContext context, long offset, long limit) {
+        context.addParameter(offset + 1);
+        context.addParameter(offset + limit);
+        return "SELECT * FROM (SELECT TMP_PAGE.*,ROWNUMBER() OVER() AS ROW_ID FROM ( " + sql +
+                " ) AS TMP_PAGE) TMP_PAGE WHERE ROW_ID BETWEEN ? AND ?";
+    }
+}

+ 3 - 1
src/main/java/org/ssssssss/dialect/Dialect.java

@@ -1,5 +1,7 @@
 package org.ssssssss.dialect;
 
+import org.ssssssss.context.RequestContext;
+
 public interface Dialect {
 
     /**
@@ -12,5 +14,5 @@ public interface Dialect {
     /**
      * 获取分页sql
      */
-    String getPageSql(String sql, long offset, long limit);
+    String getPageSql(String sql, RequestContext context, long offset, long limit);
 }

+ 18 - 4
src/main/java/org/ssssssss/dialect/DialectUtils.java

@@ -1,10 +1,15 @@
 package org.ssssssss.dialect;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
 public class DialectUtils {
 
+    private static Logger logger = LoggerFactory.getLogger(DialectUtils.class);
+
     /**
      * 缓存已解析的方言
      */
@@ -15,10 +20,19 @@ public class DialectUtils {
      */
     public static Dialect getDialectFromUrl(String fromUrl) {
         Dialect dialect = dialectMap.get(fromUrl);
-        if (dialect == null) {
-            //判断mysql
-            if (fromUrl.startsWith("jdbc:mysql:") || fromUrl.startsWith("jdbc:cobar:") || fromUrl.startsWith("jdbc:log4jdbc:mysql:")) {
-                dialect = new MySqlDialect();
+        if (dialect == null && !dialectMap.containsKey(fromUrl)) {
+            if (fromUrl.startsWith("jdbc:mysql:") || fromUrl.startsWith("jdbc:cobar:") || fromUrl.startsWith("jdbc:log4jdbc:mysql:") || fromUrl.startsWith("jdbc:mariadb:")) {
+                dialect = new MySQLDialect();
+            } else if (fromUrl.startsWith("jdbc:oracle:") || fromUrl.startsWith("jdbc:log4jdbc:oracle:")) {
+                dialect = new OracleDialect();
+            } else if (fromUrl.startsWith("jdbc:sqlserver2012:")) {
+                dialect = new SQLServerDialect();
+            } else if (fromUrl.startsWith("jdbc:postgresql:") || fromUrl.startsWith("jdbc:log4jdbc:postgresql:")) {
+                dialect = new PostgreSQLDialect();
+            } else if (fromUrl.startsWith("jdbc:db2:")) {
+                dialect = new DB2Dialect();
+            } else {
+                logger.warn(String.format("ssssssss在%s中无法获取dialect", fromUrl));
             }
             dialectMap.put(fromUrl, dialect);
         }

+ 13 - 0
src/main/java/org/ssssssss/dialect/MySQLDialect.java

@@ -0,0 +1,13 @@
+package org.ssssssss.dialect;
+
+import org.ssssssss.context.RequestContext;
+
+public class MySQLDialect implements Dialect {
+
+    @Override
+    public String getPageSql(String sql, RequestContext context, long offset, long limit) {
+        context.addParameter(limit);
+        context.addParameter(offset);
+        return sql + " limit ?,?";
+    }
+}

+ 0 - 9
src/main/java/org/ssssssss/dialect/MySqlDialect.java

@@ -1,9 +0,0 @@
-package org.ssssssss.dialect;
-
-public class MySqlDialect implements Dialect {
-
-    @Override
-    public String getPageSql(String sql, long offset, long limit) {
-        return sql + " limit ?,?";
-    }
-}

+ 15 - 0
src/main/java/org/ssssssss/dialect/OracleDialect.java

@@ -0,0 +1,15 @@
+package org.ssssssss.dialect;
+
+import org.ssssssss.context.RequestContext;
+
+public class OracleDialect implements Dialect {
+
+    @Override
+    public String getPageSql(String sql, RequestContext context, long offset, long limit) {
+        limit = (offset >= 1) ? (offset + limit) : limit;
+        context.addParameter(limit);
+        context.addParameter(offset);
+        return "SELECT * FROM ( SELECT TMP.*, ROWNUM ROW_ID FROM ( " +
+                sql + " ) TMP WHERE ROWNUM <= ? ) WHERE ROW_ID > ?";
+    }
+}

+ 12 - 0
src/main/java/org/ssssssss/dialect/PostgreSQLDialect.java

@@ -0,0 +1,12 @@
+package org.ssssssss.dialect;
+
+import org.ssssssss.context.RequestContext;
+
+public class PostgreSQLDialect implements Dialect {
+    @Override
+    public String getPageSql(String sql, RequestContext context, long offset, long limit) {
+        context.addParameter(limit);
+        context.addParameter(offset);
+        return sql + " limit ? offset ?";
+    }
+}

+ 12 - 0
src/main/java/org/ssssssss/dialect/SQLServerDialect.java

@@ -0,0 +1,12 @@
+package org.ssssssss.dialect;
+
+import org.ssssssss.context.RequestContext;
+
+public class SQLServerDialect implements Dialect {
+    @Override
+    public String getPageSql(String sql, RequestContext context, long offset, long limit) {
+        context.addParameter(offset);
+        context.addParameter(limit);
+        return sql + " OFFSET ? ROWS FETCH NEXT ? ROWS ONLY";
+    }
+}

+ 1 - 3
src/main/java/org/ssssssss/executor/StatementExecutor.java

@@ -163,10 +163,8 @@ public class StatementExecutor {
                 // 当条数>0时,执行查询语句,否则不查询以提高性能
                 if (total > 0) {
                     // 获取分页语句
-                    String pageSql = dialect.getPageSql(sql, page.getOffset(), page.getLimit());
+                    String pageSql = dialect.getPageSql(sql, context, page.getOffset(), page.getLimit());
                     // 设置分页参数
-                    context.addParameter(page.getLimit());
-                    context.addParameter(page.getOffset());
                     // 执行查询
                     pageResult.setList(sqlExecutor.queryForList(connection, pageSql, context.getParameters(), sqlStatement.getReturnType()));
                 }