diff --git a/kyuubi-hive-jdbc/src/main/java/org/apache/kyuubi/jdbc/hive/Utils.java b/kyuubi-hive-jdbc/src/main/java/org/apache/kyuubi/jdbc/hive/Utils.java index 2a0462aed..881a96027 100644 --- a/kyuubi-hive-jdbc/src/main/java/org/apache/kyuubi/jdbc/hive/Utils.java +++ b/kyuubi-hive-jdbc/src/main/java/org/apache/kyuubi/jdbc/hive/Utils.java @@ -106,25 +106,47 @@ public class Utils { */ static List splitSqlStatement(String sql) { List parts = new ArrayList<>(); - int apCount = 0; + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + boolean inComment = false; int off = 0; boolean skip = false; for (int i = 0; i < sql.length(); i++) { char c = sql.charAt(i); + if (inComment) { + inComment = (c != '\n'); + continue; + } if (skip) { skip = false; continue; } switch (c) { case '\'': - apCount++; + if (!inDoubleQuote) { + inSingleQuote = !inSingleQuote; + } + break; + case '\"': + if (!inSingleQuote) { + inDoubleQuote = !inDoubleQuote; + } + break; + case '-': + if (!inSingleQuote && !inDoubleQuote) { + if (i < sql.length() - 1 && sql.charAt(i + 1) == '-') { + inComment = true; + } + } break; case '\\': - skip = true; + if (!inSingleQuote && !inDoubleQuote) { + skip = true; + } break; case '?': - if ((apCount & 1) == 0) { + if (!inSingleQuote && !inDoubleQuote) { parts.add(sql.substring(off, i)); off = i + 1; } diff --git a/kyuubi-hive-jdbc/src/test/java/org/apache/kyuubi/jdbc/hive/UtilsTest.java b/kyuubi-hive-jdbc/src/test/java/org/apache/kyuubi/jdbc/hive/UtilsTest.java index fc4a55d9f..87f1a78de 100644 --- a/kyuubi-hive-jdbc/src/test/java/org/apache/kyuubi/jdbc/hive/UtilsTest.java +++ b/kyuubi-hive-jdbc/src/test/java/org/apache/kyuubi/jdbc/hive/UtilsTest.java @@ -25,10 +25,7 @@ import com.google.common.collect.ImmutableMap; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Collection; -import java.util.Map; -import java.util.Properties; +import java.util.*; import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; @@ -156,4 +153,49 @@ public class UtilsTest { Pattern pattern = Pattern.compile("^\\d+\\.\\d+\\.\\d+.*"); assert pattern.matcher(Utils.getVersion()).matches(); } + + @Test + public void testSplitSqlStatement() { + String simpleSql = "select 1 from ? where a = ?"; + List splitSql = Utils.splitSqlStatement(simpleSql); + assertEquals(3, splitSql.size()); + assertEquals("select 1 from ", splitSql.get(0)); + assertEquals(" where a = ", splitSql.get(1)); + assertEquals("", splitSql.get(2)); + + String placeHolderWithinSingleQuote = "select '?' from ? where a = ?"; + splitSql = Utils.splitSqlStatement(placeHolderWithinSingleQuote); + assertEquals(3, splitSql.size()); + assertEquals("select '?' from ", splitSql.get(0)); + assertEquals(" where a = ", splitSql.get(1)); + assertEquals("", splitSql.get(2)); + + String escapePlaceHolder = "select \\? from ? where a = ?"; + splitSql = Utils.splitSqlStatement(escapePlaceHolder); + assertEquals(3, splitSql.size()); + assertEquals("select \\? from ", splitSql.get(0)); + assertEquals(" where a = ", splitSql.get(1)); + assertEquals("", splitSql.get(2)); + + String inQuoteLikeRegexFunction = + "select " + + "regexp_extract(field_a, \"[a-zA-Z]+?\", 0) as extracted_a," + + "regexp_extract(field_b, '[a-zA-Z]+?', 0) as extracted_b" + + " from ?"; + splitSql = Utils.splitSqlStatement(inQuoteLikeRegexFunction); + assertEquals(2, splitSql.size()); + assertEquals( + "select " + + "regexp_extract(field_a, \"[a-zA-Z]+?\", 0) as extracted_a," + + "regexp_extract(field_b, '[a-zA-Z]+?', 0) as extracted_b from ", + splitSql.get(0)); + assertEquals("", splitSql.get(1)); + + String inCommentBlock = "--comments\n" + "select --? \n" + "? from ?"; + splitSql = Utils.splitSqlStatement(inCommentBlock); + assertEquals(3, splitSql.size()); + assertEquals("--comments\n" + "select --? \n", splitSql.get(0)); + assertEquals(" from ", splitSql.get(1)); + assertEquals("", splitSql.get(2)); + } }