diff --git a/.gitignore b/.gitignore
index e705c32..d8d23e0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,3 @@
*.class
*.log
-/target
-/project
\ No newline at end of file
+**/target
diff --git a/.scalafmt.conf b/.scalafmt.conf
new file mode 100644
index 0000000..f825c99
--- /dev/null
+++ b/.scalafmt.conf
@@ -0,0 +1,12 @@
+version = 3.9.8
+style = default
+runner.dialect=scala212
+maxColumn = 120
+continuationIndent.defnSite = 2
+continuationIndent.callSite = 2
+align.preset = "none"
+danglingParentheses.preset = true
+optIn.configStyleArguments = false
+docstrings.style = SpaceAsterisk
+spaces.beforeContextBoundColon = true
+rewrite.rules = [SortImports]
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..6d1098c
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,56 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [Unreleased]
+
+### Added
+- Multi-dialect support for SQL parsing:
+ - Oracle SQL dialect
+ - PostgreSQL dialect
+ - ANSI SQL dialect
+- CLI application (`Main.scala`) with comprehensive options:
+ - Multiple output formats: text, JSON, JQ queries
+ - File pattern matching (glob support)
+ - Configurable output destinations (console or file)
+- Error recovery for unparseable statements
+- Circe-based JSON encoding for all AST types
+- Assembly plugin support for building standalone JARs
+- Comprehensive test suite with Oracle-specific tests
+- Support for parsing multiple statements from a single file with accurate line number tracking
+
+### Changed
+- **BREAKING**: Major refactoring of package structure:
+ - Moved core types to `sparsity.common.*` package
+ - Parser base classes now in `sparsity.common.parser.*`
+ - Statement types in `sparsity.common.statement.*`
+ - Expression types in `sparsity.common.expression.*`
+- Refactored parser architecture:
+ - Moved logic from `Elements` into `SQLBase` for better subclass extensibility
+ - Created dialect-specific parser objects
+- Enhanced build configuration:
+ - Updated SBT to 1.11.7
+ - Cross-compilation for Scala 2.12, 2.13, and 3.3
+ - Added assembly, scalafmt, and cross-project plugins
+- Updated dependencies:
+ - fastparse to 3.1.1
+ - Added scribe for logging
+ - Added mainargs for CLI argument parsing
+ - Added monocle for optics
+ - Added java-jq for JQ query support
+
+### Fixed
+- Improved error messages with better context
+- Better handling of parsing failures with fallback mechanisms
+
+## [1.7.0] - Previous Release
+
+### Added
+- Cross-compiled support for Scala 2.12, 2.13, and 3.3
+- Updated fastparse dependency
+
+[Unreleased]: https://github.com/UBOdin/sparsity/compare/v1.7.0...HEAD
+[1.7.0]: https://github.com/UBOdin/sparsity/releases/tag/v1.7.0
diff --git a/README.md b/README.md
index ad5d268..56b5b60 100644
--- a/README.md
+++ b/README.md
@@ -18,9 +18,6 @@ val tree = SQL("SELECT * FROM R")
libraryDependencies ++= Seq("info.mimirdb" %% "sparsity" % "1.5")
```
-
## Version History
-
-#### 1.7 (Breaking)
-- Added cross-compiled support for Scala 2.12
\ No newline at end of file
+See [CHANGELOG.md](CHANGELOG.md) for detailed version history and release notes.
\ No newline at end of file
diff --git a/build.sbt b/build.sbt
index a39b0c1..2bd3f2c 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,49 +1,109 @@
import scala.sys.process._
+import scala.scalanative.build._
+import sbtcrossproject.CrossPlugin.autoImport._
+import sbtassembly.AssemblyPlugin.autoImport._
-name := "Sparsity"
-version := "1.7.1-SNAPSHOT"
-organization := "info.mimirdb"
-scalaVersion := "2.12.20"
-crossScalaVersions := Seq("2.12.20", "2.13.17", "3.3.6")
+val Scala3 = "3.3.6"
+val Scala2_13 = "2.13.17"
+val Scala2_12 = "2.12.20"
-dependencyOverrides ++= {
+val circeVersion = "0.14.14"
+val circeOpticsVersion = "0.15.1"
+val jsonPathVersion = "0.2.0"
+val specs2Version = "4.23.0"
+val fastparseVersion = "3.1.1"
+val javaJqVersion = "2.0.0"
+val scribeVersion = "3.17.0"
+val monocleVersion = "3.3.0"
+val mainArgsVersion = "0.7.7"
+
+ThisBuild / scalafmtOnCompile := true
+ThisBuild / name := "Sparsity"
+ThisBuild / version := "1.7.1-SNAPSHOT"
+ThisBuild / organization := "info.mimirdb"
+ThisBuild / scalaVersion := Scala2_13
+ThisBuild / crossScalaVersions := Seq(Scala2_12, Scala2_13, Scala3)
+
+ThisBuild / dependencyOverrides ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((3, _)) => Seq.empty // Scala 3 doesn't need override
case _ => Seq("org.scala-lang" % "scala-library" % scalaVersion.value)
}
}
-resolvers += "MimirDB" at "https://maven.mimirdb.info/"
-resolvers ++= Seq("snapshots", "releases").map(Resolver.sonatypeRepo)
+ThisBuild / resolvers += "MimirDB" at "https://maven.mimirdb.info/"
+ThisBuild / resolvers ++= Seq("snapshots", "releases").map(Resolver.sonatypeRepo)
-libraryDependencies ++= Seq(
- "com.lihaoyi" %% "fastparse" % "3.1.1",
- "com.typesafe.scala-logging" %% "scala-logging" % "3.9.5",
- "ch.qos.logback" % "logback-classic" % "1.5.14",
- "org.specs2" %% "specs2-core" % "4.23.0" % "test",
- "org.specs2" %% "specs2-junit" % "4.23.0" % "test"
+val scalacheckVersion = "1.19.0"
+
+val commonSettings = Seq(
+ libraryDependencies ++= Seq(
+ "com.lihaoyi" %%% "mainargs" % mainArgsVersion,
+ "com.lihaoyi" %%% "fastparse" % fastparseVersion,
+ "com.outr" %%% "scribe" % scribeVersion,
+ "io.circe" %%% "circe-core" % circeVersion,
+ "io.circe" %%% "circe-generic" % circeVersion,
+ "io.circe" %%% "circe-generic-extras" % "0.14.5-RC1",
+ "io.circe" %%% "circe-parser" % circeVersion,
+ "io.circe" %%% "circe-optics" % circeOpticsVersion,
+ "dev.optics" %% "monocle-core" % monocleVersion,
+ "dev.optics" %% "monocle-macro" % monocleVersion,
+ "com.quincyjo" %%% "scala-json-path-circe" % jsonPathVersion,
+ "com.arakelian" % "java-jq" % javaJqVersion,
+ "org.specs2" %%% "specs2-core" % specs2Version % "test",
+ "org.specs2" %%% "specs2-junit" % specs2Version % "test",
+ "org.scalacheck" %%% "scalacheck" % scalacheckVersion % "test"
+ )
)
-////// Publishing Metadata //////
-// use `sbt publish makePom` to generate
-// a publishable jar artifact and its POM metadata
-
-publishMavenStyle := true
-
-pomExtra := http://github.com/UBOdin/sparsity/
-
-
- Apache License 2.0
- http://www.apache.org/licenses/
- repo
-
-
-
- git@github.com:ubodin/sparsity.git
- scm:git:git@github.com:ubodin/sparsity.git
-
-
-/////// Publishing Options ////////
-// use `sbt publish` to update the package in
-// your own local ivy cache
-publishTo := Some(Resolver.file("file", new File("/var/www/maven_repo/")))
+val nativeSettings = Seq(
+ nativeConfig ~= {
+ _.withLTO(LTO.full)
+ .withMode(Mode.releaseFast)
+ .withGC(GC.commix)
+ }
+)
+
+// lazy val root = project
+// .in(file("."))
+// .aggregate(sparsity.jvm, sparsity.native)
+
+lazy val sparsity =
+ //crossProject(JVMPlatform, NativePlatform)
+ //.crossType(CrossType.Pure)
+ project
+ .in(file("."))
+ .settings(
+ name := "Sparsity",
+ version := "1.7.1-SNAPSHOT",
+ organization := "info.mimirdb",
+ commonSettings,
+ // Assembly settings
+ assembly / mainClass := Some("sparsity.Main"),
+ assembly / assemblyJarName := s"${name.value}-${version.value}.jar",
+ assembly / assemblyMergeStrategy := {
+ case PathList("META-INF", xs @ _*) => MergeStrategy.discard
+ case x => MergeStrategy.first
+ },
+ assembly / artifact := {
+ val art = (assembly / artifact).value
+ art.withClassifier(Some("assembly"))
+ },
+ addArtifact(assembly / artifact, assembly),
+ // Publishing settings
+ publishMavenStyle := true,
+ pomExtra := http://github.com/UBOdin/sparsity/
+
+
+ Apache License 2.0
+ http://www.apache.org/licenses/
+ repo
+
+
+
+ git@github.com:ubodin/sparsity.git
+ scm:git:git@github.com:ubodin/sparsity.git
+ ,
+ )
+ //.nativeSettings(nativeSettings)
+ .enablePlugins(AssemblyPlugin)
diff --git a/project/build.properties b/project/build.properties
new file mode 100644
index 0000000..01a16ed
--- /dev/null
+++ b/project/build.properties
@@ -0,0 +1 @@
+sbt.version=1.11.7
diff --git a/project/plugins.sbt b/project/plugins.sbt
new file mode 100644
index 0000000..1cd828c
--- /dev/null
+++ b/project/plugins.sbt
@@ -0,0 +1,5 @@
+addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.3")
+addSbtPlugin("org.portable-scala" % "sbt-scala-native-crossproject" % "1.3.2")
+addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.5.8")
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1")
+
diff --git a/src/main/scala/sparsity/Main.scala b/src/main/scala/sparsity/Main.scala
new file mode 100644
index 0000000..1eb62fe
--- /dev/null
+++ b/src/main/scala/sparsity/Main.scala
@@ -0,0 +1,597 @@
+package sparsity
+
+import sparsity.oracle.OracleSQL
+import sparsity.ansi.ANSISQL
+import sparsity.postgres.PostgresSQL
+import sparsity.common.parser.ErrorMessage
+import sparsity.common.codec.CirceCodecs._
+import fastparse._
+import scala.io.Source
+import java.io.{File, FileWriter, PrintWriter}
+import java.nio.file.{FileSystems, Files, Paths}
+import java.util.logging.{Level => JLevel, LogManager, Logger => JLogger}
+import scribe.Logger
+import scribe.Level
+import io.circe._
+import io.circe.generic.semiauto._
+import io.circe.syntax._
+import io.circe.Printer
+import io.circe.Encoder
+import sparsity.common.statement.Statement
+import com.arakelian.jq.{ImmutableJqLibrary, ImmutableJqRequest}
+import mainargs.{arg, main, Leftover, Parser, TokensReader}
+import sparsity.common.parser.SQLBaseObject
+
+case class Annotation(file: String, line: Int, title: String, message: String, annotation_level: String)
+
+object Annotation {
+ implicit val encoder: Encoder[Annotation] = deriveEncoder[Annotation]
+}
+
+// Maps for dialect lookup - built dynamically using the name method
+object SQLDialectMaps {
+ val dialects: Seq[SQLBaseObject] = Seq(OracleSQL, ANSISQL, PostgresSQL)
+ val dialectNames: Seq[String] = dialects.map(_.name)
+ val nameToDialect: Map[String, SQLBaseObject] = dialects.map(d => d.name -> d).toMap
+ val validDialects: String = dialects.map(_.name).mkString(", ")
+}
+
+// Custom parser for SQLBaseObject
+object SQLBaseObjectReader extends TokensReader.Simple[SQLBaseObject] {
+ def shortName = "dialect"
+ def read(strs: Seq[String]): Either[String, SQLBaseObject] = {
+ if (strs.isEmpty) {
+ Left("Missing dialect name")
+ } else {
+ SQLDialectMaps.nameToDialect.get(strs.head.toLowerCase) match {
+ case Some(dialect) => Right(dialect)
+ case None => Left(s"Unknown dialect: ${strs.head}. Valid dialects are: ${SQLDialectMaps.validDialects}")
+ }
+ }
+ }
+}
+
+// Custom parser for OutputFormat
+object OutputFormatReader extends TokensReader.Simple[OutputFormat] {
+ def shortName = "format"
+ def read(strs: Seq[String]): Either[String, OutputFormat] = {
+ if (strs.isEmpty) {
+ Left("Missing output format")
+ } else {
+ val s = strs.head
+ s match {
+ case "text" => Right(TextFormat)
+ case "json" => Right(StatementsJsonFormat)
+ case s"jq($query)" => Right(JqFormat(query))
+ case s"jq-file:$filePath" => Right(JqFileFormat(filePath))
+ case _ => Left(s"Unknown output format: $s. Valid formats: text, json, jq(query), jq-file:path")
+ }
+ }
+ }
+}
+
+// Output format types
+sealed trait OutputFormat {
+ def name: String
+ def format(result: ProcessingResult): String
+}
+
+case object TextFormat extends OutputFormat {
+ def name: String = "text"
+ def format(result: ProcessingResult): String = {
+ val textOutputLines = result.fileResults.flatMap { fileResult =>
+ (s"\n=== Processing: ${fileResult.fileName}" :: fileResult.statementResults.flatMap { stmtResult =>
+ val idx = stmtResult.index
+ stmtResult.parseResult match {
+ case Parsed.Success(_, _) =>
+ List(s"✓ Statement ${idx + 1}: OK")
+ case Parsed.Failure(label, index, _) =>
+ val stmtIndex = index.toInt
+ val absoluteLineNum = stmtResult.absoluteLineNumber.getOrElse(0)
+ // For unparseable statements, show the full statement text instead of a substring
+ val context = if (label == "unparseable") {
+ // Show full statement text (up to 200 chars for readability)
+ val fullText = stmtResult.context.statement
+ if (fullText.length <= 200) fullText else fullText.take(200) + "..."
+ } else {
+ stmtResult.context.statement
+ .substring(Math.max(0, stmtIndex - 50), Math.min(stmtResult.context.statement.length, stmtIndex + 50))
+ }
+ List(
+ s"❌ ${stmtResult.context.relativePath}:${absoluteLineNum}: Syntax ERROR",
+ s"Statement ${idx + 1} failed: $label",
+ s"Context: $context"
+ )
+ }
+ }) :+ fileResult.fileSummary
+ }
+
+ val overallSummary = s"\n=== Overall Summary ==="
+ val summary2 = s"Files processed: ${result.fileResults.length}"
+ val summary3 =
+ s"Total: ${result.totalSuccessCount} successful, ${result.totalErrorCount} failed out of ${result.totalStatements} statements"
+
+ (textOutputLines :+ overallSummary :+ summary2 :+ summary3).mkString("\n")
+ }
+}
+case object StatementsJsonFormat extends OutputFormat {
+ import ParsedEncoder._
+ def name: String = "json"
+ def format(result: ProcessingResult): String = {
+ val statementsJson = result.asJson
+ Printer.spaces2.copy(dropNullValues = true).print(statementsJson)
+ }
+}
+case class JqFormat(query: String) extends OutputFormat {
+ def name: String = "jq"
+ def format(result: ProcessingResult): String = {
+ // Generate statements JSON for JQ query
+ val statementsJsonString = StatementsJsonFormat.format(result)
+ try {
+ // Suppress Java JQ library logging before initializing the library
+ // This must be done before ImmutableJqLibrary.of() is called
+ // Suppress all JQ-related loggers and their handlers
+ val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary")
+
+ loggersToSuppress.foreach { loggerName =>
+ val logger = JLogger.getLogger(loggerName)
+ logger.setLevel(JLevel.OFF) // Completely disable logging
+ logger.setUseParentHandlers(false)
+ // Remove all handlers
+ val handlers = logger.getHandlers
+ handlers.foreach(logger.removeHandler)
+ }
+
+ val library = ImmutableJqLibrary.of()
+ val request = ImmutableJqRequest
+ .builder()
+ .lib(library)
+ .input(statementsJsonString)
+ .filter(query)
+ .build()
+ request.execute().getOutput
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error executing JQ query: ${e.getMessage}", e)
+ }
+ }
+}
+case class JqFileFormat(filePath: String) extends OutputFormat {
+ def name: String = "jq-file"
+ def format(result: ProcessingResult): String = {
+ // Read the JQ query from the file
+ val query =
+ try {
+ val currentDir = Paths.get(".").toAbsolutePath.normalize
+ val queryFile = if (Paths.get(filePath).isAbsolute) {
+ Paths.get(filePath).toFile
+ } else {
+ currentDir.resolve(filePath).toFile
+ }
+
+ if (!queryFile.exists() || !queryFile.isFile) {
+ throw new RuntimeException(s"JQ query file not found: $filePath")
+ }
+
+ Source.fromFile(queryFile).mkString.trim
+ } catch {
+ case e: RuntimeException => throw e
+ case e: Exception =>
+ throw new RuntimeException(s"Error reading JQ query file '$filePath': ${e.getMessage}", e)
+ }
+
+ // Generate statements JSON for JQ query
+ val statementsJsonString = StatementsJsonFormat.format(result)
+ try {
+ // Suppress Java JQ library logging before initializing the library
+ // This must be done before ImmutableJqLibrary.of() is called
+ // Suppress all JQ-related loggers and their handlers
+ val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary")
+
+ loggersToSuppress.foreach { loggerName =>
+ val logger = JLogger.getLogger(loggerName)
+ logger.setLevel(JLevel.OFF) // Completely disable logging
+ logger.setUseParentHandlers(false)
+ // Remove all handlers
+ val handlers = logger.getHandlers
+ handlers.foreach(logger.removeHandler)
+ }
+
+ val library = ImmutableJqLibrary.of()
+ val request = ImmutableJqRequest
+ .builder()
+ .lib(library)
+ .input(statementsJsonString)
+ .filter(query)
+ .build()
+ request.execute().getOutput
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error executing JQ query from file '$filePath': ${e.getMessage}", e)
+ }
+ }
+}
+
+// Output destination types
+sealed trait OutputDestination {
+ def write(content: String): Unit
+}
+case object ConsoleDestination extends OutputDestination {
+ def write(content: String): Unit = println(content)
+}
+case class FileDestination(path: String) extends OutputDestination {
+ def write(content: String): Unit = {
+ val writer = new PrintWriter(new File(path))
+ try
+ writer.println(content)
+ finally
+ writer.close()
+ }
+}
+
+object OutputDestination {
+ implicit val tokensReader: TokensReader.Simple[OutputDestination] = new TokensReader.Simple[OutputDestination] {
+ def shortName = "destination"
+ def read(strs: Seq[String]): Either[String, OutputDestination] = {
+ if (strs.isEmpty) {
+ Left("Missing output destination")
+ } else {
+ strs.head match {
+ case "console" => Right(ConsoleDestination)
+ case path: String => Right(FileDestination(path))
+ case _ => Left(s"Unknown output destination: ${strs.head}. Valid destinations: console, file path")
+ }
+ }
+ }
+ }
+}
+
+@main(doc = "Parse SQL files using the specified dialect")
+case class Config(
+ @arg(positional = true, doc = "SQL dialect to use")
+ dialect: SQLBaseObject,
+ @arg(doc =
+ "Output format. Format: text|json|jq(query)|jq-file:path. Example: --output jq(.files[].statements[]) or --output jq-file:pr-checker.jq"
+ )
+ output: OutputFormat = TextFormat,
+ @arg(doc = "Output destination. Example: --write-to console or --write-to output.json")
+ writeTo: OutputDestination = ConsoleDestination,
+ @arg(doc = "SQL file(s) or glob pattern(s) to parse (e.g., *.sql)")
+ filePatterns: Leftover[String]
+)
+
+case class StatementContext(
+ statement: String,
+ rawStatement: String,
+ relativePath: String,
+ dialectName: String,
+ statementStartInContent: Int
+)
+
+object StatementContext {
+ implicit val encoder: Encoder[StatementContext] = deriveEncoder[StatementContext]
+}
+
+case class StatementResult(
+ index: Int,
+ parseResult: Parsed[Statement],
+ context: StatementContext,
+ absoluteLineNumber: Option[Int] = None
+) {
+ def success: Boolean = parseResult match {
+ case Parsed.Success(_, _) => true
+ case _: Parsed.Failure => false
+ }
+}
+
+// Custom encoder for Parsed[Statement] to make it serializable
+object ParsedEncoder {
+ implicit def parsedEncoder[T : Encoder]: Encoder[Parsed[T]] = Encoder.instance {
+ case Parsed.Success(value, _) =>
+ Json.obj("success" -> Json.fromBoolean(true), "value" -> value.asJson)
+ case Parsed.Failure(label, index, extra) =>
+ Json.obj("success" -> Json.fromBoolean(false), "label" -> Json.fromString(label), "index" -> Json.fromLong(index))
+ }
+}
+
+object StatementResult {
+ import ParsedEncoder._
+ implicit val encoder: Encoder[StatementResult] = deriveEncoder[StatementResult]
+}
+
+case class FileResult(
+ fileName: String,
+ relativePath: String,
+ statementResults: List[StatementResult],
+ fileSummary: String
+)
+
+object FileResult {
+ implicit val encoder: Encoder[FileResult] = deriveEncoder[FileResult]
+}
+
+case class ProcessingResult(
+ fileResults: List[FileResult],
+ totalSuccessCount: Int,
+ totalErrorCount: Int,
+ totalStatements: Int
+)
+
+object ProcessingResult {
+ implicit val encoder: Encoder[ProcessingResult] = deriveEncoder[ProcessingResult]
+}
+
+object Main {
+ // Configure scribe logger
+ private val logger = Logger("sparsity.Main")
+
+ // Make the custom parsers available implicitly
+ implicit val sqlBaseObjectReader: TokensReader.Simple[SQLBaseObject] = SQLBaseObjectReader
+ implicit val outputFormatReader: TokensReader.Simple[OutputFormat] = OutputFormatReader
+
+ def main(args: scala.Array[String]): Unit = {
+ // Parse command-line arguments using mainargs first to check output destination
+ val config = Parser[Config].constructOrExit(args)
+
+ // Configure logging - suppress logs when writing to console
+ val shouldSuppressLogs = config.writeTo == ConsoleDestination
+ val logLevel = if (shouldSuppressLogs) {
+ // When writing to console, suppress INFO/DEBUG logs
+ Option(System.getProperty("log.level"))
+ .orElse(Option(System.getenv("LOG_LEVEL")))
+ .getOrElse("WARN")
+ } else {
+ // When writing to file, respect user's log level preference
+ Option(System.getProperty("log.level"))
+ .orElse(Option(System.getenv("LOG_LEVEL")))
+ .getOrElse("DEBUG")
+ }
+
+ val level = logLevel.toUpperCase match {
+ case "TRACE" => Level.Trace
+ case "DEBUG" => Level.Debug
+ case "INFO" => Level.Info
+ case "WARN" => Level.Warn
+ case "ERROR" => Level.Error
+ case _ => Level.Debug
+ }
+
+ // Configure root logger with console output and minimum level
+ Logger.root
+ .clearHandlers()
+ .withHandler(minimumLevel = Some(level))
+ .replace()
+
+ // Suppress Java JQ library logging when writing to console
+ if (shouldSuppressLogs) {
+ // Suppress all JQ-related loggers early
+ val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary")
+
+ loggersToSuppress.foreach { loggerName =>
+ val logger = JLogger.getLogger(loggerName)
+ logger.setLevel(JLevel.OFF)
+ logger.setUseParentHandlers(false)
+ }
+ }
+
+ // Expand glob patterns and collect all files
+ val files = config.filePatterns.value.flatMap { pattern =>
+ val currentDir = Paths.get(".").toAbsolutePath.normalize
+
+ // If it's a direct file path and exists, return it
+ val directFile = if (Paths.get(pattern).isAbsolute) {
+ Paths.get(pattern).toFile
+ } else {
+ currentDir.resolve(pattern).toFile
+ }
+
+ if (directFile.exists() && directFile.isFile && !pattern.contains('*') && !pattern.contains('?')) {
+ Seq(directFile)
+ } else if (pattern.contains('*') || pattern.contains('?')) {
+ // Normalize path separators to forward slashes for glob matching
+ val normalizedPattern = pattern.replace('\\', '/')
+
+ // Determine the base directory to start searching from
+ // For patterns like "../**/*oracle*/**/*.sql", find the deepest non-glob directory
+ val parts = normalizedPattern.split("/").filter(_.nonEmpty)
+ def findBaseRec(path: java.nio.file.Path, index: Int): (java.nio.file.Path, Int) =
+ if (index >= parts.length) {
+ (path, index)
+ } else {
+ val part = parts(index)
+ if (part == "..") {
+ findBaseRec(path.getParent, index + 1)
+ } else if (part == ".") {
+ findBaseRec(path, index + 1)
+ } else if (part.contains('*') || part.contains('?')) {
+ (path, index) // Stop here, found glob
+ } else {
+ val nextPath = path.resolve(part)
+ if (nextPath.toFile.exists() && nextPath.toFile.isDirectory) {
+ findBaseRec(nextPath, index + 1)
+ } else {
+ (path, index) // Stop here, directory doesn't exist
+ }
+ }
+ }
+
+ val (baseDir, baseIndex) = findBaseRec(currentDir, 0)
+
+ // Build the relative pattern from the base directory
+ val relativePattern = if (baseIndex < parts.length) {
+ parts.slice(baseIndex, parts.length).mkString("/")
+ } else {
+ "**" // If we consumed everything, search recursively
+ }
+
+ val basePath = baseDir.normalize
+
+ if (basePath.toFile.exists() && basePath.toFile.isDirectory) {
+ val fs = FileSystems.getDefault
+ val matcher = fs.getPathMatcher(s"glob:$relativePattern")
+
+ // Walk the directory tree recursively using Stream
+ try {
+ import scala.jdk.CollectionConverters._
+ Files
+ .walk(basePath)
+ .iterator()
+ .asScala
+ .filter(_.toFile.isFile)
+ .filter { path =>
+ val relativePath = basePath.relativize(path).toString.replace('\\', '/')
+ matcher.matches(Paths.get(relativePath))
+ }
+ .map(_.toFile)
+ .toSeq
+ } catch {
+ case _: Exception =>
+ Seq.empty
+ }
+ } else {
+ Seq.empty
+ }
+ } else {
+ // Not a glob pattern, but file doesn't exist
+ Seq.empty
+ }
+ }.distinct
+
+ if (files.isEmpty) {
+ logger.warn(s"No files found matching: ${config.filePatterns.value.mkString(", ")}")
+ System.exit(1)
+ }
+
+ // Step 1: Run the parser
+ val result = parseFiles(files, config.dialect)
+
+ // Step 2: Convert results into outputs
+ generateOutputs(result, config)
+
+ // Exit with error code if there were errors
+ if (result.totalErrorCount > 0) {
+ System.exit(1)
+ }
+ }
+
+ /** Parses a list of SQL files using the specified dialect and returns processing results. This method is public to
+ * enable testing.
+ *
+ * @param files
+ * The list of files to parse
+ * @param dialect
+ * The SQL dialect object to use
+ * @return
+ * ProcessingResult containing file results and summary statistics
+ */
+ def parseFiles(files: Seq[File], dialect: SQLBaseObject): ProcessingResult = {
+ val fileResults = files.map { file =>
+ val relativePath = {
+ val currentDir = Paths.get(".").toAbsolutePath.normalize
+ val filePathObj = file.toPath.toAbsolutePath.normalize
+ try
+ currentDir.relativize(filePathObj).toString.replace('\\', '/')
+ catch {
+ case _: Exception => file.getPath
+ }
+ }
+
+ val content = Source.fromFile(file).mkString
+
+ val dialectName = dialect.name
+
+ // Use parseAll to parse all statements with error recovery (now returns positions)
+ val parseResults: Seq[sparsity.common.statement.StatementParseResult] = dialect.parseAll(content)
+
+ // Convert StatementParseResult to StatementResult
+ val statementResults = parseResults.zipWithIndex.map { case (parseResult, idx) =>
+ val (statementText, isUnparseable) = parseResult.result match {
+ case Right(stmt) => (stmt.toString, false)
+ case Left(unparseable) => (unparseable.content, true)
+ }
+
+ // Calculate absolute line number from start position
+ val absoluteLineNum = if (parseResult.startPos >= 0 && parseResult.startPos < content.length) {
+ val contentBeforeStmt = content.substring(0, parseResult.startPos)
+ val linesBeforeStmt = contentBeforeStmt.count(_ == '\n')
+ Some(linesBeforeStmt + 1) // +1 because line numbers are 1-based
+ } else {
+ None
+ }
+
+ // Create parse result object
+ val finalParseResult = parseResult.result match {
+ case Right(stmt) => Parsed.Success(stmt, 0)
+ case Left(_) => Parsed.Failure("unparseable", 0, null)
+ }
+
+ val stmtContext = StatementContext(
+ statement = statementText,
+ rawStatement = statementText,
+ relativePath = relativePath,
+ dialectName = dialectName,
+ statementStartInContent = parseResult.startPos
+ )
+
+ val stmtResult = StatementResult(
+ index = idx,
+ parseResult = finalParseResult,
+ context = stmtContext,
+ absoluteLineNumber = absoluteLineNum
+ )
+
+ // Log parsing details for failures
+ parseResult.result match {
+ case Left(unparseable) =>
+ val stmtLines = statementText.split("\n")
+ logger.info(s" Label: unparseable")
+ logger.info(s" Context: ${statementText.take(100)}")
+ logger.debug(s" Full statement:")
+ logger.debug(s" ${stmtLines.mkString("\n ")}")
+ case Right(_) =>
+ }
+
+ stmtResult
+ }.toList
+
+ val successCount = statementResults.count(_.success)
+ val errorCount = statementResults.count(!_.success)
+ val summary =
+ s"File summary: $successCount successful, $errorCount failed out of ${statementResults.length} statements"
+
+ FileResult(
+ fileName = file.getName,
+ relativePath = relativePath,
+ statementResults = statementResults,
+ fileSummary = summary
+ )
+ }.toList
+
+ val totals = fileResults.foldLeft((0, 0, 0)) { case ((success, error, total), fileResult) =>
+ val stmtCounts = fileResult.statementResults.foldLeft((0, 0)) { case ((s, e), stmt) =>
+ if (stmt.success) (s + 1, e) else (s, e + 1)
+ }
+ (success + stmtCounts._1, error + stmtCounts._2, total + fileResult.statementResults.length)
+ }
+
+ ProcessingResult(
+ fileResults = fileResults,
+ totalSuccessCount = totals._1,
+ totalErrorCount = totals._2,
+ totalStatements = totals._3
+ )
+ }
+
+ /** Generates all outputs from parsing results. This method is public to enable testing.
+ *
+ * @param result
+ * The processing result from parsing files
+ * @param config
+ * The configuration containing output options
+ */
+ def generateOutputs(result: ProcessingResult, config: Config): Unit = {
+ // Process each output format (all outputs go to console)
+ val formattedOutput = config.output.format(result)
+ config.writeTo.write(formattedOutput)
+ }
+}
diff --git a/src/main/scala/sparsity/Name.scala b/src/main/scala/sparsity/Name.scala
deleted file mode 100644
index 622f24f..0000000
--- a/src/main/scala/sparsity/Name.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-package sparsity
-
-case class Name(name: String, quoted: Boolean = false)
-{
- override def equals(other: Any): Boolean =
- {
- other match {
- case s: String => equals(s)
- case n: Name => equals(n)
- case _ => false
- }
- }
- def equals(other: Name): Boolean =
- {
- if(quoted || other.quoted){ name.equals(other.name) }
- else { name.equalsIgnoreCase(other.name) }
- }
- def equals(other: String): Boolean =
- {
- if(quoted){ name.equals(other) }
- else { name.equalsIgnoreCase(other) }
- }
- override def toString = if(quoted){ "`"+name+"`" } else { name }
- def +(other: Name) = Name(name+other.name, quoted || other.quoted)
- def +(other: String) = Name(name+other, quoted)
-
- def lower = if(quoted){ name } else { name.toLowerCase }
- def upper = if(quoted){ name } else { name.toUpperCase }
-
- def extendUpper(template:String) =
- Name(template.replace("\\$", upper), quoted)
- def extendLower(template:String) =
- Name(template.replace("\\$", lower), quoted)
-
- def withPrefix(other: String) = Name(other+name, quoted)
- def withSuffix(other: String) = Name(name+other, quoted)
-}
-
-class StringNameMatch(cmp:String)
-{
- def unapply(name: Name):Option[Unit] =
- if(name.equals(cmp)) { return Some(()) }
- else { return None }
-}
-
-object NameMatch
-{
- def apply(cmp: String) = new StringNameMatch(cmp)
-}
\ No newline at end of file
diff --git a/src/main/scala/sparsity/alter/AlterView.scala b/src/main/scala/sparsity/alter/AlterView.scala
deleted file mode 100644
index b63daf2..0000000
--- a/src/main/scala/sparsity/alter/AlterView.scala
+++ /dev/null
@@ -1,8 +0,0 @@
-package sparsity.alter
-
-sealed abstract class AlterViewAction
-
-case class Materialize(add: Boolean) extends AlterViewAction
-{
- override def toString = (if(add){ "" } else { "DROP " })+"MATERIALIZE"
-}
\ No newline at end of file
diff --git a/src/main/scala/sparsity/ansi/ANSISQL.scala b/src/main/scala/sparsity/ansi/ANSISQL.scala
new file mode 100644
index 0000000..7419a67
--- /dev/null
+++ b/src/main/scala/sparsity/ansi/ANSISQL.scala
@@ -0,0 +1,73 @@
+package sparsity.ansi
+
+import fastparse._
+import scala.io._
+import java.io._
+import sparsity.common.statement._
+import sparsity.common.{
+ BigInt,
+ Boolean,
+ Char,
+ DataType,
+ Date,
+ Decimal,
+ Double,
+ Float,
+ Integer,
+ Text,
+ Timestamp,
+ VarChar
+}
+import sparsity.common.parser.{SQLBase, SQLBaseObject, StreamParser}
+import sparsity.common.statement.Unparseable
+
+/** ANSI SQL dialect implementation. Provides baseline SQL parsing with standard ANSI SQL features.
+ */
+class ANSISQL extends SQLBase {
+ // Configuration
+ type Stmt = Statement
+ def caseSensitive = false
+ def statementTerminator = ";"
+ def supportsCTEs = true
+ def supportsReturning = false
+ def supportsIfExists = true
+ def stringConcatOp = "||"
+
+ // Parser implementations
+ def createStatement[$ : P]: P[Statement] = P(&(keyword("CREATE")) ~/ (createTable | createView))
+
+ def dialectSpecificStatement[$ : P]: P[Statement] = P(fastparse.Fail)
+
+ def dataType[$ : P]: P[DataType] = P(
+ (keyword("VARCHAR") ~ "(" ~ integer.map(_.toInt) ~ ")").map(VarChar(_))
+ | (keyword("CHAR") ~ "(" ~ integer.map(_.toInt) ~ ")").map(Char(_))
+ | keyword("INTEGER").map(_ => Integer())
+ | keyword("INT").map(_ => Integer())
+ | keyword("BIGINT").map(_ => BigInt())
+ | (keyword("DECIMAL") ~ "(" ~ integer.map(_.toInt) ~ "," ~ integer.map(_.toInt) ~ ")")
+ .map { case (p, s) => Decimal(p, s) }
+ | keyword("FLOAT").map(_ => Float())
+ | keyword("DOUBLE").map(_ => Double())
+ | keyword("DATE").map(_ => Date())
+ | (keyword("TIMESTAMP") ~ ("(" ~ integer.map(_.toInt) ~ ")").?)
+ .map(p => Timestamp(p))
+ | keyword("BOOLEAN").map(_ => Boolean())
+ | keyword("TEXT").map(_ => Text())
+ )
+}
+
+object ANSISQL extends SQLBaseObject {
+ type Stmt = Statement
+ import fastparse.Parsed
+ val instance = new ANSISQL()
+
+ def name: String = "ansi"
+
+ protected def statementTerminator: String = ";"
+
+ def apply(input: String): Parsed[Stmt] =
+ parse(input, instance.terminatedStatement(_))
+
+ def apply(input: Reader): StreamParser[Stmt] =
+ new StreamParser[Stmt](parse(_: Iterator[String], instance.terminatedStatement(_), verboseFailures = true), input)
+}
diff --git a/src/main/scala/sparsity/common/DataType.scala b/src/main/scala/sparsity/common/DataType.scala
new file mode 100644
index 0000000..2ab0bc6
--- /dev/null
+++ b/src/main/scala/sparsity/common/DataType.scala
@@ -0,0 +1,37 @@
+package sparsity.common
+
+/** Base trait for SQL data types. Supports common types shared across dialects and dialect-specific extensions.
+ */
+sealed trait DataType
+
+// Common data types (shared across dialects)
+case class VarChar(length: Int) extends DataType
+case class Char(length: Int) extends DataType
+case class Integer() extends DataType
+case class BigInt() extends DataType
+case class Decimal(precision: Int, scale: Int) extends DataType
+case class Float() extends DataType
+case class Double() extends DataType
+case class Date() extends DataType
+case class Timestamp(precision: Option[Int]) extends DataType
+case class Boolean() extends DataType
+case class Text() extends DataType
+
+// Oracle-specific data types
+case class VarChar2(length: Int, byteOrChar: Option[String]) extends DataType
+case class Number(precision: Option[Int], scale: Option[Int]) extends DataType
+case class Clob() extends DataType
+case class Blob() extends DataType
+case class NVarChar2(length: Int) extends DataType
+case class Raw(length: Int) extends DataType
+
+// Postgres-specific data types
+case class Serial() extends DataType
+case class BigSerial() extends DataType
+case class Json() extends DataType
+case class Jsonb() extends DataType
+case class Uuid() extends DataType
+case class Inet() extends DataType
+case class Cidr() extends DataType
+case class MacAddr() extends DataType
+case class Array(elementType: DataType) extends DataType
diff --git a/src/main/scala/sparsity/common/Name.scala b/src/main/scala/sparsity/common/Name.scala
new file mode 100644
index 0000000..ddbc48e
--- /dev/null
+++ b/src/main/scala/sparsity/common/Name.scala
@@ -0,0 +1,45 @@
+package sparsity.common
+
+import sparsity.common.expression.ToSql
+
+case class Name(name: String, quoted: scala.Boolean = false) extends ToSql {
+ override def equals(other: Any): scala.Boolean =
+ other match {
+ case s: String => equals(s)
+ case n: Name => equals(n)
+ case _ => false
+ }
+ def equals(other: Name): scala.Boolean =
+ if (quoted || other.quoted) { name.equals(other.name) }
+ else { name.equalsIgnoreCase(other.name) }
+ def equals(other: String): scala.Boolean =
+ if (quoted) { name.equals(other) }
+ else { name.equalsIgnoreCase(other) }
+ override def toSql = if (quoted) { "`" + name + "`" }
+ else { name }
+ def +(other: Name) = Name(name + other.name, quoted || other.quoted)
+ def +(other: String) = Name(name + other, quoted)
+
+ def lower = if (quoted) { name }
+ else { name.toLowerCase }
+ def upper = if (quoted) { name }
+ else { name.toUpperCase }
+
+ def extendUpper(template: String) =
+ Name(template.replace("\\$", upper), quoted)
+ def extendLower(template: String) =
+ Name(template.replace("\\$", lower), quoted)
+
+ def withPrefix(other: String) = Name(other + name, quoted)
+ def withSuffix(other: String) = Name(name + other, quoted)
+}
+
+class StringNameMatch(cmp: String) {
+ def unapply(name: Name): Option[Unit] =
+ if (name.equals(cmp): scala.Boolean) { return Some(()) }
+ else { return None }
+}
+
+object NameMatch {
+ def apply(cmp: String) = new StringNameMatch(cmp)
+}
diff --git a/src/main/scala/sparsity/common/alter/AlterView.scala b/src/main/scala/sparsity/common/alter/AlterView.scala
new file mode 100644
index 0000000..99bb6d5
--- /dev/null
+++ b/src/main/scala/sparsity/common/alter/AlterView.scala
@@ -0,0 +1,10 @@
+package sparsity.common.alter
+
+import sparsity.common.expression.ToSql
+
+sealed abstract class AlterViewAction extends ToSql
+
+case class Materialize(add: Boolean) extends AlterViewAction {
+ override def toSql = (if (add) { "" }
+ else { "DROP " }) + "MATERIALIZE"
+}
diff --git a/src/main/scala/sparsity/common/codec/CirceCodecs.scala b/src/main/scala/sparsity/common/codec/CirceCodecs.scala
new file mode 100644
index 0000000..3626606
--- /dev/null
+++ b/src/main/scala/sparsity/common/codec/CirceCodecs.scala
@@ -0,0 +1,581 @@
+package sparsity.common.codec
+
+import io.circe.{Decoder, DecodingFailure, Encoder, HCursor}
+import io.circe.{Json => CirceJson}
+import io.circe.generic.extras.Configuration
+import io.circe.generic.extras.semiauto.{deriveConfiguredDecoder, deriveConfiguredEncoder}
+import io.circe.generic.semiauto.{deriveDecoder, deriveEncoder}
+import io.circe.syntax._
+import cats.syntax.traverse._
+import sparsity.common.{
+ Array => SArray,
+ BigInt => SBigInt,
+ BigSerial => SBigSerial,
+ Blob => SBlob,
+ Boolean => SBoolean,
+ Char => SChar,
+ Cidr => SCidr,
+ Clob => SClob,
+ DataType,
+ Date => SDate,
+ Decimal => SDecimal,
+ Double => SDouble,
+ Float => SFloat,
+ Inet => SInet,
+ Integer => SInteger,
+ Json => SJson,
+ Jsonb => SJsonb,
+ MacAddr => SMacAddr,
+ NVarChar2 => SNVarChar2,
+ Name,
+ Number => SNumber,
+ Raw => SRaw,
+ Serial => SSerial,
+ Text => SText,
+ Timestamp => STimestamp,
+ Uuid => SUuid,
+ VarChar => SVarChar,
+ VarChar2 => SVarChar2
+}
+import sparsity.common.expression._
+import sparsity.common.statement._
+import sparsity.common.select._
+import sparsity.common.alter._
+import sparsity.oracle._
+
+/** Circe codecs for all parser result datatypes */
+object CirceCodecs {
+ // Configuration for sealed trait/class hierarchies with "type" discriminator
+ implicit val config: Configuration = Configuration.default.withDiscriminator("type")
+
+ // Name codec
+ implicit val nameEncoder: Encoder[Name] = deriveEncoder[Name]
+ implicit val nameDecoder: Decoder[Name] = deriveDecoder[Name]
+
+ // Function codec (needs explicit codec due to generic derivation issues)
+ implicit val functionEncoder: Encoder[Function] = deriveEncoder[Function]
+ implicit val functionDecoder: Decoder[Function] = deriveDecoder[Function]
+
+ // DataType codecs
+ implicit val dataTypeEncoder: Encoder[DataType] = Encoder.instance {
+ case v: SVarChar =>
+ CirceJson.obj("type" -> CirceJson.fromString("VarChar"), "length" -> CirceJson.fromInt(v.length))
+ case c: SChar => CirceJson.obj("type" -> CirceJson.fromString("Char"), "length" -> CirceJson.fromInt(c.length))
+ case SInteger() => CirceJson.obj("type" -> CirceJson.fromString("Integer"))
+ case SBigInt() => CirceJson.obj("type" -> CirceJson.fromString("BigInt"))
+ case d: SDecimal =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Decimal"),
+ "precision" -> CirceJson.fromInt(d.precision),
+ "scale" -> CirceJson.fromInt(d.scale)
+ )
+ case SFloat() => CirceJson.obj("type" -> CirceJson.fromString("Float"))
+ case SDouble() => CirceJson.obj("type" -> CirceJson.fromString("Double"))
+ case SDate() => CirceJson.obj("type" -> CirceJson.fromString("Date"))
+ case t: STimestamp => CirceJson.obj("type" -> CirceJson.fromString("Timestamp"), "precision" -> t.precision.asJson)
+ case SBoolean() => CirceJson.obj("type" -> CirceJson.fromString("Boolean"))
+ case SText() => CirceJson.obj("type" -> CirceJson.fromString("Text"))
+ case v: SVarChar2 =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("VarChar2"),
+ "length" -> CirceJson.fromInt(v.length),
+ "byteOrChar" -> v.byteOrChar.asJson
+ )
+ case n: SNumber =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Number"),
+ "precision" -> n.precision.asJson,
+ "scale" -> n.scale.asJson
+ )
+ case SClob() => CirceJson.obj("type" -> CirceJson.fromString("Clob"))
+ case SBlob() => CirceJson.obj("type" -> CirceJson.fromString("Blob"))
+ case n: SNVarChar2 =>
+ CirceJson.obj("type" -> CirceJson.fromString("NVarChar2"), "length" -> CirceJson.fromInt(n.length))
+ case r: SRaw => CirceJson.obj("type" -> CirceJson.fromString("Raw"), "length" -> CirceJson.fromInt(r.length))
+ case SSerial() => CirceJson.obj("type" -> CirceJson.fromString("Serial"))
+ case SBigSerial() => CirceJson.obj("type" -> CirceJson.fromString("BigSerial"))
+ case SJson() => CirceJson.obj("type" -> CirceJson.fromString("Json"))
+ case SJsonb() => CirceJson.obj("type" -> CirceJson.fromString("Jsonb"))
+ case SUuid() => CirceJson.obj("type" -> CirceJson.fromString("Uuid"))
+ case SInet() => CirceJson.obj("type" -> CirceJson.fromString("Inet"))
+ case SCidr() => CirceJson.obj("type" -> CirceJson.fromString("Cidr"))
+ case SMacAddr() => CirceJson.obj("type" -> CirceJson.fromString("MacAddr"))
+ case a: SArray =>
+ CirceJson.obj("type" -> CirceJson.fromString("Array"), "elementType" -> dataTypeEncoder(a.elementType))
+ }
+
+ implicit val dataTypeDecoder: Decoder[DataType] = Decoder.instance { c =>
+ c.get[String]("type").flatMap {
+ case "VarChar" => c.get[Int]("length").map(SVarChar.apply)
+ case "Char" => c.get[Int]("length").map(SChar.apply)
+ case "Integer" => Right(SInteger())
+ case "BigInt" => Right(SBigInt())
+ case "Decimal" =>
+ for {
+ p <- c.get[Int]("precision")
+ s <- c.get[Int]("scale")
+ } yield SDecimal(p, s)
+ case "Float" => Right(SFloat())
+ case "Double" => Right(SDouble())
+ case "Date" => Right(SDate())
+ case "Timestamp" => c.get[Option[Int]]("precision").map(STimestamp.apply)
+ case "Boolean" => Right(SBoolean())
+ case "Text" => Right(SText())
+ case "VarChar2" =>
+ for {
+ l <- c.get[Int]("length")
+ boc <- c.get[Option[String]]("byteOrChar")
+ } yield SVarChar2(l, boc)
+ case "Number" =>
+ for {
+ p <- c.get[Option[Int]]("precision")
+ s <- c.get[Option[Int]]("scale")
+ } yield SNumber(p, s)
+ case "Clob" => Right(SClob())
+ case "Blob" => Right(SBlob())
+ case "NVarChar2" => c.get[Int]("length").map(SNVarChar2.apply)
+ case "Raw" => c.get[Int]("length").map(SRaw.apply)
+ case "Serial" => Right(SSerial())
+ case "BigSerial" => Right(SBigSerial())
+ case "Json" => Right(SJson())
+ case "Jsonb" => Right(SJsonb())
+ case "Uuid" => Right(SUuid())
+ case "Inet" => Right(SInet())
+ case "Cidr" => Right(SCidr())
+ case "MacAddr" => Right(SMacAddr())
+ case "Array" => c.get[DataType]("elementType")(dataTypeDecoder).map(SArray.apply)
+ case other => Left(DecodingFailure(s"Unknown DataType: $other", c.history))
+ }
+ }
+
+ // Enum codecs
+ implicit val arithmeticOpEncoder: Encoder[Arithmetic.Op] = Encoder.encodeString.contramap(_.toString)
+ implicit val arithmeticOpDecoder: Decoder[Arithmetic.Op] = Decoder.decodeString.emap { s =>
+ try Right(Arithmetic.withName(s).asInstanceOf[Arithmetic.Op])
+ catch { case _: NoSuchElementException => Left(s"Unknown Arithmetic.Op: $s") }
+ }
+
+ implicit val comparisonOpEncoder: Encoder[Comparison.Op] = Encoder.encodeString.contramap(_.toString)
+ implicit val comparisonOpDecoder: Decoder[Comparison.Op] = Decoder.decodeString.emap { s =>
+ try Right(Comparison.withName(s).asInstanceOf[Comparison.Op])
+ catch { case _: NoSuchElementException => Left(s"Unknown Comparison.Op: $s") }
+ }
+
+ implicit val joinTypeEncoder: Encoder[Join.Type] = Encoder.encodeString.contramap(_.toString)
+ implicit val joinTypeDecoder: Decoder[Join.Type] = Decoder.decodeString.emap { s =>
+ try Right(Join.withName(s).asInstanceOf[Join.Type])
+ catch { case _: NoSuchElementException => Left(s"Unknown Join.Type: $s") }
+ }
+
+ implicit val unionTypeEncoder: Encoder[Union.Type] = Encoder.encodeString.contramap(_.toString)
+ implicit val unionTypeDecoder: Decoder[Union.Type] = Decoder.decodeString.emap { s =>
+ try Right(Union.withName(s).asInstanceOf[Union.Type])
+ catch { case _: NoSuchElementException => Left(s"Unknown Union.Type: $s") }
+ }
+
+ // Expression codecs (recursive)
+ implicit val expressionEncoder: Encoder[Expression] = Encoder.instance {
+ case lp: LongPrimitive =>
+ CirceJson.obj("type" -> CirceJson.fromString("LongPrimitive"), "value" -> CirceJson.fromLong(lp.v))
+ case dp: DoublePrimitive =>
+ CirceJson.obj("type" -> CirceJson.fromString("DoublePrimitive"), "value" -> CirceJson.fromDoubleOrNull(dp.v))
+ case sp: StringPrimitive =>
+ CirceJson.obj("type" -> CirceJson.fromString("StringPrimitive"), "value" -> CirceJson.fromString(sp.v))
+ case bp: BooleanPrimitive =>
+ CirceJson.obj("type" -> CirceJson.fromString("BooleanPrimitive"), "value" -> CirceJson.fromBoolean(bp.v))
+ case NullPrimitive() => CirceJson.obj("type" -> CirceJson.fromString("NullPrimitive"))
+ case c: Column =>
+ CirceJson.obj("type" -> CirceJson.fromString("Column"), "column" -> c.column.asJson, "table" -> c.table.asJson)
+ case a: Arithmetic =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Arithmetic"),
+ "lhs" -> a.lhs.asJson,
+ "op" -> a.op.asJson,
+ "rhs" -> a.rhs.asJson
+ )
+ case c: Comparison =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Comparison"),
+ "lhs" -> c.lhs.asJson,
+ "op" -> c.op.asJson,
+ "rhs" -> c.rhs.asJson
+ )
+ case f: Function =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Function"),
+ "name" -> f.name.asJson,
+ "params" -> f.params.asJson,
+ "distinct" -> CirceJson.fromBoolean(f.distinct)
+ )
+ case wf: WindowFunction =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("WindowFunction"),
+ "function" -> wf.function.asJson,
+ "partitionBy" -> wf.partitionBy.asJson,
+ "orderBy" -> wf.orderBy.asJson
+ )
+ case JDBCVar() => CirceJson.obj("type" -> CirceJson.fromString("JDBCVar"))
+ case sq: Subquery => CirceJson.obj("type" -> CirceJson.fromString("Subquery"), "query" -> sq.query.asJson)
+ case cwe: CaseWhenElse =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("CaseWhenElse"),
+ "target" -> cwe.target.asJson,
+ "cases" -> cwe.cases.map { case (e1, e2) => CirceJson.obj("when" -> e1.asJson, "then" -> e2.asJson) }.asJson,
+ "otherwise" -> cwe.otherwise.asJson
+ )
+ case in: IsNull => CirceJson.obj("type" -> CirceJson.fromString("IsNull"), "target" -> in.target.asJson)
+ case n: Not => CirceJson.obj("type" -> CirceJson.fromString("Not"), "target" -> n.target.asJson)
+ case cast: Cast =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Cast"),
+ "expression" -> cast.expression.asJson,
+ "typeName" -> cast.t.asJson
+ )
+ case ie: InExpression =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("InExpression"),
+ "expression" -> ie.expression.asJson,
+ "source" -> (ie.source match {
+ case Left(exprs) => CirceJson.obj("type" -> CirceJson.fromString("Expressions"), "values" -> exprs.asJson)
+ case Right(query) => CirceJson.obj("type" -> CirceJson.fromString("SelectBody"), "query" -> query.asJson)
+ })
+ )
+ case _: NegatableExpression => // Handled by IsNull and InExpression cases above
+ throw new IllegalStateException("Unhandled NegatableExpression subtype")
+ }
+
+ implicit val expressionDecoder: Decoder[Expression] = Decoder.instance { c =>
+ c.get[String]("type").flatMap {
+ case "LongPrimitive" => c.get[Long]("value").map(LongPrimitive.apply)
+ case "DoublePrimitive" => c.get[Double]("value").map(DoublePrimitive.apply)
+ case "StringPrimitive" => c.get[String]("value").map(StringPrimitive.apply)
+ case "BooleanPrimitive" => c.get[scala.Boolean]("value").map(BooleanPrimitive.apply)
+ case "NullPrimitive" => Right(NullPrimitive())
+ case "Column" =>
+ for {
+ col <- c.get[Name]("column")
+ table <- c.get[Option[Name]]("table")
+ } yield Column(col, table)
+ case "Arithmetic" =>
+ for {
+ lhs <- c.get[Expression]("lhs")
+ op <- c.get[Arithmetic.Op]("op")
+ rhs <- c.get[Expression]("rhs")
+ } yield Arithmetic(lhs, op, rhs)
+ case "Comparison" =>
+ for {
+ lhs <- c.get[Expression]("lhs")
+ op <- c.get[Comparison.Op]("op")
+ rhs <- c.get[Expression]("rhs")
+ } yield Comparison(lhs, op, rhs)
+ case "Function" =>
+ for {
+ name <- c.get[Name]("name")
+ params <- c.get[Option[Seq[Expression]]]("params")
+ distinct <- c.get[scala.Boolean]("distinct")
+ } yield Function(name, params, distinct)
+ case "WindowFunction" =>
+ for {
+ func <- c.get[Function]("function")
+ partitionBy <- c.get[Option[Seq[Expression]]]("partitionBy")
+ orderBy <- c.get[Seq[OrderBy]]("orderBy")
+ } yield WindowFunction(func, partitionBy, orderBy)
+ case "JDBCVar" => Right(JDBCVar())
+ case "Subquery" => c.get[SelectBody]("query").map(Subquery.apply)
+ case "CaseWhenElse" =>
+ for {
+ target <- c.get[Option[Expression]]("target")
+ cases <- c.get[Seq[CirceJson]]("cases").flatMap { jsons =>
+ jsons.traverse { json =>
+ for {
+ when <- json.hcursor.get[Expression]("when")
+ thenExpr <- json.hcursor.get[Expression]("then")
+ } yield (when, thenExpr)
+ }
+ }
+ otherwise <- c.get[Expression]("otherwise")
+ } yield CaseWhenElse(target, cases, otherwise)
+ case "IsNull" => c.get[Expression]("target").map(IsNull.apply)
+ case "Not" => c.get[Expression]("target").map(Not.apply)
+ case "Cast" =>
+ for {
+ expr <- c.get[Expression]("expression")
+ typeName <- c.get[Name]("typeName")
+ } yield Cast(expr, typeName)
+ case "InExpression" =>
+ for {
+ expr <- c.get[Expression]("expression")
+ source <- c.get[CirceJson]("source").flatMap { json =>
+ json.hcursor.get[String]("type").flatMap {
+ case "Expressions" => json.hcursor.get[Seq[Expression]]("values").map(Left.apply)
+ case "SelectBody" => json.hcursor.get[SelectBody]("query").map(Right.apply)
+ case other => Left(DecodingFailure(s"Unknown InExpression source type: $other", json.hcursor.history))
+ }
+ }
+ } yield InExpression(expr, source)
+ case other => Left(DecodingFailure(s"Unknown Expression type: $other", c.history))
+ }
+ }
+
+ // SelectTarget codecs (using circe-generic-extras)
+ implicit val selectTargetEncoder: Encoder[SelectTarget] = deriveConfiguredEncoder[SelectTarget]
+ implicit val selectTargetDecoder: Decoder[SelectTarget] = deriveConfiguredDecoder[SelectTarget]
+
+ // OrderBy codec
+ implicit val orderByEncoder: Encoder[OrderBy] = deriveEncoder[OrderBy]
+ implicit val orderByDecoder: Decoder[OrderBy] = deriveDecoder[OrderBy]
+
+ // FromElement codecs (using circe-generic-extras)
+ implicit val fromElementEncoder: Encoder[FromElement] = deriveConfiguredEncoder[FromElement]
+ implicit val fromElementDecoder: Decoder[FromElement] = deriveConfiguredDecoder[FromElement]
+
+ // SelectBody codec (recursive)
+ implicit val selectBodyEncoder: Encoder[SelectBody] = deriveEncoder[SelectBody]
+ implicit val selectBodyDecoder: Decoder[SelectBody] = deriveDecoder[SelectBody]
+
+ // ColumnAnnotation codecs (using circe-generic-extras)
+ implicit val columnAnnotationEncoder: Encoder[ColumnAnnotation] = deriveConfiguredEncoder[ColumnAnnotation]
+ implicit val columnAnnotationDecoder: Decoder[ColumnAnnotation] = deriveConfiguredDecoder[ColumnAnnotation]
+
+ // PrimitiveValue codec (needed for ColumnDefinition)
+ implicit val primitiveValueEncoder: Encoder[PrimitiveValue] = expressionEncoder.contramap(identity[Expression])
+ implicit val primitiveValueDecoder: Decoder[PrimitiveValue] = expressionDecoder.emap {
+ case pv: PrimitiveValue => Right(pv)
+ case other => Left(s"Expected PrimitiveValue but got ${other.getClass.getSimpleName}")
+ }
+
+ // ColumnDefinition codec
+ implicit val columnDefinitionEncoder: Encoder[ColumnDefinition] = deriveEncoder[ColumnDefinition]
+ implicit val columnDefinitionDecoder: Decoder[ColumnDefinition] = deriveDecoder[ColumnDefinition]
+
+ // TableAnnotation codecs (using circe-generic-extras)
+ implicit val tableAnnotationEncoder: Encoder[TableAnnotation] = deriveConfiguredEncoder[TableAnnotation]
+ implicit val tableAnnotationDecoder: Decoder[TableAnnotation] = deriveConfiguredDecoder[TableAnnotation]
+
+ // InsertValues codecs (using circe-generic-extras)
+ implicit val insertValuesEncoder: Encoder[InsertValues] = deriveConfiguredEncoder[InsertValues]
+ implicit val insertValuesDecoder: Decoder[InsertValues] = deriveConfiguredDecoder[InsertValues]
+
+ // WithClause codec
+ implicit val withClauseEncoder: Encoder[WithClause] = deriveEncoder[WithClause]
+ implicit val withClauseDecoder: Decoder[WithClause] = deriveDecoder[WithClause]
+
+ // AlterViewAction codecs (using circe-generic-extras)
+ implicit val alterViewActionEncoder: Encoder[AlterViewAction] = deriveConfiguredEncoder[AlterViewAction]
+ implicit val alterViewActionDecoder: Decoder[AlterViewAction] = deriveConfiguredDecoder[AlterViewAction]
+
+ // Statement codecs (base and common)
+ implicit val statementEncoder: Encoder[Statement] = Encoder.instance {
+ case EmptyStatement() => CirceJson.obj("type" -> CirceJson.fromString("EmptyStatement"))
+ case u: Unparseable =>
+ CirceJson.obj("type" -> CirceJson.fromString("Unparseable"), "content" -> CirceJson.fromString(u.content))
+ case s: Select =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Select"),
+ "body" -> s.body.asJson,
+ "withClause" -> s.withClause.asJson
+ )
+ case u: Update =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Update"),
+ "table" -> u.table.asJson,
+ "set" -> u.set.map { case (n, e) => CirceJson.obj("name" -> n.asJson, "expression" -> e.asJson) }.asJson,
+ "where" -> u.where.asJson
+ )
+ case d: Delete =>
+ CirceJson.obj("type" -> CirceJson.fromString("Delete"), "table" -> d.table.asJson, "where" -> d.where.asJson)
+ case i: Insert =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("Insert"),
+ "table" -> i.table.asJson,
+ "columns" -> i.columns.asJson,
+ "values" -> i.values.asJson,
+ "orReplace" -> CirceJson.fromBoolean(i.orReplace)
+ )
+ case ct: CreateTable =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("CreateTable"),
+ "name" -> ct.name.asJson,
+ "orReplace" -> CirceJson.fromBoolean(ct.orReplace),
+ "columns" -> ct.columns.asJson,
+ "annotations" -> ct.annotations.asJson
+ )
+ case cta: CreateTableAs =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("CreateTableAs"),
+ "name" -> cta.name.asJson,
+ "orReplace" -> CirceJson.fromBoolean(cta.orReplace),
+ "query" -> cta.query.asJson
+ )
+ case cv: CreateView =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("CreateView"),
+ "name" -> cv.name.asJson,
+ "orReplace" -> CirceJson.fromBoolean(cv.orReplace),
+ "query" -> cv.query.asJson,
+ "materialized" -> CirceJson.fromBoolean(cv.materialized),
+ "temporary" -> CirceJson.fromBoolean(cv.temporary)
+ )
+ case av: AlterView =>
+ CirceJson.obj("type" -> CirceJson.fromString("AlterView"), "name" -> av.name.asJson, "action" -> av.action.asJson)
+ case dt: DropTable =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("DropTable"),
+ "name" -> dt.name.asJson,
+ "ifExists" -> CirceJson.fromBoolean(dt.ifExists)
+ )
+ case dv: DropView =>
+ CirceJson.obj(
+ "type" -> CirceJson.fromString("DropView"),
+ "name" -> dv.name.asJson,
+ "ifExists" -> CirceJson.fromBoolean(dv.ifExists)
+ )
+ case e: Explain => CirceJson.obj("type" -> CirceJson.fromString("Explain"), "query" -> e.query.asJson)
+ // Oracle statements
+ case os: OracleStatement => oracleStatementEncoder(os)
+ }
+
+ implicit val statementDecoder: Decoder[Statement] = Decoder.instance { c =>
+ c.get[String]("type").flatMap {
+ case "EmptyStatement" => Right(EmptyStatement())
+ case "Unparseable" => c.get[String]("content").map(Unparseable.apply)
+ case "Select" =>
+ for {
+ body <- c.get[SelectBody]("body")
+ withClause <- c.get[Seq[WithClause]]("withClause")
+ } yield Select(body, withClause)
+ case "Update" =>
+ for {
+ table <- c.get[Name]("table")
+ set <- c.get[Seq[CirceJson]]("set").flatMap { jsons =>
+ jsons.traverse { json =>
+ for {
+ name <- json.hcursor.get[Name]("name")
+ expr <- json.hcursor.get[Expression]("expression")
+ } yield (name, expr)
+ }
+ }
+ where <- c.get[Option[Expression]]("where")
+ } yield Update(table, set, where)
+ case "Delete" =>
+ for {
+ table <- c.get[Name]("table")
+ where <- c.get[Option[Expression]]("where")
+ } yield Delete(table, where)
+ case "Insert" =>
+ for {
+ table <- c.get[Name]("table")
+ columns <- c.get[Option[Seq[Name]]]("columns")
+ values <- c.get[InsertValues]("values")
+ orReplace <- c.get[Boolean]("orReplace")
+ } yield Insert(table, columns, values, orReplace)
+ case "CreateTable" =>
+ for {
+ name <- c.get[Name]("name")
+ orReplace <- c.get[Boolean]("orReplace")
+ columns <- c.get[Seq[ColumnDefinition]]("columns")
+ annotations <- c.get[Seq[TableAnnotation]]("annotations")
+ } yield CreateTable(name, orReplace, columns, annotations)
+ case "CreateTableAs" =>
+ for {
+ name <- c.get[Name]("name")
+ orReplace <- c.get[Boolean]("orReplace")
+ query <- c.get[SelectBody]("query")
+ } yield CreateTableAs(name, orReplace, query)
+ case "CreateView" =>
+ for {
+ name <- c.get[Name]("name")
+ orReplace <- c.get[Boolean]("orReplace")
+ query <- c.get[SelectBody]("query")
+ materialized <- c.get[Boolean]("materialized")
+ temporary <- c.get[Boolean]("temporary")
+ } yield CreateView(name, orReplace, query, materialized, temporary)
+ case "AlterView" =>
+ for {
+ name <- c.get[Name]("name")
+ action <- c.get[AlterViewAction]("action")
+ } yield AlterView(name, action)
+ case "DropTable" =>
+ for {
+ name <- c.get[Name]("name")
+ ifExists <- c.get[Boolean]("ifExists")
+ } yield DropTable(name, ifExists)
+ case "DropView" =>
+ for {
+ name <- c.get[Name]("name")
+ ifExists <- c.get[Boolean]("ifExists")
+ } yield DropView(name, ifExists)
+ case "Explain" => c.get[SelectBody]("query").map(Explain.apply)
+ case other => oracleStatementDecoderHelper(c, other)
+ }
+ }
+
+ // Oracle-specific codecs
+ // ReferentialAction codecs
+ implicit val referentialActionEncoder: Encoder[ReferentialAction] = Encoder.encodeString.contramap(_.sql)
+ implicit val referentialActionDecoder: Decoder[ReferentialAction] = Decoder.decodeString.emap {
+ case "NO ACTION" => Right(ReferentialAction.NoAction)
+ case "RESTRICT" => Right(ReferentialAction.Restrict)
+ case "CASCADE" => Right(ReferentialAction.Cascade)
+ case "SET NULL" => Right(ReferentialAction.SetNull)
+ case "SET DEFAULT" => Right(ReferentialAction.SetDefault)
+ case other => Left(s"Unknown ReferentialAction: $other")
+ }
+
+ // TableConstraint codecs
+ implicit val tableConstraintEncoder: Encoder[TableConstraint] = deriveConfiguredEncoder[TableConstraint]
+ implicit val tableConstraintDecoder: Decoder[TableConstraint] = deriveConfiguredDecoder[TableConstraint]
+
+ // AlterTableAction codecs (using circe-generic-extras)
+ implicit val alterTableActionEncoder: Encoder[AlterTableAction] = deriveConfiguredEncoder[AlterTableAction]
+ implicit val alterTableActionDecoder: Decoder[AlterTableAction] = deriveConfiguredDecoder[AlterTableAction]
+
+ implicit val columnModificationEncoder: Encoder[ColumnModification] = deriveEncoder[ColumnModification]
+ implicit val columnModificationDecoder: Decoder[ColumnModification] = deriveDecoder[ColumnModification]
+
+ implicit val storageClauseEncoder: Encoder[StorageClause] = deriveEncoder[StorageClause]
+ implicit val storageClauseDecoder: Decoder[StorageClause] = deriveDecoder[StorageClause]
+
+ // Custom codec for (Name, Expression) tuples to encode as objects with "name" and "expression" fields
+ implicit val nameExpressionTupleEncoder: Encoder[(Name, Expression)] = Encoder.instance { case (name, expr) =>
+ CirceJson.obj("name" -> name.asJson, "expression" -> expr.asJson)
+ }
+ implicit val nameExpressionTupleDecoder: Decoder[(Name, Expression)] = Decoder.instance { c =>
+ for {
+ name <- c.get[Name]("name")
+ expr <- c.get[Expression]("expression")
+ } yield (name, expr)
+ }
+
+ // Custom codec for Either[SelectBody, Seq[ColumnDefinition]] used in OracleCreateTable.schema
+ // Encodes as an object with "type" field ("SelectBody" or "Columns") and corresponding data
+ implicit val oracleCreateTableSchemaEncoder: Encoder[Either[SelectBody, Seq[ColumnDefinition]]] =
+ Encoder.instance {
+ case Left(query) =>
+ CirceJson.obj("type" -> CirceJson.fromString("SelectBody"), "query" -> query.asJson)
+ case Right(columns) =>
+ CirceJson.obj("type" -> CirceJson.fromString("Columns"), "columns" -> columns.asJson)
+ }
+ implicit val oracleCreateTableSchemaDecoder: Decoder[Either[SelectBody, Seq[ColumnDefinition]]] =
+ Decoder.instance { c =>
+ c.get[String]("type").flatMap {
+ case "SelectBody" => c.get[SelectBody]("query").map(Left.apply)
+ case "Columns" => c.get[Seq[ColumnDefinition]]("columns").map(Right.apply)
+ case other =>
+ Left(DecodingFailure(s"Unknown OracleCreateTable schema type: $other", c.history))
+ }
+ }
+
+ // OracleStatement codecs (using circe-generic-extras)
+ // Note: The tuple codec above handles Update.set and OracleUpdate.set fields
+ implicit val oracleStatementEncoder: Encoder[OracleStatement] = deriveConfiguredEncoder[OracleStatement]
+ implicit val oracleStatementDecoder: Decoder[OracleStatement] = deriveConfiguredDecoder[OracleStatement]
+
+ // Helper function to decode OracleStatement from Statement decoder
+ def oracleStatementDecoderHelper(c: HCursor, typeName: String): Decoder.Result[Statement] = {
+ // Use the automatic decoder for OracleStatement
+ oracleStatementDecoder(c).left.map { failure =>
+ DecodingFailure(s"Unknown Statement type: $typeName - ${failure.message}", c.history)
+ }
+ }
+}
diff --git a/src/main/scala/sparsity/common/expression/Expression.scala b/src/main/scala/sparsity/common/expression/Expression.scala
new file mode 100644
index 0000000..cb63679
--- /dev/null
+++ b/src/main/scala/sparsity/common/expression/Expression.scala
@@ -0,0 +1,367 @@
+package sparsity.common.expression
+
+import sparsity.common.Name
+import sparsity.common.select.{OrderBy, SelectBody}
+
+trait ToSql {
+ def toSql: String
+}
+
+sealed abstract class Expression extends ToSql {
+ def needsParenthesis: Boolean
+ def children: Seq[Expression]
+ def rebuild(newChildren: Seq[Expression]): Expression
+}
+
+object Expression {
+ def parenthesize(e: Expression) =
+ if (e.needsParenthesis) { "(" + e.toSql + ")" }
+ else { e.toSql }
+ def escapeString(s: String) =
+ s.replaceAll("'", "''")
+}
+
+sealed abstract class PrimitiveValue extends Expression {
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq()
+ def rebuild(newChildren: Seq[Expression]): Expression = this
+}
+case class LongPrimitive(v: Long) extends PrimitiveValue { override def toSql = v.toString }
+case class DoublePrimitive(v: Double) extends PrimitiveValue { override def toSql = v.toString }
+case class StringPrimitive(v: String) extends PrimitiveValue {
+ override def toSql = "'" + Expression.escapeString(v.toString) + "'"
+}
+case class BooleanPrimitive(v: Boolean) extends PrimitiveValue {
+ override def toSql = v.toString
+}
+case class NullPrimitive() extends PrimitiveValue { override def toSql = "NULL" }
+
+/** A column reference 'Table.Col'
+ */
+case class Column(column: Name, table: Option[Name] = None) extends Expression {
+ override def toSql = (table.toSeq ++ Seq(column)).map(_.toSql).mkString(".")
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq()
+ def rebuild(newChildren: Seq[Expression]): Expression = this
+}
+
+/** Any simple binary arithmetic expression See Arithmetic.scala for an enumeration of the possibilities *not* a case
+ * class to avoid namespace collisions. Arithmetic defines apply/unapply explicitly
+ */
+class Arithmetic(val lhs: Expression, val op: Arithmetic.Op, val rhs: Expression) extends Expression {
+ override def toSql =
+ Expression.parenthesize(lhs) + " " +
+ Arithmetic.opString(op) + " " +
+ Expression.parenthesize(rhs)
+ def needsParenthesis = true
+ override def equals(other: Any): Boolean =
+ other match {
+ case Arithmetic(otherlhs, otherop, otherrhs) =>
+ lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs)
+ case _ => false
+ }
+ def children: Seq[Expression] = Seq(lhs, rhs)
+ def rebuild(c: Seq[Expression]): Expression = Arithmetic(c(0), op, c(1))
+}
+
+object Arithmetic extends Enumeration {
+ type Op = Value
+ val Add, Sub, Mult, Div, And, Or, BitAnd, BitOr, ShiftLeft, ShiftRight = Value
+
+ def apply(lhs: Expression, op: Op, rhs: Expression) =
+ new Arithmetic(lhs, op, rhs)
+ def apply(lhs: Expression, op: String, rhs: Expression) =
+ new Arithmetic(lhs, fromString(op), rhs)
+
+ def unapply(e: Arithmetic): Option[(Expression, Op, Expression)] =
+ Some((e.lhs, e.op, e.rhs))
+
+ /** Regular expresion to match any and all binary operations
+ */
+ def matchRegex = """\+|-|\*|/|\|\||&&|\||&""".r
+
+ /** Convert from the operator's string encoding to its Arith.Op rep
+ */
+ def fromString(a: String) =
+ a.toUpperCase match {
+ case "+" => Add
+ case "-" => Sub
+ case "*" => Mult
+ case "/" => Div
+ case "&" => BitAnd
+ case "|" => BitOr
+ case "<<" => ShiftLeft
+ case ">>" => ShiftRight
+ case "&&" => And
+ case "||" => Or
+ case "AND" => And
+ case "OR" => And
+ case x => throw new Exception("Invalid operand '" + x + "'")
+ }
+
+ /** Convert from the operator's Arith.Op representation to a string
+ */
+ def opString(v: Op): String =
+ v match {
+ case Add => "+"
+ case Sub => "-"
+ case Mult => "*"
+ case Div => "/"
+ case BitAnd => "&"
+ case BitOr => "|"
+ case ShiftLeft => "<<"
+ case ShiftRight => ">>"
+ case And => "AND"
+ case Or => "OR"
+ }
+
+ /** Is this binary operation a boolean operator (AND/OR)
+ */
+ def isBool(v: Op): Boolean =
+ v match {
+ case And | Or => true
+ case _ => false
+ }
+
+ /** Is this binary operation a numeric operator (+, -, *, /, & , |)
+ */
+ def isNumeric(v: Op): Boolean = !isBool(v)
+
+}
+
+/** Any simple comparison expression See Comparison.scala for an enumeration of the possibilities *not* a case class to
+ * avoid namespace collisions. Comparison defines apply/unapply explicitly
+ */
+class Comparison(val lhs: Expression, val op: Comparison.Op, val rhs: Expression) extends Expression {
+ override def toSql =
+ Expression.parenthesize(lhs) + " " +
+ Comparison.opString(op) + " " +
+ Expression.parenthesize(rhs)
+ def needsParenthesis = true
+ override def equals(other: Any): Boolean =
+ other match {
+ case Comparison(otherlhs, otherop, otherrhs) =>
+ lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs)
+ case _ => false
+ }
+ def children: Seq[Expression] = Seq(lhs, rhs)
+ def rebuild(c: Seq[Expression]): Expression = Comparison(c(0), op, c(1))
+}
+object Comparison extends Enumeration {
+ type Op = Value
+ val Eq, Neq, Gt, Lt, Gte, Lte, Like, NotLike, RLike, NotRLike = Value
+
+ val strings = Map(
+ "=" -> Eq,
+ "==" -> Eq,
+ "!=" -> Neq,
+ "<>" -> Neq,
+ ">" -> Gt,
+ "<" -> Lt,
+ ">=" -> Gte,
+ "<=" -> Lte,
+ "LIKE" -> Like, // SQL-style LIKE expression
+ "RLIKE" -> RLike, // Regular expression lookup
+ "NOT LIKE" -> NotLike, // Inverse LIKE
+ "NOT RLIKE" -> RLike // Inverse NOT LIKE
+ )
+
+ def apply(lhs: Expression, op: Op, rhs: Expression) =
+ new Comparison(lhs, op, rhs)
+ def apply(lhs: Expression, op: String, rhs: Expression) =
+ new Comparison(lhs, strings(op.toUpperCase), rhs)
+
+ def unapply(e: Comparison): Option[(Expression, Op, Expression)] =
+ Some((e.lhs, e.op, e.rhs))
+
+ def negate(v: Op): Op =
+ v match {
+ case Eq => Neq
+ case Neq => Eq
+ case Gt => Lte
+ case Gte => Lt
+ case Lt => Gte
+ case Lte => Gt
+ case Like => NotLike
+ case NotLike => Like
+ case RLike => NotRLike
+ case NotRLike => RLike
+ }
+
+ def flip(v: Op): Option[Op] =
+ v match {
+ case Eq => Some(Eq)
+ case Neq => Some(Neq)
+ case Gt => Some(Lt)
+ case Gte => Some(Lte)
+ case Lt => Some(Gt)
+ case Lte => Some(Gte)
+ case Like => None
+ case NotLike => None
+ case RLike => None
+ case NotRLike => None
+ }
+
+ def opString(v: Op): String =
+ v match {
+ case Eq => "="
+ case Neq => "<>"
+ case Gt => ">"
+ case Gte => ">="
+ case Lt => "<"
+ case Lte => "<="
+ case Like => "LIKE"
+ case NotLike => "NOT LIKE"
+ case RLike => "RLIKE"
+ case NotRLike => "NOT RLIKE"
+ }
+}
+
+case class Function(name: Name, params: Option[Seq[Expression]], distinct: Boolean = false) extends Expression {
+ override def toSql =
+ name.toSql + "(" +
+ (if (distinct) { "DISTINCT " }
+ else { "" }) +
+ params.map(_.map(_.toSql).mkString(", ")).getOrElse("*") +
+ ")"
+ def needsParenthesis = false
+ def children: Seq[Expression] = params.getOrElse(Seq())
+ def rebuild(c: Seq[Expression]): Expression =
+ Function(name, params.map(_ => c), distinct)
+}
+
+/** Window function expression: function(...) OVER (PARTITION BY ... ORDER BY ...)
+ */
+case class WindowFunction(
+ function: Function,
+ partitionBy: Option[Seq[Expression]] = None,
+ orderBy: Seq[OrderBy] = Seq()
+) extends Expression {
+ override def toSql = {
+ val overClause = Seq("OVER") ++
+ (if (partitionBy.isDefined || orderBy.nonEmpty) {
+ Seq("(") ++
+ partitionBy
+ .map(exprs => Seq("PARTITION BY") ++ exprs.map(_.toSql))
+ .getOrElse(Seq()) ++
+ (if (partitionBy.isDefined && orderBy.nonEmpty) Seq(",") else Seq()) ++
+ (if (orderBy.nonEmpty) Seq("ORDER BY") ++ orderBy.map(_.toSql) else Seq()) ++
+ Seq(")")
+ } else {
+ Seq("()")
+ })
+ function.toSql + " " + overClause.mkString(" ")
+ }
+ def needsParenthesis = false
+ def children: Seq[Expression] =
+ function.children ++ partitionBy.getOrElse(Seq()) ++ orderBy.map(_.expression)
+ def rebuild(c: Seq[Expression]): Expression = {
+ val funcChildren = function.children
+ val funcRebuilt = Function(function.name, function.params.map(_ => c.take(funcChildren.length)), function.distinct)
+ val partitionRebuilt =
+ partitionBy.map(_ => c.slice(funcChildren.length, funcChildren.length + partitionBy.get.length))
+ val orderByRebuilt = orderBy.zipWithIndex.map { case (ob, idx) =>
+ OrderBy(c(funcChildren.length + partitionBy.map(_.length).getOrElse(0) + idx), ob.ascending)
+ }
+ WindowFunction(funcRebuilt, partitionRebuilt, orderByRebuilt)
+ }
+}
+
+case class JDBCVar() extends Expression {
+ override def toSql = "?"
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq()
+ def rebuild(c: Seq[Expression]): Expression = this
+}
+
+/** A scalar subquery expression: (SELECT ...) Used in contexts like VALUES clauses where a subquery can be used as a
+ * value.
+ */
+case class Subquery(query: SelectBody) extends Expression {
+ override def toSql = "(" + query.toSql + ")"
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq() // Subquery is a leaf expression
+ def rebuild(c: Seq[Expression]): Expression = this
+}
+
+case class CaseWhenElse(target: Option[Expression], cases: Seq[(Expression, Expression)], otherwise: Expression)
+ extends Expression {
+ override def toSql =
+ "CASE " +
+ target.map(_.toSql).getOrElse("") +
+ cases
+ .map { clause =>
+ "WHEN " + Expression.parenthesize(clause._1) +
+ " THEN " + Expression.parenthesize(clause._2)
+ }
+ .mkString(" ") + " " +
+ "ELSE " + otherwise.toSql + " END"
+ def needsParenthesis = false
+ def children: Seq[Expression] =
+ target.toSeq ++ Seq(otherwise) ++ cases.flatMap(x => Seq(x._1, x._2))
+ def rebuild(c: Seq[Expression]): Expression = {
+ val (newTarget, newOtherwise, newCases) =
+ if (c.length % 2 == 0) { (Some(c.head), c.tail.head, c.tail.tail) }
+ else { (None, c.head, c.tail) }
+ CaseWhenElse(newTarget, newCases.grouped(2).map(x => (x(0), x(1))).toSeq, newOtherwise)
+ }
+}
+
+abstract class NegatableExpression extends Expression {
+ def toNegatedString: String
+}
+
+case class IsNull(target: Expression) extends NegatableExpression {
+ override def toSql =
+ Expression.parenthesize(target) + " IS NULL"
+ def toNegatedString =
+ Expression.parenthesize(target) + " IS NOT NULL"
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq(target)
+ def rebuild(c: Seq[Expression]): Expression = IsNull(c(0))
+}
+
+case class Not(target: Expression) extends Expression {
+ override def toSql =
+ target match {
+ case neg: NegatableExpression => neg.toNegatedString
+ case _ => "NOT " + Expression.parenthesize(target)
+ }
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq(target)
+ def rebuild(c: Seq[Expression]): Expression = Not(c(0))
+}
+
+case class Cast(expression: Expression, t: Name) extends Expression {
+ override def toSql = "CAST(" + expression.toSql + " AS " + t.toSql + ")"
+ def needsParenthesis = false
+ def children: Seq[Expression] = Seq(expression)
+ def rebuild(c: Seq[Expression]): Expression = Cast(c(0), t)
+}
+
+case class InExpression(expression: Expression, source: Either[Seq[Expression], SelectBody])
+ extends NegatableExpression {
+ override def toSql =
+ Expression.parenthesize(expression) + " IN " + sourceString
+ override def toNegatedString =
+ Expression.parenthesize(expression) + " NOT IN " + sourceString
+ def needsParenthesis = false
+ def sourceString =
+ source match {
+ case Left(elems) => elems.map(Expression.parenthesize(_)).mkString(", ")
+ case Right(query) => "(" + query.toSql + ")"
+ }
+ def children: Seq[Expression] =
+ Seq(expression) ++ (source match {
+ case Left(expr) => expr
+ case Right(_) => Seq()
+ })
+ def rebuild(c: Seq[Expression]): Expression =
+ InExpression(
+ c.head,
+ source match {
+ case Left(_) => Left(c.tail)
+ case Right(query) => Right(query)
+ }
+ )
+}
diff --git a/src/main/scala/sparsity/common/parser/ErrorMessage.scala b/src/main/scala/sparsity/common/parser/ErrorMessage.scala
new file mode 100644
index 0000000..1d9ddc3
--- /dev/null
+++ b/src/main/scala/sparsity/common/parser/ErrorMessage.scala
@@ -0,0 +1,36 @@
+package sparsity.common.parser
+
+import fastparse.Parsed
+
+object ErrorMessage {
+
+ def indexToLine(lines: Seq[String], index: Int): (Int, Int) = {
+ val cumulativeLengths = lines.scanLeft(0)(_ + _.length)
+ cumulativeLengths.zipWithIndex
+ .sliding(2)
+ .collectFirst {
+ case Seq((prev, _), (curr, i)) if index < curr =>
+ (i - 1, index - prev)
+ }
+ .getOrElse((lines.length - 1, 0))
+ }
+
+ def format(query: String, msg: Parsed.Failure): String = {
+ val lines = query.split("\n")
+
+ // println(msg.longMsg)
+
+ val expected: String = "expected " + msg.label
+ val index: Int = msg.extra.index
+ // if(msg.extra.stack.isEmpty){ ("parse error", ) }
+ // else {
+ // val (expected, index) = msg.extra.stack(0)
+ // (" expected "+expected, index)
+ // }
+ val (lineNumber, linePosition) = indexToLine(lines, index)
+ return lines(lineNumber) + "\n" +
+ " " * linePosition +
+ "^--- " + expected
+ }
+
+}
diff --git a/src/main/scala/sparsity/parser/ParseException.scala b/src/main/scala/sparsity/common/parser/ParseException.scala
similarity index 76%
rename from src/main/scala/sparsity/parser/ParseException.scala
rename to src/main/scala/sparsity/common/parser/ParseException.scala
index 78ef80f..20b37ab 100644
--- a/src/main/scala/sparsity/parser/ParseException.scala
+++ b/src/main/scala/sparsity/common/parser/ParseException.scala
@@ -1,4 +1,4 @@
-package sparsity.parser
+package sparsity.common.parser
import fastparse.Parsed
diff --git a/src/main/scala/sparsity/common/parser/SQL.scala b/src/main/scala/sparsity/common/parser/SQL.scala
new file mode 100644
index 0000000..bed4cf5
--- /dev/null
+++ b/src/main/scala/sparsity/common/parser/SQL.scala
@@ -0,0 +1,17 @@
+package sparsity.common.parser
+
+import fastparse._
+import scala.io._
+import java.io._
+import sparsity.common.statement._
+import sparsity.ansi.ANSISQL
+
+/** Backward-compatible SQL parser object. Delegates to ANSISQL for parsing.
+ *
+ * @deprecated
+ * Consider using ANSISQL, OracleSQL, or PostgresSQL directly for dialect-specific parsing
+ */
+object SQL {
+ def apply(input: String): Parsed[Statement] = ANSISQL(input)
+ def apply(input: Reader): StreamParser[Statement] = ANSISQL(input)
+}
diff --git a/src/main/scala/sparsity/common/parser/SQLBase.scala b/src/main/scala/sparsity/common/parser/SQLBase.scala
new file mode 100644
index 0000000..5b9b75b
--- /dev/null
+++ b/src/main/scala/sparsity/common/parser/SQLBase.scala
@@ -0,0 +1,893 @@
+package sparsity.common.parser
+
+import fastparse._
+import scala.io._
+import java.io._
+import sparsity.common.Name
+import sparsity.common.statement._
+import sparsity.common.select._
+import sparsity.common.alter._
+import sparsity.common.expression.{
+ Arithmetic,
+ BooleanPrimitive,
+ CaseWhenElse,
+ Cast,
+ Column,
+ Comparison,
+ DoublePrimitive,
+ Expression,
+ Function,
+ InExpression,
+ IsNull,
+ JDBCVar,
+ LongPrimitive,
+ Not,
+ NullPrimitive,
+ StringPrimitive
+}
+import sparsity.common.DataType
+
+/** Base trait for SQL dialect parsers. Provides shared parser methods and configuration hooks for dialect-specific
+ * behavior.
+ */
+trait SQLBase {
+ // Type members allow dialects to extend the Statement ADT
+ type Stmt <: Statement
+ // Configuration for simple variations (abstract, implemented by subclasses)
+ def caseSensitive: Boolean
+ def statementTerminator: String // ";" or "/"
+ def supportsCTEs: Boolean
+ def supportsReturning: Boolean
+ def supportsIfExists: Boolean
+ def stringConcatOp: String // "||" or "+"
+ implicit val whitespaceImplicit: fastparse.Whitespace = MultiLineWhitespace.whitespace
+ // Expression parsing methods
+ // These use `this` to access the SQL parser instance, removing hard-coded ANSISQL.instance references
+
+ // Elements methods - inlined from Elements.scala to allow overriding
+ def anyKeyword[$ : P] = P(
+ StringInIgnoreCase(
+ "ADD",
+ "ALL",
+ "ALTER",
+ "AND",
+ "AS",
+ "ASC",
+ "BEGIN",
+ "BETWEEN",
+ "BY",
+ "CASE",
+ "CAST",
+ "COLUMN",
+ "COMMENT",
+ "COMMIT",
+ "CONSTRAINT",
+ "CREATE",
+ "DECLARE",
+ "DEFAULT",
+ "DEFINE",
+ "DELETE",
+ "DESC",
+ "DISABLE",
+ "DISTINCT",
+ "DOUBLE",
+ "DROP",
+ "ELSE",
+ "ELSIF",
+ "ENABLE",
+ "END",
+ "EXCEPTION",
+ "EXISTS",
+ "EXPLAIN",
+ "FALSE",
+ "FOR",
+ "FROM",
+ "FULL",
+ "GROUP",
+ "GRANT",
+ "HAVING",
+ "IF",
+ "IN",
+ "INDEX",
+ "INNER",
+ "INSERT",
+ "INTO",
+ "IS",
+ "JOIN",
+ "KEY",
+ "LEFT",
+ "LIMIT",
+ "LOOP",
+ "MATERIALIZE",
+ "MATERIALIZED",
+ "MODIFY",
+ "NATURAL",
+ "NOT",
+ "NULL",
+ "OFF",
+ "OFFSET",
+ "ON",
+ "OR",
+ "ORDER",
+ "OVER",
+ "OUTER",
+ "PARTITION",
+ "PRECISION",
+ "PRIMARY",
+ "PROMPT",
+ "RENAME",
+ "REPLACE",
+ "REVERSE",
+ "REVOKE",
+ "RIGHT",
+ "SELECT",
+ "SET",
+ "SMALLINT",
+ "SYNONYM",
+ "TABLE",
+ "TEMPORARY",
+ "THEN",
+ "TO",
+ "TRUE",
+ "UNIQUE",
+ "UPDATE",
+ "UNION",
+ "VALUES",
+ "VIEW",
+ "WHEN",
+ "WHERE",
+ "WHILE",
+ "WITH",
+ "FOREIGN",
+ "REFERENCES",
+ "CASCADE",
+ "RESTRICT",
+ "ACTION",
+ "PURGE"
+ // avoid dropping keyword prefixes
+ // (e.g., 'int' matched by 'in')
+ ).! ~~ !CharIn("a-zA-Z0-9_")
+ )
+
+ def keyword[$ : P](expected: String*) = P[Unit](
+ anyKeyword
+ .opaque(expected.mkString(" or "))
+ .filter(kw => expected.exists(_.equalsIgnoreCase(kw)))
+ .map(_ => ())
+ )
+
+ def avoidReservedKeywords[$ : P] = P(!anyKeyword)
+
+ def rawIdentifier[$ : P] = P(
+ avoidReservedKeywords ~~
+ (CharIn("_a-zA-Z") ~~ CharsWhileIn("a-zA-Z0-9_").?).!.map(Name(_))
+ )
+
+ def quotedIdentifier[$ : P] = P(
+ (("`" ~~/ CharsWhile(_ != '`').! ~ "`")
+ | ("\"" ~~/ CharsWhile(_ != '"').! ~~ "\"")).map(Name(_, true: scala.Boolean))
+ )
+
+ def identifier[$ : P]: P[Name] = P(rawIdentifier | quotedIdentifier)
+
+ def dottedPair[$ : P]: P[(Option[Name], Name)] = P((identifier ~~ ("." ~~ identifier).?).map {
+ case (x, None) => (None, x)
+ case (x, Some(y)) => (Some(x), y)
+ })
+
+ def dottedWildcard[$ : P]: P[Name] = P(identifier ~~ ".*")
+
+ // Schema-qualified table name (schema.table) as a single Name
+ def qualifiedTableName[$ : P]: P[Name] = P((identifier ~~ ("." ~~ identifier).?).map {
+ case (x, None) => x
+ case (x, Some(y)) => Name(x.name + "." + y.name, x.quoted || y.quoted)
+ })
+
+ def digits[$ : P] = P(CharsWhileIn("0-9"))
+ def plusMinus[$ : P] = P("-" | "+")
+ def integral[$ : P] = "0" | CharIn("1-9") ~~ digits.?
+
+ def integer[$ : P] = (plusMinus.? ~~ digits).!.map(_.toLong) ~~ !"." // Fail on a trailing period
+ def decimal[$ : P] =
+ (plusMinus.? ~~ digits ~~ ("." ~~ digits).? ~~ ("e" ~~ plusMinus.? ~~ digits).?).!.map(_.toDouble)
+
+ def escapeQuote[$ : P] = P("''".!.map(_.replaceAll("''", "'")))
+ def escapedString[$ : P] = P((CharsWhile(_ != '\'') | escapeQuote).repX.!.map {
+ _.replaceAll("''", "'")
+ })
+ def quotedString[$ : P] = P("'" ~~ escapedString ~~ "'")
+
+ def comma[$ : P] = P(",")
+
+ def expressionList[$ : P]: P[Seq[Expression]] = P(expression.rep(sep = comma))
+
+ def expression[$ : P]: P[Expression] = P(disjunction)
+
+ def disjunction[$ : P] = P((conjunction ~ (!keyword("ORDER") ~ keyword("OR") ~ conjunction).rep).map { x =>
+ x._2.fold(x._1) { (accum, current) =>
+ Arithmetic(accum, Arithmetic.Or, current)
+ }
+ })
+
+ def conjunction[$ : P] = P((negation ~ (keyword("AND") ~ negation).rep).map { x =>
+ x._2.fold(x._1) { (accum, current) =>
+ Arithmetic(accum, Arithmetic.And, current)
+ }
+ })
+
+ def negation[$ : P] = P((keyword("NOT").!.? ~ comparison).map {
+ case (None, expression) => expression
+ case (_, expression) => Not(expression)
+ })
+
+ def comparison[$ : P] = P(
+ (isNullBetweenIn ~
+ (StringInIgnoreCase("=", "==", "!=", "<>", ">", "<", ">=", "<=", "LIKE", "NOT LIKE", "RLIKE", "NOT RLIKE").! ~
+ addSub).?).map {
+ case (expression, None) => expression
+ case (lhs, Some((op, rhs))) => Comparison(lhs, op, rhs)
+ }
+ )
+
+ def optionalNegation[$ : P]: P[Expression => Expression] = P(keyword("NOT").!.?.map { x =>
+ if (x.isDefined) { y => Not(y) }
+ else { y => y }
+ })
+
+ def isNullBetweenIn[$ : P] = P(
+ (addSub ~ (
+ // IS [NOT] NULL -> IsNull(...)
+ (keyword("IS") ~ optionalNegation ~
+ keyword("NULL").map(_ => IsNull(_))) | (
+ // [IS] [NOT] BETWEEN low AND high
+ keyword("IS").? ~ optionalNegation ~
+ (
+ keyword("BETWEEN") ~
+ addSub ~ keyword("AND") ~ addSub
+ ).map { case (low, high) =>
+ (lhs: Expression) =>
+ Arithmetic(Comparison(lhs, Comparison.Gte, low), Arithmetic.And, Comparison(lhs, Comparison.Lte, high))
+ }
+ ) | (
+ optionalNegation ~ keyword("IN") ~/ (
+ (
+ // IN ( SELECT ... )
+ &("(" ~ keyword("SELECT")) ~/
+ "(" ~/ this.select.map { query =>
+ InExpression(_: Expression, Right(query))
+ } ~ ")"
+ ) | (
+ // IN ('list', 'of', 'items')
+ "(" ~/
+ expressionList.map(exprs => InExpression(_: Expression, Left(exprs))) ~
+ ")"
+ )
+ )
+ )
+ ).?).map {
+ case (expression, None) => expression
+ case (expression, Some((neg, build))) =>
+ val f: Expression => Expression = neg.asInstanceOf[Expression => Expression]
+ f(build(expression))
+ }
+ )
+
+ def addSub[$ : P] = P((multDiv ~ ((CharIn("+\\-&\\|") | StringIn("<<", ">>")).! ~ multDiv).rep).map { x =>
+ x._2.foldLeft(x._1: Expression) { (accum, current) =>
+ Arithmetic(accum, current._1, current._2)
+ }
+ })
+
+ def multDivOp[$ : P] = P(CharIn("*/"))
+
+ def multDiv[$ : P] = P((leaf ~ (multDivOp.! ~ leaf).rep).map { x =>
+ x._2.foldLeft(x._1) { (accum, current) =>
+ Arithmetic(accum, current._1, current._2)
+ }
+ })
+
+ def leaf[$ : P]: P[Expression] = P(
+ parens |
+ primitive |
+ jdbcvar |
+ caseWhen | ifThenElse |
+ cast |
+ nullLiteral |
+ // need to lookahead `function` to avoid conflicts with `column`
+ &(identifier ~ "(") ~ function |
+ column
+ )
+
+ def parens[$ : P] = P("(" ~ expression ~ ")")
+
+ def primitive[$ : P] = P(
+ integer.map(v => LongPrimitive(v))
+ | decimal.map(v => DoublePrimitive(v))
+ | quotedString.map(v => StringPrimitive(v))
+ | keyword("TRUE").map(_ => BooleanPrimitive(true))
+ | keyword("FALSE").map(_ => BooleanPrimitive(false))
+ )
+
+ def column[$ : P] = P(dottedPair.map(x => Column(x._2, x._1)))
+
+ def nullLiteral[$ : P] = P(keyword("NULL").map(_ => NullPrimitive()))
+
+ def function[$ : P] = P(
+ (identifier ~ "(" ~/
+ keyword("DISTINCT").!.?.map {
+ _ != None
+ } ~
+ ("*".!.map(_ => None)
+ | expressionList.map(Some(_))) ~ ")").map { case (name, distinct, args) =>
+ Function(name, args, distinct)
+ }
+ )
+
+ def jdbcvar[$ : P] = P("?".!.map(_ => JDBCVar()))
+
+ def caseWhen[$ : P] = P(
+ keyword("CASE") ~/
+ (!keyword("WHEN") ~ expression).? ~
+ (
+ keyword("WHEN") ~/
+ expression ~
+ keyword("THEN") ~/
+ expression
+ ).rep ~
+ keyword("ELSE") ~/
+ expression ~
+ keyword("END")
+ ).map { case (target, whenThen, orElse) =>
+ CaseWhenElse(target, whenThen, orElse)
+ }
+
+ def ifThenElse[$ : P] = P(
+ keyword("IF") ~/
+ expression ~/
+ keyword("THEN") ~/
+ expression ~/
+ keyword("ELSE") ~/
+ expression ~/
+ keyword("END")
+ ).map { case (condition, thenClause, elseClause) =>
+ CaseWhenElse(None, Seq(condition -> thenClause), elseClause)
+ }
+
+ def cast[$ : P] = P(
+ (
+ keyword("CAST") ~/ "(" ~/
+ expression ~ keyword("AS") ~/
+ identifier ~ ")"
+ ).map { case (expression, t) =>
+ Cast(expression, t)
+ }
+ )
+
+ def statementTerminatorParser[$ : P]: P[Unit] = P(statementTerminator)
+
+ // Parser for unparseable statements (error recovery)
+ // Consumes everything up to and including the next statement terminator
+ // IMPORTANT: Must match at least one char to avoid infinite loops with .rep
+ def unparseableStatement[$ : P]: P[Unparseable] = P(
+ // Consume at least one char, then everything up to (but not including) the terminator
+ AnyChar.! ~ CharsWhile(c => c != statementTerminator.head).! ~ statementTerminatorParser.!
+ ).map { case (first, rest, terminator) => Unparseable((first + rest + terminator).trim) }
+
+ // Shared parsers with conditional logic
+ // Base implementation includes unparseable fallback - can be overridden by dialects that need special handling
+ def terminatedStatement[$ : P]: P[Stmt] = P(
+ (statement ~/ statementTerminatorParser) |
+ unparseableStatement.map(_.asInstanceOf[Stmt]) // Fallback for parse errors (requires at least one char)
+ )
+
+ def statement[$ : P]: P[Stmt] =
+ P(
+ Pass ~ // This trims off leading whitespace
+ (parenthesizedSelect.map(s => Select(s).asInstanceOf[Stmt])
+ | update
+ | delete
+ | insert
+ | createStatement // Delegates to subclass
+ | (&(keyword("ALTER")) ~/
+ alterView)
+ | dropTableOrView
+ | explainStatement
+ | dialectSpecificStatement // Hook for dialect extensions
+ )
+ )
+
+ def explainStatement[$ : P]: P[Stmt] = P(keyword("EXPLAIN") ~ select.map(s => Explain(s).asInstanceOf[Stmt]))
+
+ // Methods that can be overridden for dialect-specific behavior
+ def createStatement[$ : P]: P[Stmt]
+ def dialectSpecificStatement[$ : P]: P[Stmt]
+ def dataType[$ : P]: P[DataType]
+
+ // Shared logic with conditional branches
+ def ifExists[$ : P]: P[Boolean] =
+ if (supportsIfExists) {
+ P((keyword("IF") ~ keyword("EXISTS")).!.?.map(_ != None))
+ } else {
+ P(Pass).map(_ => false)
+ }
+
+ def orReplace[$ : P]: P[Boolean] = P((keyword("OR") ~ keyword("REPLACE")).!.?.map(_ != None))
+
+ def alterView[$ : P]: P[Stmt] = P(
+ (
+ keyword("ALTER") ~
+ keyword("VIEW") ~/
+ identifier ~
+ (
+ (keyword("MATERIALIZE").!.map(_ => Materialize(true)))
+ | (keyword("DROP") ~
+ keyword("MATERIALIZE").!.map(_ => Materialize(false)))
+ )
+ ).map { case (name, op) => AlterView(name, op).asInstanceOf[Stmt] }
+ )
+
+ def dropTableOrView[$ : P]: P[Stmt] = P(
+ (
+ keyword("DROP") ~
+ keyword("TABLE", "VIEW").!.map(_.toUpperCase) ~/
+ ifExists ~
+ identifier
+ ).map {
+ case ("TABLE", ifExists, name) =>
+ DropTable(name, ifExists).asInstanceOf[Stmt]
+ case ("VIEW", ifExists, name) =>
+ DropView(name, ifExists).asInstanceOf[Stmt]
+ case (_, _, _) =>
+ throw new Exception("Internal Error")
+ }
+ )
+
+ def createView[$ : P]: P[Stmt] = P(
+ (
+ keyword("CREATE") ~
+ orReplace ~
+ keyword("MATERIALIZED", "TEMPORARY").!.?.map {
+ _.map(_.toUpperCase) match {
+ case Some("MATERIALIZED") => (true, false)
+ case Some("TEMPORARY") => (false, true)
+ case _ => (false, false)
+ }
+ } ~
+ keyword("VIEW") ~/
+ identifier ~
+ keyword("AS") ~/
+ select
+ ).map { case (orReplace, (materialized, temporary), name, query) =>
+ CreateView(name, orReplace, query, materialized, temporary).asInstanceOf[Stmt]
+ }
+ )
+
+ // Parser for function calls without parentheses (e.g., CURRENT_TIMESTAMP, SYSDATE)
+ def functionCallNoParens[$ : P]: P[Expression] = P(identifier.map { name =>
+ // Treat as a function call with no arguments
+ Function(name, None, false)
+ })
+
+ def columnAnnotation[$ : P]: P[ColumnAnnotation] = P(
+ (
+ keyword("PRIMARY") ~/
+ keyword("KEY").map(_ => ColumnIsPrimaryKey())
+ ) | (
+ keyword("NOT") ~/
+ keyword("NULL").map(_ => ColumnIsNotNullable())
+ ) | (
+ keyword("DEFAULT") ~/
+ (
+ ("(" ~ expression ~ ")")
+ | function // Function calls with parentheses
+ | functionCallNoParens // Function calls without parentheses (CURRENT_TIMESTAMP, etc.)
+ | primitive
+ ).map(ColumnDefaultValue(_))
+ )
+ )
+
+ def oneOrMoreAttributes[$ : P]: P[Seq[Name]] = P(
+ ("(" ~/ identifier.rep(sep = comma, min = 1) ~ ")")
+ | identifier.map(Seq(_))
+ )
+
+ def tableField[$ : P]: P[Either[TableAnnotation, ColumnDefinition]] = P(
+ (
+ keyword("PRIMARY") ~/
+ keyword("KEY") ~
+ oneOrMoreAttributes.map(attrs => Left(TablePrimaryKey(attrs)))
+ ) | (
+ keyword("INDEX") ~/
+ keyword("ON") ~
+ oneOrMoreAttributes.map(attrs => Left(TableIndexOn(attrs)))
+ ) | (
+ (
+ identifier ~/
+ identifier ~
+ ("(" ~
+ primitive.rep(sep = ",") ~
+ ")").?.map(_.getOrElse(Seq())) ~
+ columnAnnotation.rep
+ ).map { case (name, t, args, annotations) =>
+ Right(ColumnDefinition(name, t, args, annotations))
+ }
+ )
+ )
+
+ def createTable[$ : P]: P[Stmt] = P(
+ (
+ keyword("CREATE") ~
+ orReplace ~
+ keyword("TABLE") ~/
+ identifier ~
+ (
+ (keyword("AS") ~/ select).map(Left(_))
+ | ("(" ~/
+ tableField.rep(sep = comma) ~
+ ")").map(Right(_))
+ )
+ ).map {
+ case (orReplace, table, Left(query)) =>
+ CreateTableAs(table, orReplace, query).asInstanceOf[Stmt]
+ case (orReplace, table, Right(fields)) =>
+ val columns = fields.collect { case Right(r) => r }
+ val annotations = fields.collect { case Left(l) => l }
+ CreateTable(table, orReplace, columns, annotations).asInstanceOf[Stmt]
+ }
+ )
+
+ def valueList[$ : P]: P[InsertValues] = P(
+ (
+ keyword("VALUES") ~/
+ ("(" ~/ expressionList ~ ")").rep(sep = comma)
+ ).map(ExplicitInsert(_))
+ )
+
+ def insert[$ : P]: P[Stmt] = P(
+ (
+ keyword("INSERT") ~/
+ (
+ keyword("OR") ~/
+ keyword("REPLACE")
+ ).!.?.map { case None => false; case _ => true } ~
+ keyword("INTO") ~/
+ identifier ~
+ ("(" ~/
+ identifier ~
+ (comma ~/ identifier).rep ~
+ ")").map(x => Seq(x._1) ++ x._2).? ~
+ (
+ (&(keyword("SELECT")) ~/ select.map(SelectInsert(_)))
+ | (&(keyword("VALUES")) ~/ valueList)
+ )
+ ).map { case (orReplace, table, columns, values) =>
+ Insert(table, columns, values, orReplace).asInstanceOf[Stmt]
+ }
+ )
+
+ def delete[$ : P]: P[Stmt] = P(
+ (
+ keyword("DELETE") ~/
+ keyword("FROM") ~/
+ identifier ~
+ (
+ keyword("WHERE") ~/
+ expression
+ ).?
+ ).map { case (table, where) => Delete(table, where).asInstanceOf[Stmt] }
+ )
+
+ def update[$ : P]: P[Stmt] = P(
+ (
+ keyword("UPDATE") ~/
+ identifier ~
+ keyword("SET") ~/
+ (
+ identifier ~
+ "=" ~/
+ expression
+ ).rep(sep = comma, min = 1) ~
+ (
+ StringInIgnoreCase("WHERE") ~/
+ expression
+ ).?
+ ).map { case (table, set, where) =>
+ Update(table, set, where).asInstanceOf[Stmt]
+ }
+ )
+
+ def alias[$ : P]: P[Name] =
+ P(keyword("AS").? ~ identifier)
+
+ def selectTarget[$ : P]: P[SelectTarget] = P(
+ P("*").map(_ => SelectAll())
+ // Dotted wildcard needs a lookahead since a single token isn't
+ // enough to distinguish between `foo`.* and `foo` AS `bar`
+ | (&(dottedWildcard) ~
+ dottedWildcard.map(SelectTable(_)))
+ | (expression ~ alias.?).map(x => SelectExpression(x._1, x._2))
+ )
+
+ def simpleFromElement[$ : P]: P[FromElement] = P(
+ (("(" ~ select ~ ")" ~ alias).map(x => FromSelect(x._1, x._2)))
+ | ((dottedPair ~ alias.?).map { case (schema, table, alias) =>
+ FromTable(schema, table, alias)
+ })
+ | (("(" ~ fromElement ~ ")" ~ alias.?).map {
+ case (from, None) => from
+ case (from, Some(alias)) => from.withAlias(alias)
+ })
+ )
+
+ def joinWith[$ : P]: P[Join.Type] = P(
+ keyword("JOIN").map(Unit => Join.Inner)
+ | ((
+ keyword("NATURAL").!.map(Unit => Join.Natural)
+ | keyword("INNER").map(Unit => Join.Inner)
+ | ((
+ keyword("LEFT").map(Unit => Join.LeftOuter)
+ | keyword("RIGHT").map(Unit => Join.RightOuter)
+ | keyword("FULL").map(Unit => Join.FullOuter)
+ ).?.map(_.getOrElse(Join.FullOuter)) ~/
+ keyword("OUTER"))
+ ) ~/ keyword("JOIN"))
+ )
+
+ def fromElement[$ : P]: P[FromElement] = P(
+ (
+ simpleFromElement ~ (
+ &(joinWith) ~
+ joinWith ~/
+ simpleFromElement ~/
+ (
+ keyword("ON") ~/
+ expression
+ ).? ~
+ alias.?
+ ).rep
+ ).map { case (lhs, rest) =>
+ rest.foldLeft(lhs) { (lhs, next) =>
+ val (t, rhs, onClause, alias) = next
+ FromJoin(lhs, rhs, t, onClause.getOrElse(BooleanPrimitive(true)), alias)
+ }
+ }
+ )
+
+ def fromClause[$ : P] = P(
+ keyword("FROM") ~/
+ fromElement.rep(sep = comma, min = 1)
+ )
+
+ def whereClause[$ : P] = P(keyword("WHERE") ~/ expression)
+
+ def groupByClause[$ : P] = P(
+ keyword("GROUP") ~/
+ keyword("BY") ~/
+ expressionList
+ )
+
+ def havingClause[$ : P] = P(keyword("HAVING") ~ expression)
+
+ def options[A](default: A, options: Map[String, A]): (Option[String] => A) =
+ _.map(_.toUpperCase).map(options(_)).getOrElse(default)
+
+ def ascOrDesc[$ : P] = P(keyword("ASC", "DESC").!.?.map {
+ options(true, Map("ASC" -> true, "DESC" -> false))
+ })
+
+ def orderBy[$ : P] = P((expression ~ ascOrDesc).map(x => OrderBy(x._1, x._2)))
+
+ def orderByClause[$ : P] = P(
+ keyword("ORDER") ~/
+ keyword("BY") ~/
+ orderBy.rep(sep = comma, min = 1)
+ )
+
+ def limitClause[$ : P] = P(
+ keyword("LIMIT") ~/
+ integer
+ )
+
+ def offsetClause[$ : P] = P(
+ keyword("OFFSET") ~/
+ integer
+ )
+
+ def allOrDistinct[$ : P] = P(keyword("ALL", "DISTINCT").!.?.map {
+ options(Union.Distinct, Map("ALL" -> Union.All, "DISTINCT" -> Union.Distinct))
+ })
+
+ def unionClause[$ : P] = P(keyword("UNION") ~/ allOrDistinct ~/ parenthesizedSelect)
+
+ def parenthesizedSelect[$ : P]: P[SelectBody] = P(
+ (
+ "(" ~/ select ~ ")" ~/ unionClause.?
+ ).map {
+ case (body, Some((unionType, unionBody))) => body.unionWith(unionType, unionBody)
+ case (body, None) => body
+ } | select
+ )
+
+ def select[$ : P]: P[SelectBody] = P(
+ (
+ keyword("SELECT") ~/
+ keyword("DISTINCT").!.?.map(_ != None) ~/
+ selectTarget.rep(sep = ",") ~
+ fromClause.?.map(_.toSeq.flatten) ~
+ whereClause.? ~
+ groupByClause.? ~
+ havingClause.? ~
+ orderByClause.?.map(_.toSeq.flatten) ~
+ limitClause.? ~
+ offsetClause.? ~
+ unionClause.?
+ ).map { case (distinct, targets, froms, where, groupBy, having, orderBy, limit, offset, union) =>
+ SelectBody(
+ distinct = distinct,
+ target = targets,
+ from = froms,
+ where = where,
+ groupBy = groupBy,
+ having = having,
+ orderBy = orderBy,
+ limit = limit,
+ offset = offset,
+ union = union
+ )
+ }
+ )
+
+ // Parse multiple statements with error recovery
+ // Handle End separately to avoid infinite loops
+ def allStatements[$ : P]: P[Seq[Stmt]] = P(Start ~ terminatedStatement.rep ~ End).map(_.toSeq)
+}
+
+trait SQLBaseObject {
+ type Stmt <: Statement
+ def name: String
+ def apply(input: String): Parsed[Stmt]
+ def apply(input: Reader): StreamParser[Stmt]
+
+ // Abstract method to get statement terminator for this dialect
+ protected def statementTerminator: String
+
+ // Override this to specify all valid terminator characters (e.g., both ';' and '/')
+ protected def statementTerminatorChars: Set[Char] = statementTerminator.toSet
+
+ // Parse multiple statements with error recovery, tracking positions
+ def parseAll(input: String): Seq[StatementParseResult] = {
+ import sparsity.common.statement.StatementParseResult
+ import fastparse._
+ import fastparse.Parsed
+
+ def parseStatementsSequentially(
+ remaining: String,
+ currentPos: Int,
+ results: List[StatementParseResult]
+ ): List[StatementParseResult] = {
+ // Skip whitespace and comment lines at the start
+ def skipWhitespaceAndComments(s: String): (String, Int) = {
+ var pos = 0
+ var skipped = 0
+ while (pos < s.length) {
+ val char = s(pos)
+ if (char.isWhitespace) {
+ pos += 1
+ skipped += 1
+ } else if (pos + 1 < s.length && s(pos) == '-' && s(pos + 1) == '-') {
+ // Skip -- comment until newline
+ while (pos < s.length && s(pos) != '\n') {
+ pos += 1
+ skipped += 1
+ }
+ // Skip the newline too
+ if (pos < s.length) {
+ pos += 1
+ skipped += 1
+ }
+ } else if (pos + 1 < s.length && s(pos) == '/' && s(pos + 1) == '*') {
+ // Skip /* */ comment
+ pos += 2
+ skipped += 2
+ while (pos + 1 < s.length && !(s(pos) == '*' && s(pos + 1) == '/')) {
+ pos += 1
+ skipped += 1
+ }
+ if (pos + 1 < s.length) {
+ pos += 2
+ skipped += 2
+ }
+ } else {
+ // Found non-whitespace, non-comment content
+ return (s.substring(pos), skipped)
+ }
+ }
+ ("", skipped)
+ }
+
+ // Find the next statement terminator, skipping over string literals and comments
+ def findNextTerminator(s: String, terminatorChars: Set[Char]): Int = {
+ var pos = 0
+ while (pos < s.length) {
+ val char = s(pos)
+ if (char == '\'') {
+ // Inside string literal - skip until closing quote (handle escaped quotes '')
+ pos += 1
+ while (pos < s.length) {
+ if (s(pos) == '\'') {
+ if (pos + 1 < s.length && s(pos + 1) == '\'') {
+ pos += 2 // Skip escaped quote ''
+ } else {
+ pos += 1 // End of string
+ return findNextTerminator(s.substring(pos), terminatorChars) match {
+ case -1 => -1
+ case n => pos + n
+ }
+ }
+ } else {
+ pos += 1
+ }
+ }
+ return -1 // Unclosed string
+ } else if (pos + 1 < s.length && char == '-' && s(pos + 1) == '-') {
+ // Skip -- comment until newline
+ while (pos < s.length && s(pos) != '\n') pos += 1
+ } else if (pos + 1 < s.length && char == '/' && s(pos + 1) == '*') {
+ // Skip /* */ comment
+ pos += 2
+ while (pos + 1 < s.length && !(s(pos) == '*' && s(pos + 1) == '/')) pos += 1
+ if (pos + 1 < s.length) pos += 2
+ } else if (terminatorChars.contains(char)) {
+ return pos
+ } else {
+ pos += 1
+ }
+ }
+ -1
+ }
+
+ val (trimmed, skipped) = skipWhitespaceAndComments(remaining)
+ val statementStartPos = currentPos + skipped
+
+ if (trimmed.isEmpty) {
+ return results.reverse
+ }
+
+ // Try to parse a single statement
+ val parseResult = apply(trimmed)
+
+ val (result, endPos) = parseResult match {
+ case Parsed.Success(stmt, index) =>
+ val parsedLength = index.toInt
+ val statementEndPos = statementStartPos + parsedLength
+ val parsedResult: Either[Unparseable, Stmt] = stmt match {
+ case u: Unparseable => Left(u)
+ case _ => Right(stmt.asInstanceOf[Stmt])
+ }
+ (parsedResult, statementEndPos)
+
+ case Parsed.Failure(label, index, extra) =>
+ // Parse failed - extract unparseable content
+ val terminatorPos = findNextTerminator(trimmed, statementTerminatorChars)
+ val unparseableEnd = if (terminatorPos >= 0) terminatorPos + 1 else trimmed.length
+
+ val unparseableText = trimmed.substring(0, unparseableEnd).trim
+ val unparseableEndPos = statementStartPos + unparseableEnd
+ (Left(Unparseable(unparseableText)), unparseableEndPos)
+ }
+
+ val parseResultObj = StatementParseResult(result, statementStartPos, endPos)
+
+ // Continue parsing from the end position
+ val nextRemaining = if (endPos < input.length) {
+ input.substring(endPos)
+ } else {
+ ""
+ }
+
+ parseStatementsSequentially(nextRemaining, endPos, parseResultObj :: results)
+ }
+
+ parseStatementsSequentially(input, 0, Nil)
+ }
+}
diff --git a/src/main/scala/sparsity/parser/StreamParser.scala b/src/main/scala/sparsity/common/parser/StreamParser.scala
similarity index 61%
rename from src/main/scala/sparsity/parser/StreamParser.scala
rename to src/main/scala/sparsity/common/parser/StreamParser.scala
index e3f1b35..9c8fc34 100644
--- a/src/main/scala/sparsity/parser/StreamParser.scala
+++ b/src/main/scala/sparsity/common/parser/StreamParser.scala
@@ -1,42 +1,36 @@
-package sparsity.parser
+package sparsity.common.parser
import scala.collection.mutable.Buffer
import scala.io._
import java.io._
-import fastparse._, NoWhitespace._
-import com.typesafe.scalalogging.LazyLogging
+import fastparse._
+import scribe.Logger
-class StreamParser[R](parser:(Iterator[String] => Parsed[R]), source: Reader)
- extends Iterator[Parsed[R]]
- with LazyLogging
-{
+class StreamParser[R](parser: (Iterator[String] => Parsed[R]), source: Reader) extends Iterator[Parsed[R]] {
+ private val logger = Logger("sparsity.common.parser.StreamParser")
val bufferedSource = source match {
- case b:BufferedReader => b
+ case b: BufferedReader => b
case _ => new BufferedReader(source)
}
val buffer = Buffer[String]()
- def load(): Unit = {
- while(bufferedSource.ready){
+ def load(): Unit =
+ while (bufferedSource.ready)
buffer += bufferedSource.readLine.replace("\\n", " ");
- }
- }
- def loadBlocking(): Unit = {
+ def loadBlocking(): Unit =
buffer += bufferedSource.readLine.replace("\\n", " ");
- }
def hasNext: Boolean = { load(); buffer.size > 0 }
- def next(): Parsed[R] =
- {
+ def next(): Parsed[R] = {
load();
- if(buffer.size > 0) {
+ if (buffer.size > 0) {
parser(buffer.iterator) match {
- case r@Parsed.Success(result, index) =>
+ case r @ Parsed.Success(result, index) =>
logger.debug(s"Parsed(index = $index): $result")
skipBytes(index)
return r
- case f:Parsed.Failure =>
+ case f: Parsed.Failure =>
buffer.clear()
return f
}
@@ -45,29 +39,29 @@ class StreamParser[R](parser:(Iterator[String] => Parsed[R]), source: Reader)
}
}
- def skipBytes(offset: Int) : Unit = {
+ def skipBytes(offset: Int): Unit = {
var dropped = 0
- while(offset > dropped && !buffer.isEmpty){
+ while (offset > dropped && !buffer.isEmpty) {
logger.trace(s"Checking for drop: $dropped / $offset: Next unit: ${buffer.head.length}")
- if(buffer.head.length < (offset - dropped)){
+ if (buffer.head.length < (offset - dropped)) {
dropped = dropped + buffer.head.length
logger.trace(s"Dropping '${buffer.head}' ($dropped / $offset dropped so far)")
buffer.remove(0)
} else {
logger.trace(s"Trimming '${buffer.head}' (by ${offset - dropped})")
var head = buffer.head
- .substring(offset - dropped)
- .replace("^\\w+", "")
+ .substring(offset - dropped)
+ .replace("^\\w+", "")
logger.trace(s"Trimming leading whitespace")
- while(head.length <= 0 && !buffer.isEmpty){
+ while (head.length <= 0 && !buffer.isEmpty) {
logger.trace(s"Nothing but whitespace left. Dropping and trying again")
buffer.remove(0)
- if(!buffer.isEmpty) {
+ if (!buffer.isEmpty) {
head = buffer.head.replace("^\\w+", "")
}
}
logger.trace(s"Remaining = '$head'")
- if(head.length > 0){ buffer.update(0, head) }
+ if (head.length > 0) { buffer.update(0, head) }
return
}
}
diff --git a/src/main/scala/sparsity/common/select/FromElement.scala b/src/main/scala/sparsity/common/select/FromElement.scala
new file mode 100644
index 0000000..6926a41
--- /dev/null
+++ b/src/main/scala/sparsity/common/select/FromElement.scala
@@ -0,0 +1,64 @@
+package sparsity.common.select
+
+import sparsity.common.Name
+import sparsity.common.expression.{BooleanPrimitive, Expression}
+
+sealed abstract class FromElement {
+ def aliases: Seq[Name]
+ def withAlias(newAlias: Name): FromElement
+ def toSqlWithParensIfNeeded: String = toSql
+ def toSql: String
+}
+
+case class FromTable(schema: Option[Name], table: Name, alias: Option[Name]) extends FromElement {
+ def aliases = Seq(alias.getOrElse(table))
+ override def toSql =
+ schema.map(_.toSql + ".").getOrElse("") +
+ table.toSql +
+ alias.map(" AS " + _.toSql).getOrElse("")
+ def withAlias(newAlias: Name) = FromTable(schema, table, Some(newAlias))
+
+}
+case class FromSelect(body: SelectBody, val alias: Name) extends FromElement {
+ def aliases = Seq(alias)
+ override def toSql =
+ "(" + body.toSql + ") AS " + alias.toSql
+ def withAlias(newAlias: Name) = FromSelect(body, newAlias)
+}
+case class FromJoin(
+ lhs: FromElement,
+ rhs: FromElement,
+ t: Join.Type = Join.Inner,
+ on: Expression = BooleanPrimitive(true),
+ alias: Option[Name] = None
+) extends FromElement {
+ def aliases =
+ alias.map(Seq(_)).getOrElse(lhs.aliases ++ rhs.aliases)
+ override def toSqlWithParensIfNeeded: String = "(" + toSql + ")"
+ override def toSql = {
+ val baseString =
+ lhs.toSqlWithParensIfNeeded + " " +
+ Join.toSql(t) + " " +
+ rhs.toSqlWithParensIfNeeded +
+ (on match { case BooleanPrimitive(true) => ""; case _ => " ON " + on.toSql })
+ alias match {
+ case Some(a) => "(" + baseString + ") AS " + a.toSql
+ case None => baseString
+ }
+ }
+ def withAlias(newAlias: Name) = FromJoin(lhs, rhs, t, on, Some(newAlias))
+
+}
+
+object Join extends Enumeration {
+ type Type = Value
+ val Inner, Natural, LeftOuter, RightOuter, FullOuter = Value
+
+ def toSql(t: Type) = t match {
+ case Inner => "JOIN"
+ case Natural => "NATURAL JOIN"
+ case LeftOuter => "LEFT OUTER JOIN"
+ case RightOuter => "RIGHT OUTER JOIN"
+ case FullOuter => "FULL OUTER JOIN"
+ }
+}
diff --git a/src/main/scala/sparsity/common/select/OrderBy.scala b/src/main/scala/sparsity/common/select/OrderBy.scala
new file mode 100644
index 0000000..87b8896
--- /dev/null
+++ b/src/main/scala/sparsity/common/select/OrderBy.scala
@@ -0,0 +1,9 @@
+package sparsity.common.select
+
+import sparsity.common.expression.{Expression, ToSql}
+
+case class OrderBy(expression: Expression, ascending: Boolean) extends ToSql {
+ def descending = !ascending
+ override def toSql = expression.toSql + (if (ascending) { "" }
+ else { " DESC" })
+}
diff --git a/src/main/scala/sparsity/common/select/SelectBody.scala b/src/main/scala/sparsity/common/select/SelectBody.scala
new file mode 100644
index 0000000..3ce138e
--- /dev/null
+++ b/src/main/scala/sparsity/common/select/SelectBody.scala
@@ -0,0 +1,43 @@
+package sparsity.common.select
+
+import sparsity.common.expression.{Expression, ToSql}
+
+case class SelectBody(
+ distinct: Boolean = false,
+ target: Seq[SelectTarget] = Seq(),
+ from: Seq[FromElement] = Seq(),
+ where: Option[Expression] = None,
+ groupBy: Option[Seq[Expression]] = None,
+ having: Option[Expression] = None,
+ orderBy: Seq[OrderBy] = Seq(),
+ limit: Option[Long] = None,
+ offset: Option[Long] = None,
+ union: Option[(Union.Type, SelectBody)] = None
+) extends ToSql {
+ def stringElements: Seq[String] =
+ Seq("SELECT") ++
+ (if (distinct) { Some("DISTINCT") }
+ else { None }) ++
+ Seq(target.map(_.toSql).mkString(", ")) ++
+ (if (from.isEmpty) { None }
+ else { Seq("FROM", from.map(_.toSql).mkString(", ")) }) ++
+ (where.map(x => Seq("WHERE", x.toSql)).toSeq.flatten) ++
+ (groupBy.map(x => Seq("GROUP BY") ++ x.map(_.toSql)).toSeq.flatten) ++
+ (having.map(x => Seq("HAVING", x.toSql)).toSeq.flatten) ++
+ (if (orderBy.isEmpty) { Seq() }
+ else { Seq("ORDER BY", orderBy.map(_.toSql).mkString(", ")) }) ++
+ (limit.map(x => Seq("LIMIT", x.toString)).toSeq.flatten) ++
+ (offset.map(x => Seq("OFFSET", x.toString)).toSeq.flatten) ++
+ (union.map { case (t, b) => Seq(Union.toSql(t)) ++ b.stringElements }.toSeq.flatten)
+ override def toSql = stringElements.mkString(" ")
+
+ def unionWith(t: Union.Type, body: SelectBody): SelectBody = {
+ val replacementUnion =
+ union match {
+ case Some(nested) => (nested._1, nested._2.unionWith(t, body))
+ case None => (t, body)
+ }
+
+ SelectBody(distinct, target, from, where, groupBy, having, orderBy, limit, offset, union = Some(replacementUnion))
+ }
+}
diff --git a/src/main/scala/sparsity/common/select/SelectTarget.scala b/src/main/scala/sparsity/common/select/SelectTarget.scala
new file mode 100644
index 0000000..738785d
--- /dev/null
+++ b/src/main/scala/sparsity/common/select/SelectTarget.scala
@@ -0,0 +1,17 @@
+package sparsity.common.select
+
+import sparsity.common.Name
+import sparsity.common.expression.{Expression, ToSql}
+
+sealed abstract class SelectTarget extends ToSql
+
+case class SelectAll() extends SelectTarget {
+ override def toSql = "*"
+}
+case class SelectTable(table: Name) extends SelectTarget {
+ override def toSql = table.toSql + ".*"
+}
+case class SelectExpression(expression: Expression, alias: Option[Name] = None) extends SelectTarget {
+ override def toSql =
+ expression.toSql ++ alias.map(" AS " + _.toSql).getOrElse("")
+}
diff --git a/src/main/scala/sparsity/select/Union.scala b/src/main/scala/sparsity/common/select/Union.scala
similarity index 70%
rename from src/main/scala/sparsity/select/Union.scala
rename to src/main/scala/sparsity/common/select/Union.scala
index a5cfb33..c0b688a 100644
--- a/src/main/scala/sparsity/select/Union.scala
+++ b/src/main/scala/sparsity/common/select/Union.scala
@@ -1,11 +1,11 @@
-package sparsity.select
+package sparsity.common.select
object Union extends Enumeration {
type Type = Value
val All, Distinct = Value
- def toString(t: Type) = t match {
+ def toSql(t: Type) = t match {
case All => "UNION ALL"
case Distinct => "UNION DISTINCT"
}
-}
\ No newline at end of file
+}
diff --git a/src/main/scala/sparsity/common/statement/ColumnDefinition.scala b/src/main/scala/sparsity/common/statement/ColumnDefinition.scala
new file mode 100644
index 0000000..1ad4834
--- /dev/null
+++ b/src/main/scala/sparsity/common/statement/ColumnDefinition.scala
@@ -0,0 +1,32 @@
+package sparsity.common.statement
+
+import sparsity.common.expression.{Expression, PrimitiveValue, ToSql}
+import sparsity.common.Name
+
+sealed abstract class ColumnAnnotation extends ToSql
+
+case class ColumnIsPrimaryKey() extends ColumnAnnotation {
+ override def toSql = "PRIMARY KEY"
+}
+case class ColumnIsNotNullable() extends ColumnAnnotation {
+ override def toSql = "NOT NULL"
+}
+case class ColumnIsNullable() extends ColumnAnnotation {
+ override def toSql = "NULL"
+}
+case class ColumnDefaultValue(v: Expression) extends ColumnAnnotation {
+ override def toSql = "DEFAULT VALUE " + v.toSql
+}
+
+case class ColumnDefinition(
+ name: Name,
+ t: Name,
+ args: Seq[PrimitiveValue] = Seq(),
+ annotations: Seq[ColumnAnnotation] = Seq()
+) {
+ def toSql =
+ name.toSql + " " +
+ t.toSql + (if (args.isEmpty) { "" }
+ else { "(" + args.map(_.toSql).mkString(", ") + ")" }) +
+ annotations.map(" " + _.toSql).mkString
+}
diff --git a/src/main/scala/sparsity/common/statement/InsertTarget.scala b/src/main/scala/sparsity/common/statement/InsertTarget.scala
new file mode 100644
index 0000000..8e6d7a2
--- /dev/null
+++ b/src/main/scala/sparsity/common/statement/InsertTarget.scala
@@ -0,0 +1,15 @@
+package sparsity.common.statement
+
+import sparsity.common.expression.Expression
+import sparsity.common.select.SelectBody
+import sparsity.common.expression.ToSql
+
+sealed abstract class InsertValues extends ToSql
+
+case class ExplicitInsert(values: Seq[Seq[Expression]]) extends InsertValues {
+ override def toSql =
+ "VALUES " + values.map("(" + _.map(_.toSql).mkString(", ") + ")").mkString(", ")
+}
+case class SelectInsert(query: SelectBody) extends InsertValues {
+ override def toSql = query.toSql
+}
diff --git a/src/main/scala/sparsity/common/statement/Statement.scala b/src/main/scala/sparsity/common/statement/Statement.scala
new file mode 100644
index 0000000..fa87cf1
--- /dev/null
+++ b/src/main/scala/sparsity/common/statement/Statement.scala
@@ -0,0 +1,123 @@
+package sparsity.common.statement
+
+import sparsity.common.Name
+import sparsity.common.expression.Expression
+import sparsity.common.select.SelectBody
+import sparsity.common.alter._
+
+abstract class Statement {
+ def toSql: String
+}
+
+/** Empty statement (whitespace/comments only). Used for comment-only statements or empty statements with just
+ * terminators.
+ */
+case class EmptyStatement() extends Statement {
+ override def toSql = ""
+}
+
+/** Unparseable statement (parse error recovery). Captures statements that failed to parse, allowing the parser to
+ * continue.
+ */
+case class Unparseable(content: String) extends Statement {
+ override def toSql = content
+}
+
+case class Select(body: SelectBody, withClause: Seq[WithClause] = Seq()) extends Statement {
+ override def toSql =
+ (if (withClause.isEmpty) { "" }
+ else { "WITH " + withClause.map(_.toSql).mkString(",\n") + "\n" }) +
+ body.toSql + ";"
+}
+
+case class Update(table: Name, set: Seq[(Name, Expression)], where: Option[Expression]) extends Statement {
+ override def toSql =
+ s"UPDATE ${table.toSql} SET ${set.map(x => x._1.toSql + " = " + x._2.toSql).mkString(", ")}" +
+ where.map(w => " WHERE " + w.toSql).getOrElse("") + ";"
+}
+
+case class Delete(table: Name, where: Option[Expression]) extends Statement {
+ override def toSql =
+ s"DELETE FROM ${table.toSql}" +
+ where.map(w => " WHERE " + w.toSql).getOrElse("") + ";"
+}
+
+case class Insert(table: Name, columns: Option[Seq[Name]], values: InsertValues, orReplace: Boolean) extends Statement {
+ override def toSql =
+ s"INSERT${if (orReplace) { " OR REPLACE " }
+ else { "" }} INTO ${table.toSql}" +
+ columns.map("(" + _.map(_.toSql).mkString(", ") + ")").getOrElse("") +
+ " " + values.toSql + ";"
+
+}
+
+case class CreateTable(
+ name: Name,
+ orReplace: Boolean,
+ columns: Seq[ColumnDefinition],
+ annotations: Seq[TableAnnotation]
+) extends Statement {
+ override def toSql =
+ s"CREATE ${if (orReplace) { "OR REPLACE " }
+ else { "" }}TABLE ${name.toSql}(" +
+ (columns.map(_.toSql) ++
+ annotations.map(_.toSql)).mkString(", ") +
+ ");"
+}
+
+case class CreateTableAs(name: Name, orReplace: Boolean, query: SelectBody) extends Statement {
+ override def toSql =
+ s"CREATE ${if (orReplace) { "OR REPLACE " }
+ else { "" }}TABLE ${name.toSql} AS ${query.toSql};"
+}
+
+case class CreateView(
+ name: Name,
+ orReplace: Boolean,
+ query: SelectBody,
+ materialized: Boolean = false,
+ temporary: Boolean = false
+) extends Statement {
+ override def toSql =
+ "CREATE " +
+ (if (orReplace) { "OR REPLACE " }
+ else { "" }) +
+ (if (materialized) { "MATERIALIZED " }
+ else { "" }) +
+ s"VIEW ${name.toSql} AS ${query.toSql};"
+}
+
+case class AlterView(name: Name, action: AlterViewAction) extends Statement {
+ override def toSql = s"ALTER VIEW ${name.toSql} ${action.toSql}"
+}
+
+case class DropTable(name: Name, ifExists: Boolean) extends Statement {
+ override def toSql =
+ "DROP TABLE " + (
+ if (ifExists) { "IF EXISTS " }
+ else { "" }
+ ) + name.toSql + ";"
+}
+
+case class DropView(name: Name, ifExists: Boolean) extends Statement {
+ override def toSql =
+ "DROP VIEW " + (
+ if (ifExists) { "IF EXISTS " }
+ else { "" }
+ ) + name.toSql + ";"
+}
+
+case class Explain(query: SelectBody) extends Statement {
+ override def toSql = s"EXPLAIN ${query.toSql};"
+}
+
+class CreateIndex(val name: Name, val table: Name, val columns: Seq[Name], val unique: Boolean = false)
+ extends Statement {
+ override def toSql = {
+ val uniqueStr = if (unique) "UNIQUE " else ""
+ s"CREATE ${uniqueStr}INDEX ${name.toSql} ON ${table.toSql}(${columns.map(_.toSql).mkString(", ")});"
+ }
+}
+
+/** Result of parsing a single statement with position information */
+case class StatementParseResult(result: Either[Unparseable, Statement], startPos: Int, endPos: Int)
diff --git a/src/main/scala/sparsity/common/statement/TableAnnotation.scala b/src/main/scala/sparsity/common/statement/TableAnnotation.scala
new file mode 100644
index 0000000..a17dcc2
--- /dev/null
+++ b/src/main/scala/sparsity/common/statement/TableAnnotation.scala
@@ -0,0 +1,16 @@
+package sparsity.common.statement
+
+import sparsity.common.Name
+import sparsity.common.expression.ToSql
+
+sealed abstract class TableAnnotation extends ToSql
+
+case class TablePrimaryKey(columns: Seq[Name]) extends TableAnnotation {
+ override def toSql = s"PRIMARY KEY (${columns.map(_.toSql).mkString(", ")})"
+}
+case class TableIndexOn(columns: Seq[Name]) extends TableAnnotation {
+ override def toSql = s"INDEX (${columns.map(_.toSql).mkString(", ")})"
+}
+case class TableUnique(columns: Seq[Name]) extends TableAnnotation {
+ override def toSql = s"UNIQUE (${columns.map(_.toSql).mkString(", ")})"
+}
diff --git a/src/main/scala/sparsity/common/statement/WithClause.scala b/src/main/scala/sparsity/common/statement/WithClause.scala
new file mode 100644
index 0000000..a4c468a
--- /dev/null
+++ b/src/main/scala/sparsity/common/statement/WithClause.scala
@@ -0,0 +1,9 @@
+package sparsity.common.statement
+
+import sparsity.common.Name
+import sparsity.common.select.SelectBody
+import sparsity.common.expression.ToSql
+
+case class WithClause(body: SelectBody, name: Name) extends ToSql {
+ override def toSql = s"${name.toSql} AS (${body.toSql})"
+}
diff --git a/src/main/scala/sparsity/expression/Expression.scala b/src/main/scala/sparsity/expression/Expression.scala
deleted file mode 100644
index d89f198..0000000
--- a/src/main/scala/sparsity/expression/Expression.scala
+++ /dev/null
@@ -1,363 +0,0 @@
-package sparsity.expression
-
-import sparsity.Name
-
-sealed abstract class Expression
-{
- def needsParenthesis: Boolean
- def children: Seq[Expression]
- def rebuild(newChildren: Seq[Expression]): Expression
-}
-
-object Expression
-{
- def parenthesize(e: Expression) =
- if(e.needsParenthesis) { "(" +e.toString +")" } else { e.toString }
- def escapeString(s: String) =
- s.replaceAll("'", "''")
-}
-
-sealed abstract class PrimitiveValue extends Expression
-{
- def needsParenthesis = false
- def children:Seq[Expression] = Seq()
- def rebuild(newChildren: Seq[Expression]):Expression = this
-}
-case class LongPrimitive(v: Long) extends PrimitiveValue
- { override def toString = v.toString }
-case class DoublePrimitive(v: Double) extends PrimitiveValue
- { override def toString = v.toString }
-case class StringPrimitive(v: String) extends PrimitiveValue
- { override def toString = "'"+Expression.escapeString(v.toString)+"'" }
-case class BooleanPrimitive(v: Boolean) extends PrimitiveValue
- { override def toString = v.toString }
-case class NullPrimitive() extends PrimitiveValue
- { override def toString = "NULL" }
-
-/**
- * A column reference 'Table.Col'
- */
-case class Column(column: Name, table: Option[Name] = None) extends Expression
-{
- override def toString = (table.toSeq ++ Seq(column)).mkString(".")
- def needsParenthesis = false
- def children:Seq[Expression] = Seq()
- def rebuild(newChildren: Seq[Expression]):Expression = this
-}
-/**
- * Any simple binary arithmetic expression
- * See Arithmetic.scala for an enumeration of the possibilities
- * *not* a case class to avoid namespace collisions.
- * Arithmetic defines apply/unapply explicitly
- */
-class Arithmetic(
- val lhs: Expression,
- val op: Arithmetic.Op,
- val rhs: Expression
-) extends Expression
-{
- override def toString =
- Expression.parenthesize(lhs)+" "+
- Arithmetic.opString(op)+" "+
- Expression.parenthesize(rhs)
- def needsParenthesis = true
- override def equals(other:Any): Boolean =
- other match {
- case Arithmetic(otherlhs, otherop, otherrhs) => lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs)
- case _ => false
- }
- def children:Seq[Expression] = Seq(lhs, rhs)
- def rebuild(c: Seq[Expression]): Expression = Arithmetic(c(0), op, c(1))
-}
-
-object Arithmetic extends Enumeration {
- type Op = Value
- val Add, Sub, Mult, Div, And, Or, BitAnd, BitOr, ShiftLeft, ShiftRight = Value
-
- def apply(lhs: Expression, op: Op, rhs: Expression) =
- new Arithmetic(lhs, op, rhs)
- def apply(lhs: Expression, op: String, rhs: Expression) =
- new Arithmetic(lhs, fromString(op), rhs)
-
- def unapply(e:Arithmetic): Option[(Expression, Op, Expression)] =
- Some( (e.lhs, e.op, e.rhs) )
-
- /**
- * Regular expresion to match any and all binary operations
- */
- def matchRegex = """\+|-|\*|/|\|\||&&|\||&""".r
-
- /**
- * Convert from the operator's string encoding to its Arith.Op rep
- */
- def fromString(a: String) = {
- a.toUpperCase match {
- case "+" => Add
- case "-" => Sub
- case "*" => Mult
- case "/" => Div
- case "&" => BitAnd
- case "|" => BitOr
- case "<<" => ShiftLeft
- case ">>" => ShiftRight
- case "&&" => And
- case "||" => Or
- case "AND" => And
- case "OR" => And
- case x => throw new Exception("Invalid operand '"+x+"'")
- }
- }
- /**
- * Convert from the operator's Arith.Op representation to a string
- */
- def opString(v: Op): String =
- {
- v match {
- case Add => "+"
- case Sub => "-"
- case Mult => "*"
- case Div => "/"
- case BitAnd => "&"
- case BitOr => "|"
- case ShiftLeft => "<<"
- case ShiftRight => ">>"
- case And => "AND"
- case Or => "OR"
- }
- }
- /**
- * Is this binary operation a boolean operator (AND/OR)
- */
- def isBool(v: Op): Boolean =
- {
- v match {
- case And | Or => true
- case _ => false
- }
- }
-
- /**
- * Is this binary operation a numeric operator (+, -, *, /, & , |)
- */
- def isNumeric(v: Op): Boolean = !isBool(v)
-
-}
-
-/**
- * Any simple comparison expression
- * See Comparison.scala for an enumeration of the possibilities
- * *not* a case class to avoid namespace collisions.
- * Comparison defines apply/unapply explicitly
- */
-class Comparison(
- val lhs: Expression,
- val op: Comparison.Op,
- val rhs: Expression
-) extends Expression
-{
- override def toString =
- Expression.parenthesize(lhs)+" "+
- Comparison.opString(op)+" "+
- Expression.parenthesize(rhs)
- def needsParenthesis = true
- override def equals(other:Any): Boolean =
- other match {
- case Comparison(otherlhs, otherop, otherrhs) => lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs)
- case _ => false
- }
- def children:Seq[Expression] = Seq(lhs, rhs)
- def rebuild(c: Seq[Expression]): Expression = Comparison(c(0), op, c(1))
-}
-object Comparison extends Enumeration {
- type Op = Value
- val Eq, Neq, Gt, Lt, Gte, Lte, Like, NotLike, RLike, NotRLike = Value
-
- val strings = Map(
- "=" -> Eq,
- "==" -> Eq,
- "!=" -> Neq,
- "<>" -> Neq,
- ">" -> Gt,
- "<" -> Lt,
- ">=" -> Gte,
- "<=" -> Lte,
- "LIKE" -> Like, // SQL-style LIKE expression
- "RLIKE" -> RLike, // Regular expression lookup
- "NOT LIKE" -> NotLike, // Inverse LIKE
- "NOT RLIKE" -> RLike // Inverse NOT LIKE
- )
-
- def apply(lhs: Expression, op: Op, rhs: Expression) =
- new Comparison(lhs, op, rhs)
- def apply(lhs: Expression, op: String, rhs: Expression) =
- new Comparison(lhs, strings(op.toUpperCase), rhs)
-
- def unapply(e:Comparison): Option[(Expression, Op, Expression)] =
- Some( (e.lhs, e.op, e.rhs) )
-
- def negate(v: Op): Op =
- {
- v match {
- case Eq => Neq
- case Neq => Eq
- case Gt => Lte
- case Gte => Lt
- case Lt => Gte
- case Lte => Gt
- case Like => NotLike
- case NotLike => Like
- case RLike => NotRLike
- case NotRLike => RLike
- }
- }
-
- def flip(v: Op): Option[Op] =
- {
- v match {
- case Eq => Some(Eq)
- case Neq => Some(Neq)
- case Gt => Some(Lt)
- case Gte => Some(Lte)
- case Lt => Some(Gt)
- case Lte => Some(Gte)
- case Like => None
- case NotLike => None
- case RLike => None
- case NotRLike => None
- }
- }
-
- def opString(v: Op): String =
- {
- v match {
- case Eq => "="
- case Neq => "<>"
- case Gt => ">"
- case Gte => ">="
- case Lt => "<"
- case Lte => "<="
- case Like => "LIKE"
- case NotLike => "NOT LIKE"
- case RLike => "RLIKE"
- case NotRLike => "NOT RLIKE"
- }
- }
-}
-
-case class Function(name: Name, params: Option[Seq[Expression]], distinct: Boolean = false) extends Expression
-{
- override def toString =
- name.toString + "(" +
- (if(distinct){ "DISTINCT " } else { "" }) +
- params.map{ _.mkString(", ")}.getOrElse("*")+
- ")"
- def needsParenthesis = false
- def children:Seq[Expression] = params.getOrElse { Seq() }
- def rebuild(c: Seq[Expression]): Expression =
- Function(name, params.map { _ => c }, distinct)
-}
-
-case class JDBCVar() extends Expression
-{
- override def toString = "?"
- def needsParenthesis = false
- def children: Seq[Expression] = Seq()
- def rebuild(c: Seq[Expression]): Expression = this
-}
-
-case class CaseWhenElse(
- target:Option[Expression],
- cases:Seq[(Expression, Expression)],
- otherwise:Expression
-) extends Expression
-{
- override def toString =
- {
- "CASE "+
- target.map { _.toString }.getOrElse("")+
- cases.map { clause =>
- "WHEN "+Expression.parenthesize(clause._1)+
- " THEN "+Expression.parenthesize(clause._2)
- }.mkString(" ")+" " +
- "ELSE "+ otherwise.toString+" END"
- }
- def needsParenthesis = false
- def children: Seq[Expression] =
- target.toSeq ++ Seq(otherwise) ++ cases.flatMap { x => Seq(x._1, x._2) }
- def rebuild(c: Seq[Expression]): Expression =
- {
- val (newTarget, newOtherwise, newCases) =
- if(c.length % 2 == 0){ (Some(c.head), c.tail.head, c.tail.tail) }
- else { (None, c.head, c.tail) }
- CaseWhenElse(
- newTarget,
- newCases.grouped(2).map { x => (x(0), x(1)) }.toSeq,
- newOtherwise
- )
- }
-}
-
-abstract class NegatableExpression extends Expression
-{
- def toNegatedString: String
-}
-
-case class IsNull(target: Expression) extends NegatableExpression
-{
- override def toString =
- Expression.parenthesize(target)+" IS NULL"
- def toNegatedString =
- Expression.parenthesize(target)+" IS NOT NULL"
- def needsParenthesis = false
- def children: Seq[Expression] = Seq(target)
- def rebuild(c: Seq[Expression]): Expression = IsNull(c(0))
-}
-
-case class Not(target: Expression) extends Expression
-{
- override def toString =
- target match {
- case neg:NegatableExpression => neg.toNegatedString
- case _ => "NOT "+Expression.parenthesize(target)
- }
- def needsParenthesis = false
- def children: Seq[Expression] = Seq(target)
- def rebuild(c: Seq[Expression]): Expression = Not(c(0))
-}
-
-case class Cast(expression: Expression, t: Name) extends Expression
-{
- override def toString = "CAST("+expression.toString+" AS "+t+")"
- def needsParenthesis = false
- def children: Seq[Expression] = Seq(expression)
- def rebuild(c: Seq[Expression]): Expression = Cast(c(0), t)
-}
-
-case class InExpression(
- expression: Expression,
- source: Either[Seq[Expression], sparsity.select.SelectBody]
-) extends NegatableExpression
-{
- override def toString =
- Expression.parenthesize(expression) + " IN " + sourceString
- override def toNegatedString =
- Expression.parenthesize(expression) + " NOT IN " + sourceString
- def needsParenthesis = false
- def sourceString =
- source match {
- case Left(elems) => elems.map { Expression.parenthesize(_) }.mkString(", ")
- case Right(query) => "("+query.toString+")"
- }
- def children: Seq[Expression] =
- Seq(expression) ++ (source match {
- case Left(expr) => expr
- case Right(_) => Seq()
- })
- def rebuild(c: Seq[Expression]): Expression =
- InExpression(c.head,
- source match {
- case Left(_) => Left(c.tail)
- case Right(query) => Right(query)
- }
- )
-}
\ No newline at end of file
diff --git a/src/main/scala/sparsity/oracle/OracleSQL.scala b/src/main/scala/sparsity/oracle/OracleSQL.scala
new file mode 100644
index 0000000..a4d6b91
--- /dev/null
+++ b/src/main/scala/sparsity/oracle/OracleSQL.scala
@@ -0,0 +1,1225 @@
+package sparsity.oracle
+
+import fastparse._
+import fastparse.ParsingRun
+import fastparse.internal.{Msgs, Util}
+import scala.io._
+import java.io._
+import sparsity.common.statement._
+import sparsity.oracle._
+import sparsity.common.select.{SelectBody, _}
+import sparsity.common.alter.Materialize
+import sparsity.common.expression._
+import sparsity.common.{
+ BigInt,
+ Blob,
+ Boolean,
+ Char,
+ Clob,
+ DataType,
+ Date,
+ Decimal,
+ Double,
+ Float,
+ Integer,
+ NVarChar2,
+ Name,
+ Number,
+ Raw,
+ Text,
+ Timestamp,
+ VarChar,
+ VarChar2
+}
+import sparsity.common.parser.{SQLBase, SQLBaseObject, StreamParser}
+import fastparse.CharIn
+
+/** Custom whitespace parser that includes SQL comments (-- ... and /* ... */). This allows comments to be automatically
+ * skipped anywhere whitespace is allowed. Based on FastParse's JavaWhitespace pattern.
+ */
+object OracleWhitespace {
+ implicit object whitespace extends Whitespace {
+ def apply(ctx: ParsingRun[_]) = {
+ val input = ctx.input
+ @scala.annotation.tailrec
+ def rec(current: Int, state: Int): ParsingRun[Unit] =
+ if (!input.isReachable(current)) {
+ if (state == 0 || state == 2) {
+ // Normal whitespace or inside -- comment - both are valid at EOF
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current, Msgs.empty)
+ ctx.freshSuccessUnit(current)
+ } else if (state == 1 || state == 3) {
+ // After first '-' or '/' but not a comment - return success at previous position
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty)
+ ctx.freshSuccessUnit(current - 1)
+ } else {
+ // Inside /* */ comment but reached EOF - this is an error
+ ctx.cut = true
+ val res = ctx.freshFailure(current)
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current, () => Util.literalize("*/"))
+ res
+ }
+ } else {
+ val currentChar = input(current)
+ (state: @scala.annotation.switch) match {
+ case 0 =>
+ // Normal whitespace mode
+ (currentChar: @scala.annotation.switch) match {
+ case ' ' | '\t' | '\n' | '\r' => rec(current + 1, state)
+ case '-' => rec(current + 1, state = 1) // Might be start of -- comment
+ case '/' => rec(current + 1, state = 3) // Might be start of /* comment
+ case _ =>
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current, Msgs.empty)
+ ctx.freshSuccessUnit(current)
+ }
+ case 1 =>
+ // After first '-', check if second '-' follows
+ if (currentChar == '-') {
+ rec(current + 1, state = 2) // Start of -- comment
+ } else {
+ // Not a comment, just a single '-', return success at previous position
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty)
+ ctx.freshSuccessUnit(current - 1)
+ }
+ case 2 =>
+ // Inside -- comment, consume until newline
+ rec(current + 1, state = if (currentChar == '\n') 0 else state)
+ case 3 =>
+ // After first '/', check if '*' follows
+ (currentChar: @scala.annotation.switch) match {
+ case '*' => rec(current + 1, state = 4) // Start of /* comment
+ case _ =>
+ // Not a comment, just a single '/', return success at previous position
+ if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty)
+ ctx.freshSuccessUnit(current - 1)
+ }
+ case 4 =>
+ // Inside /* */ comment, waiting for '*'
+ rec(current + 1, state = if (currentChar == '*') 5 else state)
+ case 5 =>
+ // Inside /* */ comment, after '*', checking for '/'
+ (currentChar: @scala.annotation.switch) match {
+ case '/' => rec(current + 1, state = 0) // End of /* comment
+ case '*' => rec(current + 1, state = 5) // Stay in state 5 if another '*'
+ case _ => rec(current + 1, state = 4) // Go back to state 4 if not '/'
+ }
+ }
+ }
+ rec(current = ctx.index, state = 0)
+ }
+ }
+}
+
+/** Oracle SQL dialect implementation. Supports Oracle-specific features including PROMPT, VARCHAR2, NUMBER types, and
+ * Oracle CREATE TABLE clauses (TABLESPACE, STORAGE, etc.). For reference, see
+ * https://github.com/antlr/grammars-v4/blob/master/sql/plsql/PlSqlParser.g4
+ */
+class OracleSQL extends SQLBase {
+ // Configuration
+ type Stmt = Statement
+ def caseSensitive = false
+ def statementTerminator = "/" // Oracle supports both "/" and ";"
+ def supportsCTEs = true
+ def supportsReturning = true
+ def supportsIfExists = true
+ def stringConcatOp = "||"
+
+ // Oracle-specific config
+ def supportsPrompt = true
+ def supportsVarchar2 = true
+
+ override implicit val whitespaceImplicit: fastparse.Whitespace = OracleWhitespace.whitespace
+
+ // Override multDivOp to prevent "/" from being parsed as division when it's a statement terminator.
+ // In Oracle, "/" on its own line is a statement terminator, not division.
+ // Division "/" must be followed by an expression operand (identifier, number, paren, etc.)
+ // Statement terminator "/" is followed by newline/whitespace leading to EOF or next statement.
+ // We use negative lookahead: "/" is NOT division if followed by whitespace then a statement keyword.
+ override def multDivOp[$ : P]: P[Unit] = P(
+ "*" | ("/" ~ !(&(
+ // After whitespace (handled by implicit whitespace parser), check we're NOT at:
+ // - End of input
+ // - A statement-starting keyword
+ End |
+ keyword("CREATE") | keyword("SELECT") | keyword("INSERT") | keyword("UPDATE") |
+ keyword("DELETE") | keyword("DROP") | keyword("ALTER") | keyword("GRANT") |
+ keyword("REVOKE") | keyword("COMMIT") | keyword("EXPLAIN") | keyword("WITH") |
+ keyword("BEGIN") | keyword("DECLARE") | keyword("SET") | keyword("PROMPT") |
+ keyword("RENAME") | keyword("COMMENT")
+ )))
+ )
+
+ // Helper parser for digits that allows whitespace consumption
+ // Note: We don't override base digits since it returns P[Unit], but we need P[String] here
+ def oracleDigits[$ : P]: P[String] = P(CharsWhileIn("0-9").!)
+
+ // Helper parser for NUMBER precision/scale that allows * for unspecified precision
+ def numberPrecision[$ : P]: P[String] = P("*".! | oracleDigits)
+
+ // Oracle-specific identifier parser that works with OracleWhitespace
+ // Override base identifier to use Oracle-specific parsing
+ // Note: In Oracle, keywords can be used as identifiers in many contexts (especially after dots)
+ // So we allow keywords as identifiers, but prefer quoted identifiers when ambiguous
+ override def identifier[$ : P]: P[Name] = P(
+ (("`" ~~/ CharsWhile(_ != '`').! ~ "`")
+ | ("\"" ~~/ CharsWhile(_ != '"').! ~~ "\"")).map(Name(_, true: scala.Boolean))
+ | (CharIn("_a-zA-Z") ~~ CharsWhileIn("a-zA-Z0-9_").?).!.map(Name(_))
+ )
+
+ // Oracle-specific parser for table names with optional database links (schema.table@link)
+ def oracleQualifiedTableName[$ : P]: P[Name] = P((identifier ~~ ("." ~/ identifier).? ~~ ("@" ~/ identifier).?).map {
+ case (x: Name, y: Option[Name], link: Option[Name]) =>
+ (x, y, link) match {
+ case (x, None, None) => x
+ case (x, Some(y), None) => Name(x.name + "." + y.name, x.quoted || y.quoted)
+ case (x, None, Some(link)) => Name(x.name + "@" + link.name, x.quoted || link.quoted)
+ case (x, Some(y), Some(link)) =>
+ Name(x.name + "." + y.name + "@" + link.name, x.quoted || y.quoted || link.quoted)
+ }
+ })
+
+ // Oracle-specific parser for dotted pairs with optional database links (schema.table@link)
+ def oracleDottedPair[$ : P]: P[(Option[Name], Name)] = P(
+ (identifier ~~ ("." ~/ identifier).? ~~ ("@" ~/ identifier).?).map {
+ case (x: Name, y: Option[Name], link: Option[Name]) =>
+ (x, y, link) match {
+ case (x, None, None) => (None, x)
+ case (x, Some(y), None) => (Some(x), y)
+ case (x, None, Some(link)) => (None, Name(x.name + "@" + link.name, x.quoted || link.quoted))
+ case (x, Some(y), Some(link)) =>
+ (Some(x), Name(y.name + "@" + link.name, y.quoted || link.quoted))
+ }
+ }
+ )
+
+ // Override alias to use rawIdentifier which checks for keywords
+ // This prevents clause keywords like WHERE from being parsed as aliases
+ override def alias[$ : P]: P[Name] = P(keyword("AS").? ~ rawIdentifier)
+
+ // Override simpleFromElement to support Oracle database links
+ // Note: For subqueries, alias is optional in Oracle (unlike base class which requires it)
+ override def simpleFromElement[$ : P]: P[FromElement] = P(
+ (("(" ~ select ~ ")" ~ alias.?).map {
+ case (body, Some(alias)) => FromSelect(body, alias)
+ case (body, None) =>
+ FromSelect(body, Name("")) // Anonymous alias for subqueries without alias
+ })
+ | ((oracleDottedPair ~ alias.?).map { case (schema, table, alias) =>
+ FromTable(schema, table, alias)
+ })
+ | (("(" ~ fromElement ~ ")" ~ alias.?).map {
+ case (from, None) => from
+ case (from, Some(alias)) => from.withAlias(alias)
+ })
+ )
+
+ // Override column to use oracleDottedPair for proper Oracle identifier parsing
+ override def column[$ : P]: P[Column] = P(oracleDottedPair.map(x => Column(x._2, x._1)))
+
+ // Override leaf to allow keywords as function names (e.g., REPLACE(...))
+ override def leaf[$ : P]: P[Expression] = P(
+ parens |
+ primitive |
+ jdbcvar |
+ caseWhen | ifThenElse |
+ cast |
+ nullLiteral |
+ // Window functions: function(...) OVER (...) - parse function then check for OVER
+ (function ~ (keyword("OVER") ~/
+ "(" ~/
+ (
+ (keyword("PARTITION") ~/ keyword("BY") ~/ expressionList).? ~
+ (keyword("ORDER") ~/ keyword("BY") ~/ orderBy
+ .rep(sep = comma, min = 1)
+ .map(_.toSeq)).?
+ ) ~
+ ")").?).map {
+ case (func, Some((partitionByOpt, orderByOpt))) =>
+ WindowFunction(func, partitionByOpt, orderByOpt.getOrElse(Seq()))
+ case (func, None) =>
+ func
+ } |
+ // Try column first to avoid conflicts with quoted identifiers
+ column |
+ // Allow keywords as function names when followed by (
+ // Use lookahead that explicitly excludes quoted identifiers (starts with " or `)
+ &(!CharIn("\"`") ~ CharIn("_a-zA-Z") ~ CharsWhileIn("a-zA-Z0-9_").? ~ "(") ~ function
+ )
+
+ // Override parens to handle subqueries: (SELECT ...)
+ override def parens[$ : P]: P[Expression] = P(
+ // Try subquery first: (SELECT ...)
+ ("(" ~/ &(keyword("SELECT")) ~/ select ~ ")").map(Subquery(_)) |
+ // Fall back to regular parenthesized expression: (expression)
+ ("(" ~ expression ~ ")")
+ )
+
+ // Window function parser: function(...) OVER (PARTITION BY ... ORDER BY ...)
+ def windowFunction[$ : P]: P[WindowFunction] = P(
+ function ~
+ keyword("OVER") ~/
+ "(" ~/
+ (
+ (keyword("PARTITION") ~/ keyword("BY") ~/ expressionList).? ~
+ (keyword("ORDER") ~/ keyword("BY") ~/ orderBy
+ .rep(sep = comma, min = 1)
+ .map(_.toSeq)).?
+ ) ~
+ ")"
+ ).map { case (func, (partitionByOpt, orderByOpt)) =>
+ WindowFunction(func, partitionByOpt, orderByOpt.getOrElse(Seq()))
+ }
+
+ // Override function to allow keywords as function names
+ // But explicitly reject quoted identifiers (they should be columns, not functions)
+ override def function[$ : P]: P[Function] = P(
+ // First check that we don't have a quoted identifier
+ !CharIn("\"`") ~
+ (identifier ~ "(" ~/
+ keyword("DISTINCT").!.?.map {
+ _ != None
+ } ~
+ ("*".!.map(_ => None)
+ | expressionList.map(Some(_))) ~ ")").map { case (name, distinct, args) =>
+ Function(name, args, distinct)
+ }
+ )
+
+ // Parser implementations
+ override def dialectSpecificStatement[$ : P]: P[OracleStatement] = P(
+ prompt | commit | setScanOff | setScanOn | setDefineOff | grant | revoke | dropSynonym | rename | comment
+ )
+
+ def prompt[$ : P]: P[Prompt] = P(
+ keyword("PROMPT") ~/ CharsWhile(c => c != '\n' && c != '/' && c != ';').!.map(_.trim)
+ .map(Prompt(_))
+ )
+
+ def commit[$ : P]: P[Commit] = P(keyword("COMMIT") ~/ Pass.map(_ => Commit()))
+
+ def setDefineOff[$ : P]: P[SetDefineOff] = P(
+ keyword("SET") ~ &(keyword("DEFINE")) ~ keyword("DEFINE") ~ keyword("OFF") ~ Pass.map(_ => SetDefineOff())
+ )
+
+ def setScanOff[$ : P]: P[SetScanOff] = P(
+ keyword("SET") ~ StringInIgnoreCase("SCAN") ~ StringInIgnoreCase("OFF") ~ Pass.map(_ => SetScanOff())
+ )
+
+ def setScanOn[$ : P]: P[SetScanOn] = P(
+ keyword("SET") ~ StringInIgnoreCase("SCAN") ~ StringInIgnoreCase("ON") ~ Pass.map(_ => SetScanOn())
+ )
+
+ def grant[$ : P]: P[Grant] = P(
+ keyword("GRANT") ~/
+ grantPrivilegeList ~
+ keyword("ON") ~/
+ oracleQualifiedTableName ~
+ keyword("TO") ~/
+ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~
+ (keyword("WITH") ~/ keyword("GRANT") ~/ StringInIgnoreCase("OPTION")).!.?.map(_.isDefined)
+ ).map { case (privileges, onObject, toUsers, withGrantOption) =>
+ Grant(privileges, onObject, toUsers, withGrantOption)
+ }
+
+ def revoke[$ : P]: P[Revoke] = P(
+ keyword("REVOKE") ~/
+ grantPrivilegeList ~
+ keyword("ON") ~/
+ oracleQualifiedTableName ~
+ keyword("FROM") ~/
+ identifier.rep(sep = comma, min = 1).map(_.toSeq)
+ ).map { case (privileges, onObject, fromUsers) =>
+ Revoke(privileges, onObject, fromUsers)
+ }
+
+ def grantPrivilegeList[$ : P]: P[Seq[String]] = P(grantPrivilege ~ (comma ~/ grantPrivilege).rep).map {
+ case (first, rest) => Seq(first) ++ rest.toSeq
+ }
+
+ def grantPrivilege[$ : P]: P[String] = P(
+ // Multi-word privileges first (must come before single-word to avoid partial matches)
+ // Parse "ON COMMIT REFRESH" as a single unit
+ (StringInIgnoreCase("ON") ~/ StringInIgnoreCase("COMMIT") ~/ StringInIgnoreCase("REFRESH")).!.map(_ =>
+ "ON COMMIT REFRESH"
+ ) |
+ (StringInIgnoreCase("QUERY") ~/ StringInIgnoreCase("REWRITE")).!.map(_ => "QUERY REWRITE") |
+ // Single-word privileges (exclude "ON" and "QUERY" to avoid conflicts)
+ StringInIgnoreCase(
+ "ALTER",
+ "DELETE",
+ "INDEX",
+ "INSERT",
+ "REFERENCES",
+ "SELECT",
+ "UPDATE",
+ "ALL",
+ "DEBUG",
+ "FLASHBACK"
+ ).!
+ )
+
+ // Parser for Oracle type names that may be multi-word (e.g., DOUBLE PRECISION, SMALLINT)
+ def oracleTypeName[$ : P]: P[Name] = P(
+ (keyword("DOUBLE") ~ keyword("PRECISION")).map(_ => Name("DOUBLE PRECISION")) |
+ keyword("DOUBLE").map(_ => Name("DOUBLE")) |
+ keyword("SMALLINT").map(_ => Name("SMALLINT")) |
+ identifier
+ )
+
+ // Override oneOrMoreAttributes to use OracleWhitespace instead of MultiLineWhitespace
+ override def oneOrMoreAttributes[$ : P]: P[Seq[Name]] = P(
+ ("(" ~/ identifier.rep(sep = comma, min = 1) ~ ")")
+ | identifier.map(Seq(_))
+ )
+
+ // Helper parser for Oracle type parameters (NUMBER(p,s), VARCHAR2(size BYTE|CHAR), etc.)
+ // Returns the parameter string (e.g., "(10,2)", "(15 BYTE)")
+ def typeParameters[$ : P]: P[String] = P(
+ "(" ~/
+ (
+ // NUMBER(*,0) or NUMBER(3,2) or NUMBER(3,2 BYTE) - precision and scale with optional unit
+ // Try this FIRST to avoid backtracking issues with NUMBER(10, 2)
+ (numberPrecision ~
+ "," ~ numberPrecision ~
+ (StringInIgnoreCase("BYTE", "CHAR").!.?)).map { case (size, scale, unit) =>
+ val unitStr = unit.map(" " + _).getOrElse("")
+ s"($size,$scale$unitStr)"
+ } |
+ // VARCHAR2(15 BYTE) or NUMBER(3) or NUMBER(3 BYTE) - size with optional unit
+ // Parse digits directly to allow whitespace consumption
+ (oracleDigits ~
+ (StringInIgnoreCase("BYTE", "CHAR").!.?)).map { case (size, unit) =>
+ val unitStr = unit.map(" " + _).getOrElse("")
+ s"($size$unitStr)"
+ }
+ ) ~
+ ")"
+ ).!
+
+ // Override tableField to handle Oracle-specific data type syntax
+ override def tableField[$ : P]: P[Either[TableAnnotation, ColumnDefinition]] = P(
+ (
+ keyword("CONSTRAINT") ~
+ identifier ~
+ keyword("UNIQUE") ~/
+ oneOrMoreAttributes
+ ).map { case (_, attrs) => Left(TableUnique(attrs)) } | (
+ keyword("CONSTRAINT") ~
+ identifier ~
+ keyword("PRIMARY") ~/
+ keyword("KEY") ~/
+ oneOrMoreAttributes
+ ).map { case (_, attrs) => Left(TablePrimaryKey(attrs)) } | (
+ keyword("UNIQUE") ~/
+ oneOrMoreAttributes.map(attrs => Left(TableUnique(attrs)))
+ ) | (
+ keyword("PRIMARY") ~/
+ keyword("KEY") ~/
+ oneOrMoreAttributes.map(attrs => Left(TablePrimaryKey(attrs)))
+ ) | (
+ keyword("INDEX") ~/
+ keyword("ON") ~
+ oneOrMoreAttributes.map(attrs => Left(TableIndexOn(attrs)))
+ ) | (
+ (
+ identifier ~/
+ oracleTypeName ~
+ typeParameters.? ~
+ oracleColumnAnnotation.rep
+ ).map { case (name, typeName, paramsOpt, annotations) =>
+ val fullTypeName = Name(typeName.name + paramsOpt.getOrElse(""))
+ Right(ColumnDefinition(name, fullTypeName, Seq(), annotations))
+ }
+ )
+ )
+
+ // Override columnAnnotation to handle Oracle-specific ENABLE/DISABLE keywords and DEFAULT ON NULL
+ def oracleColumnAnnotation[$ : P]: P[ColumnAnnotation] = P(
+ // Try DEFAULT ON NULL first (Oracle-specific) - use lookahead to check full sequence
+ (&(keyword("DEFAULT") ~ keyword("ON") ~ keyword("NULL")) ~
+ keyword("DEFAULT") ~/
+ keyword("ON") ~/
+ keyword("NULL") ~/
+ (
+ ("(" ~ expression ~ ")")
+ | function // Function calls with parentheses
+ | functionCallNoParens // Function calls without parentheses (CURRENT_TIMESTAMP, etc.)
+ | primitive
+ )).map(ColumnDefaultValue(_)) |
+ // DEFAULT NULL (without ON) - use lookahead to ensure NULL follows
+ (&(keyword("DEFAULT") ~ keyword("NULL")) ~
+ keyword("DEFAULT") ~/
+ keyword("NULL")).map(_ => ColumnDefaultValue(NullPrimitive())) |
+ // Standalone NULL annotation (Oracle allows explicit NULL, though it's redundant)
+ keyword("NULL").map(_ => ColumnIsNullable()) |
+ // Fall back to standard column annotations (including regular DEFAULT and NOT NULL)
+ // Oracle allows ENABLE/DISABLE followed by VALIDATE/NOVALIDATE
+ columnAnnotation ~ (keyword("ENABLE") | keyword("DISABLE")).? ~ (StringInIgnoreCase(
+ "VALIDATE"
+ ) | StringInIgnoreCase("NOVALIDATE")).?
+ )
+
+ // Override common statements to return OracleStatement wrappers
+ override def insert[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("INSERT") ~/
+ (
+ keyword("OR") ~/
+ keyword("REPLACE")
+ ).!.?.map { case None => false; case _ => true } ~
+ keyword("INTO") ~/
+ oracleQualifiedTableName ~
+ ("(" ~/
+ identifier ~
+ (comma ~/ identifier).rep ~
+ ")").map(x => Seq(x._1) ++ x._2).? ~
+ (
+ // WITH ... SELECT
+ (&(keyword("WITH")) ~/ withSelect.map { case (ctes, body) => SelectInsert(body) }) |
+ // SELECT (without WITH)
+ (&(keyword("SELECT")) ~/ select.map(SelectInsert(_))) |
+ // VALUES
+ (&(keyword("VALUES")) ~/ valueList)
+ )
+ ).map { case (orReplace, table, columns, values) =>
+ OracleInsert(table, columns, values, orReplace)
+ }
+ )
+
+ // Parser for SELECT with optional WITH clause
+ def withSelect[$ : P]: P[(Seq[WithClause], SelectBody)] = P(withClauseList ~ select).map { case (ctes, body) =>
+ (ctes, body)
+ }
+
+ def withClauseList[$ : P]: P[Seq[WithClause]] = P(
+ keyword("WITH") ~/
+ withClause.rep(sep = comma, min = 1).map(_.toSeq)
+ )
+
+ def withClause[$ : P]: P[WithClause] = P(
+ identifier ~
+ keyword("AS") ~/
+ "(" ~/ select ~ ")"
+ ).map { case (name, body) => WithClause(body, name) }
+
+ override def update[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("UPDATE") ~/
+ oracleQualifiedTableName ~
+ // Optional table alias: identifier that is not part of a dotted name and is followed by SET
+ // Use negative lookahead to ensure we don't match "SET" as an alias
+ (!keyword("SET") ~~ identifier).? ~
+ keyword("SET") ~/
+ (
+ dottedPair.map { case (schema, col) =>
+ val colName =
+ schema.map(s => Name(s.name + "." + col.name, s.quoted || col.quoted)).getOrElse(col)
+ colName
+ } ~
+ "=" ~/
+ expression
+ ).rep(sep = comma, min = 1) ~
+ (
+ StringInIgnoreCase("WHERE") ~/
+ expression
+ ).?
+ ).map { case (table, _, set, where) =>
+ OracleUpdate(table, set, where)
+ }
+ )
+
+ override def delete[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("DELETE") ~/
+ keyword("FROM").!.?.map(_ => ()) ~/ // FROM is optional in Oracle
+ oracleQualifiedTableName ~
+ (
+ keyword("WHERE") ~/
+ expression
+ ).?
+ ).map { case (table, where) => OracleDelete(table, where) }
+ )
+
+ override def alterView[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("ALTER") ~
+ keyword("VIEW") ~/
+ oracleQualifiedTableName ~
+ (
+ (keyword("MATERIALIZE").!.map(_ => Materialize(true)))
+ | (keyword("DROP") ~
+ keyword("MATERIALIZE").!.map(_ => Materialize(false)))
+ )
+ ).map { case (name, op) => OracleAlterView(name, op) }
+ )
+
+ def alterTable[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("ALTER") ~
+ keyword("TABLE") ~/
+ oracleQualifiedTableName ~
+ (
+ alterTableModify |
+ alterTableRenameColumn |
+ alterTableAdd |
+ alterTableDrop // Combined DROP COLUMN and DROP CONSTRAINT parser
+ )
+ ).map { case (name, action) => OracleAlterTable(name, action) }
+ )
+
+ def alterTableModify[$ : P]: P[AlterTableAction] = P(
+ keyword("MODIFY") ~/
+ (
+ // MODIFY(column type, ...) - with parentheses
+ ("(" ~/
+ columnModification.rep(sep = comma, min = 1) ~
+ ")").map(mods => AlterTableModify(mods)) |
+ // MODIFY column type - without parentheses (single or multiple columns)
+ (columnModification ~
+ (comma ~/ keyword("MODIFY") ~/ columnModification).rep).map { case (first, rest) =>
+ AlterTableModify(Seq(first) ++ rest)
+ }
+ )
+ )
+
+ def alterTableRenameColumn[$ : P]: P[AlterTableAction] = P(
+ keyword("RENAME") ~/
+ keyword("COLUMN") ~/
+ identifier ~
+ keyword("TO") ~/
+ identifier
+ ).map { case (oldName, newName) => AlterTableRenameColumn(oldName, newName) }
+
+ def alterTableAdd[$ : P]: P[AlterTableAction] = P(
+ keyword("ADD") ~/
+ (
+ // ADD CONSTRAINT constraint_name UNIQUE/PRIMARY KEY/FOREIGN KEY (...)
+ (keyword("CONSTRAINT") ~/
+ identifier ~
+ tableConstraint).map { case (constraintName, constraint) =>
+ AlterTableAddConstraint(constraintName, constraint)
+ } |
+ // ADD(column, ...) - with parentheses
+ ("(" ~/
+ tableField.rep(sep = comma, min = 1) ~
+ ")").map { fields =>
+ val columns = fields.collect { case Right(colDef) => colDef }
+ val annotations = fields.collect { case Left(ann) => ann }
+ AlterTableAdd(columns.toSeq, annotations.toSeq)
+ } |
+ // ADD column - without parentheses (single or multiple columns)
+ (tableField ~
+ (keyword("ADD") ~/ tableField).rep).map { case (first, rest) =>
+ val allFields = Seq(first) ++ rest
+ val columns = allFields.collect { case Right(colDef) => colDef }
+ val annotations = allFields.collect { case Left(ann) => ann }
+ AlterTableAdd(columns, annotations)
+ }
+ )
+ )
+
+ def tableConstraint[$ : P]: P[TableConstraint] = P(
+ // UNIQUE (col1, col2, ...)
+ (keyword("UNIQUE") ~/
+ "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")").map { columns =>
+ UniqueConstraint(columns)
+ } |
+ // PRIMARY KEY (col1, col2, ...)
+ (keyword("PRIMARY") ~/
+ keyword("KEY") ~/
+ "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")").map { columns =>
+ PrimaryKeyConstraint(columns)
+ } |
+ // FOREIGN KEY (col1, ...) REFERENCES table(col1, ...)
+ (keyword("FOREIGN") ~/
+ keyword("KEY") ~/
+ "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")" ~
+ keyword("REFERENCES") ~/
+ oracleQualifiedTableName ~
+ "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")" ~
+ (keyword("ON") ~/ keyword("DELETE") ~/ referentialAction).? ~
+ (keyword("ON") ~/ keyword("UPDATE") ~/ referentialAction).?).map {
+ case (columns, refTable, refColumns, onDelete, onUpdate) =>
+ ForeignKeyConstraint(columns, refTable, refColumns, onDelete, onUpdate)
+ }
+ )
+
+ def referentialAction[$ : P]: P[ReferentialAction] = P(
+ (keyword("NO") ~/ keyword("ACTION")).map(_ => ReferentialAction.NoAction) |
+ keyword("RESTRICT").map(_ => ReferentialAction.Restrict) |
+ keyword("CASCADE").map(_ => ReferentialAction.Cascade) |
+ (keyword("SET") ~/ keyword("NULL")).map(_ => ReferentialAction.SetNull) |
+ (keyword("SET") ~/ keyword("DEFAULT")).map(_ => ReferentialAction.SetDefault)
+ )
+
+ def alterTableDrop[$ : P]: P[AlterTableAction] = P(
+ keyword("DROP") ~/
+ (
+ // DROP (col1, col2, ...) - multiple columns in parentheses
+ ("(" ~/
+ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~
+ ")").map(cols => AlterTableDropColumns(cols)) |
+ // DROP COLUMN col_name - single column
+ keyword("COLUMN") ~/ identifier.map { colName =>
+ AlterTableDropColumn(colName)
+ } |
+ // DROP CONSTRAINT constraint_name
+ keyword("CONSTRAINT") ~/ identifier.map { constraintName =>
+ AlterTableDropConstraint(constraintName)
+ }
+ )
+ )
+
+ def alterTableDropColumn[$ : P]: P[AlterTableAction] = P(
+ keyword("DROP") ~/
+ keyword("COLUMN") ~/
+ identifier
+ ).map(colName => AlterTableDropColumn(colName))
+
+ def alterTableDropConstraint[$ : P]: P[AlterTableAction] = P(
+ keyword("DROP") ~/
+ keyword("CONSTRAINT") ~/
+ identifier
+ ).map(constraintName => AlterTableDropConstraint(constraintName))
+
+ def columnModification[$ : P]: P[ColumnModification] = P(
+ (
+ identifier ~/
+ // Type is optional - MODIFY column DEFAULT/NOT NULL is valid without type
+ // Use lookahead to check if this looks like a type (not a keyword like DEFAULT/NOT)
+ (
+ // Only parse type if NOT followed by annotation keywords
+ !(&(keyword("DEFAULT") | keyword("NOT") | keyword("PRIMARY"))) ~
+ (oracleTypeName ~
+ typeParameters.?).map { case (typeName, paramsOpt) =>
+ Some(Name(typeName.name + paramsOpt.getOrElse("")))
+ }
+ ).? ~
+ oracleColumnAnnotation.rep
+ ).map { case (name: Name, typeOpt, annotations) =>
+ ColumnModification(name, typeOpt.flatten, annotations)
+ }
+ )
+
+ override def dropTableOrView[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("DROP") ~
+ keyword("TABLE", "VIEW").!.map(_.toUpperCase) ~/
+ ifExists ~
+ oracleQualifiedTableName ~
+ (keyword("CASCADE") ~/ StringInIgnoreCase("CONSTRAINTS")).!.?.map(_.isDefined) ~
+ keyword("PURGE").!.?.map(_.isDefined)
+ ).map {
+ case ("TABLE", ifExists, name, cascadeConstraints, purge) =>
+ OracleDropTable(name, ifExists, cascadeConstraints, purge)
+ case ("VIEW", ifExists, name, _, _) =>
+ OracleDropView(name, ifExists)
+ case (_, _, _, _, _) =>
+ throw new Exception("Internal Error")
+ }
+ )
+
+ def dropSynonym[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("DROP") ~
+ keyword("SYNONYM") ~/
+ ifExists ~
+ oracleQualifiedTableName
+ ).map { case (ifExists, name) => OracleDropSynonym(name, ifExists) }
+ )
+
+ def rename[$ : P]: P[OracleStatement] = P(
+ (
+ keyword("RENAME") ~/
+ oracleQualifiedTableName ~
+ keyword("TO") ~/
+ oracleQualifiedTableName
+ ).map { case (oldName, newName) => OracleRename(oldName, newName) }
+ )
+
+ // Parser for fully qualified column names (schema.table.column)
+ def oracleQualifiedColumnName[$ : P]: P[Name] = P(
+ (identifier ~ ("." ~/ identifier).rep(min = 1, max = 2)).map { case (first, rest) =>
+ val parts = Seq(first) ++ rest.toSeq
+ val nameStr = parts.map(_.name).mkString(".")
+ val quoted = parts.exists(_.quoted)
+ Name(nameStr, quoted)
+ }
+ )
+
+ def comment[$ : P]: P[Comment] = P(
+ (
+ keyword("COMMENT") ~/
+ keyword("ON") ~/
+ keyword("COLUMN") ~/
+ oracleQualifiedColumnName ~
+ keyword("IS") ~/
+ quotedString
+ ).map { case (columnName, commentText) => Comment(columnName, commentText) }
+ )
+
+ // PL/SQL recursive parser combinators for blocks and loops
+ // Uses FastParse's recursive descent to naturally handle nesting
+
+ // Parse a PL/SQL label: <