diff --git a/.gitignore b/.gitignore index e705c32..d8d23e0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ *.class *.log -/target -/project \ No newline at end of file +**/target diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..f825c99 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,12 @@ +version = 3.9.8 +style = default +runner.dialect=scala212 +maxColumn = 120 +continuationIndent.defnSite = 2 +continuationIndent.callSite = 2 +align.preset = "none" +danglingParentheses.preset = true +optIn.configStyleArguments = false +docstrings.style = SpaceAsterisk +spaces.beforeContextBoundColon = true +rewrite.rules = [SortImports] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..6d1098c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,56 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Multi-dialect support for SQL parsing: + - Oracle SQL dialect + - PostgreSQL dialect + - ANSI SQL dialect +- CLI application (`Main.scala`) with comprehensive options: + - Multiple output formats: text, JSON, JQ queries + - File pattern matching (glob support) + - Configurable output destinations (console or file) +- Error recovery for unparseable statements +- Circe-based JSON encoding for all AST types +- Assembly plugin support for building standalone JARs +- Comprehensive test suite with Oracle-specific tests +- Support for parsing multiple statements from a single file with accurate line number tracking + +### Changed +- **BREAKING**: Major refactoring of package structure: + - Moved core types to `sparsity.common.*` package + - Parser base classes now in `sparsity.common.parser.*` + - Statement types in `sparsity.common.statement.*` + - Expression types in `sparsity.common.expression.*` +- Refactored parser architecture: + - Moved logic from `Elements` into `SQLBase` for better subclass extensibility + - Created dialect-specific parser objects +- Enhanced build configuration: + - Updated SBT to 1.11.7 + - Cross-compilation for Scala 2.12, 2.13, and 3.3 + - Added assembly, scalafmt, and cross-project plugins +- Updated dependencies: + - fastparse to 3.1.1 + - Added scribe for logging + - Added mainargs for CLI argument parsing + - Added monocle for optics + - Added java-jq for JQ query support + +### Fixed +- Improved error messages with better context +- Better handling of parsing failures with fallback mechanisms + +## [1.7.0] - Previous Release + +### Added +- Cross-compiled support for Scala 2.12, 2.13, and 3.3 +- Updated fastparse dependency + +[Unreleased]: https://github.com/UBOdin/sparsity/compare/v1.7.0...HEAD +[1.7.0]: https://github.com/UBOdin/sparsity/releases/tag/v1.7.0 diff --git a/README.md b/README.md index ad5d268..56b5b60 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,6 @@ val tree = SQL("SELECT * FROM R") libraryDependencies ++= Seq("info.mimirdb" %% "sparsity" % "1.5") ``` - ## Version History - -#### 1.7 (Breaking) -- Added cross-compiled support for Scala 2.12 \ No newline at end of file +See [CHANGELOG.md](CHANGELOG.md) for detailed version history and release notes. \ No newline at end of file diff --git a/build.sbt b/build.sbt index a39b0c1..2bd3f2c 100644 --- a/build.sbt +++ b/build.sbt @@ -1,49 +1,109 @@ import scala.sys.process._ +import scala.scalanative.build._ +import sbtcrossproject.CrossPlugin.autoImport._ +import sbtassembly.AssemblyPlugin.autoImport._ -name := "Sparsity" -version := "1.7.1-SNAPSHOT" -organization := "info.mimirdb" -scalaVersion := "2.12.20" -crossScalaVersions := Seq("2.12.20", "2.13.17", "3.3.6") +val Scala3 = "3.3.6" +val Scala2_13 = "2.13.17" +val Scala2_12 = "2.12.20" -dependencyOverrides ++= { +val circeVersion = "0.14.14" +val circeOpticsVersion = "0.15.1" +val jsonPathVersion = "0.2.0" +val specs2Version = "4.23.0" +val fastparseVersion = "3.1.1" +val javaJqVersion = "2.0.0" +val scribeVersion = "3.17.0" +val monocleVersion = "3.3.0" +val mainArgsVersion = "0.7.7" + +ThisBuild / scalafmtOnCompile := true +ThisBuild / name := "Sparsity" +ThisBuild / version := "1.7.1-SNAPSHOT" +ThisBuild / organization := "info.mimirdb" +ThisBuild / scalaVersion := Scala2_13 +ThisBuild / crossScalaVersions := Seq(Scala2_12, Scala2_13, Scala3) + +ThisBuild / dependencyOverrides ++= { CrossVersion.partialVersion(scalaVersion.value) match { case Some((3, _)) => Seq.empty // Scala 3 doesn't need override case _ => Seq("org.scala-lang" % "scala-library" % scalaVersion.value) } } -resolvers += "MimirDB" at "https://maven.mimirdb.info/" -resolvers ++= Seq("snapshots", "releases").map(Resolver.sonatypeRepo) +ThisBuild / resolvers += "MimirDB" at "https://maven.mimirdb.info/" +ThisBuild / resolvers ++= Seq("snapshots", "releases").map(Resolver.sonatypeRepo) -libraryDependencies ++= Seq( - "com.lihaoyi" %% "fastparse" % "3.1.1", - "com.typesafe.scala-logging" %% "scala-logging" % "3.9.5", - "ch.qos.logback" % "logback-classic" % "1.5.14", - "org.specs2" %% "specs2-core" % "4.23.0" % "test", - "org.specs2" %% "specs2-junit" % "4.23.0" % "test" +val scalacheckVersion = "1.19.0" + +val commonSettings = Seq( + libraryDependencies ++= Seq( + "com.lihaoyi" %%% "mainargs" % mainArgsVersion, + "com.lihaoyi" %%% "fastparse" % fastparseVersion, + "com.outr" %%% "scribe" % scribeVersion, + "io.circe" %%% "circe-core" % circeVersion, + "io.circe" %%% "circe-generic" % circeVersion, + "io.circe" %%% "circe-generic-extras" % "0.14.5-RC1", + "io.circe" %%% "circe-parser" % circeVersion, + "io.circe" %%% "circe-optics" % circeOpticsVersion, + "dev.optics" %% "monocle-core" % monocleVersion, + "dev.optics" %% "monocle-macro" % monocleVersion, + "com.quincyjo" %%% "scala-json-path-circe" % jsonPathVersion, + "com.arakelian" % "java-jq" % javaJqVersion, + "org.specs2" %%% "specs2-core" % specs2Version % "test", + "org.specs2" %%% "specs2-junit" % specs2Version % "test", + "org.scalacheck" %%% "scalacheck" % scalacheckVersion % "test" + ) ) -////// Publishing Metadata ////// -// use `sbt publish makePom` to generate -// a publishable jar artifact and its POM metadata - -publishMavenStyle := true - -pomExtra := http://github.com/UBOdin/sparsity/ - - - Apache License 2.0 - http://www.apache.org/licenses/ - repo - - - - git@github.com:ubodin/sparsity.git - scm:git:git@github.com:ubodin/sparsity.git - - -/////// Publishing Options //////// -// use `sbt publish` to update the package in -// your own local ivy cache -publishTo := Some(Resolver.file("file", new File("/var/www/maven_repo/"))) +val nativeSettings = Seq( + nativeConfig ~= { + _.withLTO(LTO.full) + .withMode(Mode.releaseFast) + .withGC(GC.commix) + } +) + +// lazy val root = project +// .in(file(".")) +// .aggregate(sparsity.jvm, sparsity.native) + +lazy val sparsity = + //crossProject(JVMPlatform, NativePlatform) + //.crossType(CrossType.Pure) + project + .in(file(".")) + .settings( + name := "Sparsity", + version := "1.7.1-SNAPSHOT", + organization := "info.mimirdb", + commonSettings, + // Assembly settings + assembly / mainClass := Some("sparsity.Main"), + assembly / assemblyJarName := s"${name.value}-${version.value}.jar", + assembly / assemblyMergeStrategy := { + case PathList("META-INF", xs @ _*) => MergeStrategy.discard + case x => MergeStrategy.first + }, + assembly / artifact := { + val art = (assembly / artifact).value + art.withClassifier(Some("assembly")) + }, + addArtifact(assembly / artifact, assembly), + // Publishing settings + publishMavenStyle := true, + pomExtra := http://github.com/UBOdin/sparsity/ + + + Apache License 2.0 + http://www.apache.org/licenses/ + repo + + + + git@github.com:ubodin/sparsity.git + scm:git:git@github.com:ubodin/sparsity.git + , + ) + //.nativeSettings(nativeSettings) + .enablePlugins(AssemblyPlugin) diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..01a16ed --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.11.7 diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..1cd828c --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1,5 @@ +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.3") +addSbtPlugin("org.portable-scala" % "sbt-scala-native-crossproject" % "1.3.2") +addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.5.8") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1") + diff --git a/src/main/scala/sparsity/Main.scala b/src/main/scala/sparsity/Main.scala new file mode 100644 index 0000000..1eb62fe --- /dev/null +++ b/src/main/scala/sparsity/Main.scala @@ -0,0 +1,597 @@ +package sparsity + +import sparsity.oracle.OracleSQL +import sparsity.ansi.ANSISQL +import sparsity.postgres.PostgresSQL +import sparsity.common.parser.ErrorMessage +import sparsity.common.codec.CirceCodecs._ +import fastparse._ +import scala.io.Source +import java.io.{File, FileWriter, PrintWriter} +import java.nio.file.{FileSystems, Files, Paths} +import java.util.logging.{Level => JLevel, LogManager, Logger => JLogger} +import scribe.Logger +import scribe.Level +import io.circe._ +import io.circe.generic.semiauto._ +import io.circe.syntax._ +import io.circe.Printer +import io.circe.Encoder +import sparsity.common.statement.Statement +import com.arakelian.jq.{ImmutableJqLibrary, ImmutableJqRequest} +import mainargs.{arg, main, Leftover, Parser, TokensReader} +import sparsity.common.parser.SQLBaseObject + +case class Annotation(file: String, line: Int, title: String, message: String, annotation_level: String) + +object Annotation { + implicit val encoder: Encoder[Annotation] = deriveEncoder[Annotation] +} + +// Maps for dialect lookup - built dynamically using the name method +object SQLDialectMaps { + val dialects: Seq[SQLBaseObject] = Seq(OracleSQL, ANSISQL, PostgresSQL) + val dialectNames: Seq[String] = dialects.map(_.name) + val nameToDialect: Map[String, SQLBaseObject] = dialects.map(d => d.name -> d).toMap + val validDialects: String = dialects.map(_.name).mkString(", ") +} + +// Custom parser for SQLBaseObject +object SQLBaseObjectReader extends TokensReader.Simple[SQLBaseObject] { + def shortName = "dialect" + def read(strs: Seq[String]): Either[String, SQLBaseObject] = { + if (strs.isEmpty) { + Left("Missing dialect name") + } else { + SQLDialectMaps.nameToDialect.get(strs.head.toLowerCase) match { + case Some(dialect) => Right(dialect) + case None => Left(s"Unknown dialect: ${strs.head}. Valid dialects are: ${SQLDialectMaps.validDialects}") + } + } + } +} + +// Custom parser for OutputFormat +object OutputFormatReader extends TokensReader.Simple[OutputFormat] { + def shortName = "format" + def read(strs: Seq[String]): Either[String, OutputFormat] = { + if (strs.isEmpty) { + Left("Missing output format") + } else { + val s = strs.head + s match { + case "text" => Right(TextFormat) + case "json" => Right(StatementsJsonFormat) + case s"jq($query)" => Right(JqFormat(query)) + case s"jq-file:$filePath" => Right(JqFileFormat(filePath)) + case _ => Left(s"Unknown output format: $s. Valid formats: text, json, jq(query), jq-file:path") + } + } + } +} + +// Output format types +sealed trait OutputFormat { + def name: String + def format(result: ProcessingResult): String +} + +case object TextFormat extends OutputFormat { + def name: String = "text" + def format(result: ProcessingResult): String = { + val textOutputLines = result.fileResults.flatMap { fileResult => + (s"\n=== Processing: ${fileResult.fileName}" :: fileResult.statementResults.flatMap { stmtResult => + val idx = stmtResult.index + stmtResult.parseResult match { + case Parsed.Success(_, _) => + List(s"✓ Statement ${idx + 1}: OK") + case Parsed.Failure(label, index, _) => + val stmtIndex = index.toInt + val absoluteLineNum = stmtResult.absoluteLineNumber.getOrElse(0) + // For unparseable statements, show the full statement text instead of a substring + val context = if (label == "unparseable") { + // Show full statement text (up to 200 chars for readability) + val fullText = stmtResult.context.statement + if (fullText.length <= 200) fullText else fullText.take(200) + "..." + } else { + stmtResult.context.statement + .substring(Math.max(0, stmtIndex - 50), Math.min(stmtResult.context.statement.length, stmtIndex + 50)) + } + List( + s"❌ ${stmtResult.context.relativePath}:${absoluteLineNum}: Syntax ERROR", + s"Statement ${idx + 1} failed: $label", + s"Context: $context" + ) + } + }) :+ fileResult.fileSummary + } + + val overallSummary = s"\n=== Overall Summary ===" + val summary2 = s"Files processed: ${result.fileResults.length}" + val summary3 = + s"Total: ${result.totalSuccessCount} successful, ${result.totalErrorCount} failed out of ${result.totalStatements} statements" + + (textOutputLines :+ overallSummary :+ summary2 :+ summary3).mkString("\n") + } +} +case object StatementsJsonFormat extends OutputFormat { + import ParsedEncoder._ + def name: String = "json" + def format(result: ProcessingResult): String = { + val statementsJson = result.asJson + Printer.spaces2.copy(dropNullValues = true).print(statementsJson) + } +} +case class JqFormat(query: String) extends OutputFormat { + def name: String = "jq" + def format(result: ProcessingResult): String = { + // Generate statements JSON for JQ query + val statementsJsonString = StatementsJsonFormat.format(result) + try { + // Suppress Java JQ library logging before initializing the library + // This must be done before ImmutableJqLibrary.of() is called + // Suppress all JQ-related loggers and their handlers + val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary") + + loggersToSuppress.foreach { loggerName => + val logger = JLogger.getLogger(loggerName) + logger.setLevel(JLevel.OFF) // Completely disable logging + logger.setUseParentHandlers(false) + // Remove all handlers + val handlers = logger.getHandlers + handlers.foreach(logger.removeHandler) + } + + val library = ImmutableJqLibrary.of() + val request = ImmutableJqRequest + .builder() + .lib(library) + .input(statementsJsonString) + .filter(query) + .build() + request.execute().getOutput + } catch { + case e: Exception => + throw new RuntimeException(s"Error executing JQ query: ${e.getMessage}", e) + } + } +} +case class JqFileFormat(filePath: String) extends OutputFormat { + def name: String = "jq-file" + def format(result: ProcessingResult): String = { + // Read the JQ query from the file + val query = + try { + val currentDir = Paths.get(".").toAbsolutePath.normalize + val queryFile = if (Paths.get(filePath).isAbsolute) { + Paths.get(filePath).toFile + } else { + currentDir.resolve(filePath).toFile + } + + if (!queryFile.exists() || !queryFile.isFile) { + throw new RuntimeException(s"JQ query file not found: $filePath") + } + + Source.fromFile(queryFile).mkString.trim + } catch { + case e: RuntimeException => throw e + case e: Exception => + throw new RuntimeException(s"Error reading JQ query file '$filePath': ${e.getMessage}", e) + } + + // Generate statements JSON for JQ query + val statementsJsonString = StatementsJsonFormat.format(result) + try { + // Suppress Java JQ library logging before initializing the library + // This must be done before ImmutableJqLibrary.of() is called + // Suppress all JQ-related loggers and their handlers + val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary") + + loggersToSuppress.foreach { loggerName => + val logger = JLogger.getLogger(loggerName) + logger.setLevel(JLevel.OFF) // Completely disable logging + logger.setUseParentHandlers(false) + // Remove all handlers + val handlers = logger.getHandlers + handlers.foreach(logger.removeHandler) + } + + val library = ImmutableJqLibrary.of() + val request = ImmutableJqRequest + .builder() + .lib(library) + .input(statementsJsonString) + .filter(query) + .build() + request.execute().getOutput + } catch { + case e: Exception => + throw new RuntimeException(s"Error executing JQ query from file '$filePath': ${e.getMessage}", e) + } + } +} + +// Output destination types +sealed trait OutputDestination { + def write(content: String): Unit +} +case object ConsoleDestination extends OutputDestination { + def write(content: String): Unit = println(content) +} +case class FileDestination(path: String) extends OutputDestination { + def write(content: String): Unit = { + val writer = new PrintWriter(new File(path)) + try + writer.println(content) + finally + writer.close() + } +} + +object OutputDestination { + implicit val tokensReader: TokensReader.Simple[OutputDestination] = new TokensReader.Simple[OutputDestination] { + def shortName = "destination" + def read(strs: Seq[String]): Either[String, OutputDestination] = { + if (strs.isEmpty) { + Left("Missing output destination") + } else { + strs.head match { + case "console" => Right(ConsoleDestination) + case path: String => Right(FileDestination(path)) + case _ => Left(s"Unknown output destination: ${strs.head}. Valid destinations: console, file path") + } + } + } + } +} + +@main(doc = "Parse SQL files using the specified dialect") +case class Config( + @arg(positional = true, doc = "SQL dialect to use") + dialect: SQLBaseObject, + @arg(doc = + "Output format. Format: text|json|jq(query)|jq-file:path. Example: --output jq(.files[].statements[]) or --output jq-file:pr-checker.jq" + ) + output: OutputFormat = TextFormat, + @arg(doc = "Output destination. Example: --write-to console or --write-to output.json") + writeTo: OutputDestination = ConsoleDestination, + @arg(doc = "SQL file(s) or glob pattern(s) to parse (e.g., *.sql)") + filePatterns: Leftover[String] +) + +case class StatementContext( + statement: String, + rawStatement: String, + relativePath: String, + dialectName: String, + statementStartInContent: Int +) + +object StatementContext { + implicit val encoder: Encoder[StatementContext] = deriveEncoder[StatementContext] +} + +case class StatementResult( + index: Int, + parseResult: Parsed[Statement], + context: StatementContext, + absoluteLineNumber: Option[Int] = None +) { + def success: Boolean = parseResult match { + case Parsed.Success(_, _) => true + case _: Parsed.Failure => false + } +} + +// Custom encoder for Parsed[Statement] to make it serializable +object ParsedEncoder { + implicit def parsedEncoder[T : Encoder]: Encoder[Parsed[T]] = Encoder.instance { + case Parsed.Success(value, _) => + Json.obj("success" -> Json.fromBoolean(true), "value" -> value.asJson) + case Parsed.Failure(label, index, extra) => + Json.obj("success" -> Json.fromBoolean(false), "label" -> Json.fromString(label), "index" -> Json.fromLong(index)) + } +} + +object StatementResult { + import ParsedEncoder._ + implicit val encoder: Encoder[StatementResult] = deriveEncoder[StatementResult] +} + +case class FileResult( + fileName: String, + relativePath: String, + statementResults: List[StatementResult], + fileSummary: String +) + +object FileResult { + implicit val encoder: Encoder[FileResult] = deriveEncoder[FileResult] +} + +case class ProcessingResult( + fileResults: List[FileResult], + totalSuccessCount: Int, + totalErrorCount: Int, + totalStatements: Int +) + +object ProcessingResult { + implicit val encoder: Encoder[ProcessingResult] = deriveEncoder[ProcessingResult] +} + +object Main { + // Configure scribe logger + private val logger = Logger("sparsity.Main") + + // Make the custom parsers available implicitly + implicit val sqlBaseObjectReader: TokensReader.Simple[SQLBaseObject] = SQLBaseObjectReader + implicit val outputFormatReader: TokensReader.Simple[OutputFormat] = OutputFormatReader + + def main(args: scala.Array[String]): Unit = { + // Parse command-line arguments using mainargs first to check output destination + val config = Parser[Config].constructOrExit(args) + + // Configure logging - suppress logs when writing to console + val shouldSuppressLogs = config.writeTo == ConsoleDestination + val logLevel = if (shouldSuppressLogs) { + // When writing to console, suppress INFO/DEBUG logs + Option(System.getProperty("log.level")) + .orElse(Option(System.getenv("LOG_LEVEL"))) + .getOrElse("WARN") + } else { + // When writing to file, respect user's log level preference + Option(System.getProperty("log.level")) + .orElse(Option(System.getenv("LOG_LEVEL"))) + .getOrElse("DEBUG") + } + + val level = logLevel.toUpperCase match { + case "TRACE" => Level.Trace + case "DEBUG" => Level.Debug + case "INFO" => Level.Info + case "WARN" => Level.Warn + case "ERROR" => Level.Error + case _ => Level.Debug + } + + // Configure root logger with console output and minimum level + Logger.root + .clearHandlers() + .withHandler(minimumLevel = Some(level)) + .replace() + + // Suppress Java JQ library logging when writing to console + if (shouldSuppressLogs) { + // Suppress all JQ-related loggers early + val loggersToSuppress = Seq("com.arakelian.jq", "com.arakelian.jq.NativeLib", "com.arakelian.jq.JqLibrary") + + loggersToSuppress.foreach { loggerName => + val logger = JLogger.getLogger(loggerName) + logger.setLevel(JLevel.OFF) + logger.setUseParentHandlers(false) + } + } + + // Expand glob patterns and collect all files + val files = config.filePatterns.value.flatMap { pattern => + val currentDir = Paths.get(".").toAbsolutePath.normalize + + // If it's a direct file path and exists, return it + val directFile = if (Paths.get(pattern).isAbsolute) { + Paths.get(pattern).toFile + } else { + currentDir.resolve(pattern).toFile + } + + if (directFile.exists() && directFile.isFile && !pattern.contains('*') && !pattern.contains('?')) { + Seq(directFile) + } else if (pattern.contains('*') || pattern.contains('?')) { + // Normalize path separators to forward slashes for glob matching + val normalizedPattern = pattern.replace('\\', '/') + + // Determine the base directory to start searching from + // For patterns like "../**/*oracle*/**/*.sql", find the deepest non-glob directory + val parts = normalizedPattern.split("/").filter(_.nonEmpty) + def findBaseRec(path: java.nio.file.Path, index: Int): (java.nio.file.Path, Int) = + if (index >= parts.length) { + (path, index) + } else { + val part = parts(index) + if (part == "..") { + findBaseRec(path.getParent, index + 1) + } else if (part == ".") { + findBaseRec(path, index + 1) + } else if (part.contains('*') || part.contains('?')) { + (path, index) // Stop here, found glob + } else { + val nextPath = path.resolve(part) + if (nextPath.toFile.exists() && nextPath.toFile.isDirectory) { + findBaseRec(nextPath, index + 1) + } else { + (path, index) // Stop here, directory doesn't exist + } + } + } + + val (baseDir, baseIndex) = findBaseRec(currentDir, 0) + + // Build the relative pattern from the base directory + val relativePattern = if (baseIndex < parts.length) { + parts.slice(baseIndex, parts.length).mkString("/") + } else { + "**" // If we consumed everything, search recursively + } + + val basePath = baseDir.normalize + + if (basePath.toFile.exists() && basePath.toFile.isDirectory) { + val fs = FileSystems.getDefault + val matcher = fs.getPathMatcher(s"glob:$relativePattern") + + // Walk the directory tree recursively using Stream + try { + import scala.jdk.CollectionConverters._ + Files + .walk(basePath) + .iterator() + .asScala + .filter(_.toFile.isFile) + .filter { path => + val relativePath = basePath.relativize(path).toString.replace('\\', '/') + matcher.matches(Paths.get(relativePath)) + } + .map(_.toFile) + .toSeq + } catch { + case _: Exception => + Seq.empty + } + } else { + Seq.empty + } + } else { + // Not a glob pattern, but file doesn't exist + Seq.empty + } + }.distinct + + if (files.isEmpty) { + logger.warn(s"No files found matching: ${config.filePatterns.value.mkString(", ")}") + System.exit(1) + } + + // Step 1: Run the parser + val result = parseFiles(files, config.dialect) + + // Step 2: Convert results into outputs + generateOutputs(result, config) + + // Exit with error code if there were errors + if (result.totalErrorCount > 0) { + System.exit(1) + } + } + + /** Parses a list of SQL files using the specified dialect and returns processing results. This method is public to + * enable testing. + * + * @param files + * The list of files to parse + * @param dialect + * The SQL dialect object to use + * @return + * ProcessingResult containing file results and summary statistics + */ + def parseFiles(files: Seq[File], dialect: SQLBaseObject): ProcessingResult = { + val fileResults = files.map { file => + val relativePath = { + val currentDir = Paths.get(".").toAbsolutePath.normalize + val filePathObj = file.toPath.toAbsolutePath.normalize + try + currentDir.relativize(filePathObj).toString.replace('\\', '/') + catch { + case _: Exception => file.getPath + } + } + + val content = Source.fromFile(file).mkString + + val dialectName = dialect.name + + // Use parseAll to parse all statements with error recovery (now returns positions) + val parseResults: Seq[sparsity.common.statement.StatementParseResult] = dialect.parseAll(content) + + // Convert StatementParseResult to StatementResult + val statementResults = parseResults.zipWithIndex.map { case (parseResult, idx) => + val (statementText, isUnparseable) = parseResult.result match { + case Right(stmt) => (stmt.toString, false) + case Left(unparseable) => (unparseable.content, true) + } + + // Calculate absolute line number from start position + val absoluteLineNum = if (parseResult.startPos >= 0 && parseResult.startPos < content.length) { + val contentBeforeStmt = content.substring(0, parseResult.startPos) + val linesBeforeStmt = contentBeforeStmt.count(_ == '\n') + Some(linesBeforeStmt + 1) // +1 because line numbers are 1-based + } else { + None + } + + // Create parse result object + val finalParseResult = parseResult.result match { + case Right(stmt) => Parsed.Success(stmt, 0) + case Left(_) => Parsed.Failure("unparseable", 0, null) + } + + val stmtContext = StatementContext( + statement = statementText, + rawStatement = statementText, + relativePath = relativePath, + dialectName = dialectName, + statementStartInContent = parseResult.startPos + ) + + val stmtResult = StatementResult( + index = idx, + parseResult = finalParseResult, + context = stmtContext, + absoluteLineNumber = absoluteLineNum + ) + + // Log parsing details for failures + parseResult.result match { + case Left(unparseable) => + val stmtLines = statementText.split("\n") + logger.info(s" Label: unparseable") + logger.info(s" Context: ${statementText.take(100)}") + logger.debug(s" Full statement:") + logger.debug(s" ${stmtLines.mkString("\n ")}") + case Right(_) => + } + + stmtResult + }.toList + + val successCount = statementResults.count(_.success) + val errorCount = statementResults.count(!_.success) + val summary = + s"File summary: $successCount successful, $errorCount failed out of ${statementResults.length} statements" + + FileResult( + fileName = file.getName, + relativePath = relativePath, + statementResults = statementResults, + fileSummary = summary + ) + }.toList + + val totals = fileResults.foldLeft((0, 0, 0)) { case ((success, error, total), fileResult) => + val stmtCounts = fileResult.statementResults.foldLeft((0, 0)) { case ((s, e), stmt) => + if (stmt.success) (s + 1, e) else (s, e + 1) + } + (success + stmtCounts._1, error + stmtCounts._2, total + fileResult.statementResults.length) + } + + ProcessingResult( + fileResults = fileResults, + totalSuccessCount = totals._1, + totalErrorCount = totals._2, + totalStatements = totals._3 + ) + } + + /** Generates all outputs from parsing results. This method is public to enable testing. + * + * @param result + * The processing result from parsing files + * @param config + * The configuration containing output options + */ + def generateOutputs(result: ProcessingResult, config: Config): Unit = { + // Process each output format (all outputs go to console) + val formattedOutput = config.output.format(result) + config.writeTo.write(formattedOutput) + } +} diff --git a/src/main/scala/sparsity/Name.scala b/src/main/scala/sparsity/Name.scala deleted file mode 100644 index 622f24f..0000000 --- a/src/main/scala/sparsity/Name.scala +++ /dev/null @@ -1,49 +0,0 @@ -package sparsity - -case class Name(name: String, quoted: Boolean = false) -{ - override def equals(other: Any): Boolean = - { - other match { - case s: String => equals(s) - case n: Name => equals(n) - case _ => false - } - } - def equals(other: Name): Boolean = - { - if(quoted || other.quoted){ name.equals(other.name) } - else { name.equalsIgnoreCase(other.name) } - } - def equals(other: String): Boolean = - { - if(quoted){ name.equals(other) } - else { name.equalsIgnoreCase(other) } - } - override def toString = if(quoted){ "`"+name+"`" } else { name } - def +(other: Name) = Name(name+other.name, quoted || other.quoted) - def +(other: String) = Name(name+other, quoted) - - def lower = if(quoted){ name } else { name.toLowerCase } - def upper = if(quoted){ name } else { name.toUpperCase } - - def extendUpper(template:String) = - Name(template.replace("\\$", upper), quoted) - def extendLower(template:String) = - Name(template.replace("\\$", lower), quoted) - - def withPrefix(other: String) = Name(other+name, quoted) - def withSuffix(other: String) = Name(name+other, quoted) -} - -class StringNameMatch(cmp:String) -{ - def unapply(name: Name):Option[Unit] = - if(name.equals(cmp)) { return Some(()) } - else { return None } -} - -object NameMatch -{ - def apply(cmp: String) = new StringNameMatch(cmp) -} \ No newline at end of file diff --git a/src/main/scala/sparsity/alter/AlterView.scala b/src/main/scala/sparsity/alter/AlterView.scala deleted file mode 100644 index b63daf2..0000000 --- a/src/main/scala/sparsity/alter/AlterView.scala +++ /dev/null @@ -1,8 +0,0 @@ -package sparsity.alter - -sealed abstract class AlterViewAction - -case class Materialize(add: Boolean) extends AlterViewAction -{ - override def toString = (if(add){ "" } else { "DROP " })+"MATERIALIZE" -} \ No newline at end of file diff --git a/src/main/scala/sparsity/ansi/ANSISQL.scala b/src/main/scala/sparsity/ansi/ANSISQL.scala new file mode 100644 index 0000000..7419a67 --- /dev/null +++ b/src/main/scala/sparsity/ansi/ANSISQL.scala @@ -0,0 +1,73 @@ +package sparsity.ansi + +import fastparse._ +import scala.io._ +import java.io._ +import sparsity.common.statement._ +import sparsity.common.{ + BigInt, + Boolean, + Char, + DataType, + Date, + Decimal, + Double, + Float, + Integer, + Text, + Timestamp, + VarChar +} +import sparsity.common.parser.{SQLBase, SQLBaseObject, StreamParser} +import sparsity.common.statement.Unparseable + +/** ANSI SQL dialect implementation. Provides baseline SQL parsing with standard ANSI SQL features. + */ +class ANSISQL extends SQLBase { + // Configuration + type Stmt = Statement + def caseSensitive = false + def statementTerminator = ";" + def supportsCTEs = true + def supportsReturning = false + def supportsIfExists = true + def stringConcatOp = "||" + + // Parser implementations + def createStatement[$ : P]: P[Statement] = P(&(keyword("CREATE")) ~/ (createTable | createView)) + + def dialectSpecificStatement[$ : P]: P[Statement] = P(fastparse.Fail) + + def dataType[$ : P]: P[DataType] = P( + (keyword("VARCHAR") ~ "(" ~ integer.map(_.toInt) ~ ")").map(VarChar(_)) + | (keyword("CHAR") ~ "(" ~ integer.map(_.toInt) ~ ")").map(Char(_)) + | keyword("INTEGER").map(_ => Integer()) + | keyword("INT").map(_ => Integer()) + | keyword("BIGINT").map(_ => BigInt()) + | (keyword("DECIMAL") ~ "(" ~ integer.map(_.toInt) ~ "," ~ integer.map(_.toInt) ~ ")") + .map { case (p, s) => Decimal(p, s) } + | keyword("FLOAT").map(_ => Float()) + | keyword("DOUBLE").map(_ => Double()) + | keyword("DATE").map(_ => Date()) + | (keyword("TIMESTAMP") ~ ("(" ~ integer.map(_.toInt) ~ ")").?) + .map(p => Timestamp(p)) + | keyword("BOOLEAN").map(_ => Boolean()) + | keyword("TEXT").map(_ => Text()) + ) +} + +object ANSISQL extends SQLBaseObject { + type Stmt = Statement + import fastparse.Parsed + val instance = new ANSISQL() + + def name: String = "ansi" + + protected def statementTerminator: String = ";" + + def apply(input: String): Parsed[Stmt] = + parse(input, instance.terminatedStatement(_)) + + def apply(input: Reader): StreamParser[Stmt] = + new StreamParser[Stmt](parse(_: Iterator[String], instance.terminatedStatement(_), verboseFailures = true), input) +} diff --git a/src/main/scala/sparsity/common/DataType.scala b/src/main/scala/sparsity/common/DataType.scala new file mode 100644 index 0000000..2ab0bc6 --- /dev/null +++ b/src/main/scala/sparsity/common/DataType.scala @@ -0,0 +1,37 @@ +package sparsity.common + +/** Base trait for SQL data types. Supports common types shared across dialects and dialect-specific extensions. + */ +sealed trait DataType + +// Common data types (shared across dialects) +case class VarChar(length: Int) extends DataType +case class Char(length: Int) extends DataType +case class Integer() extends DataType +case class BigInt() extends DataType +case class Decimal(precision: Int, scale: Int) extends DataType +case class Float() extends DataType +case class Double() extends DataType +case class Date() extends DataType +case class Timestamp(precision: Option[Int]) extends DataType +case class Boolean() extends DataType +case class Text() extends DataType + +// Oracle-specific data types +case class VarChar2(length: Int, byteOrChar: Option[String]) extends DataType +case class Number(precision: Option[Int], scale: Option[Int]) extends DataType +case class Clob() extends DataType +case class Blob() extends DataType +case class NVarChar2(length: Int) extends DataType +case class Raw(length: Int) extends DataType + +// Postgres-specific data types +case class Serial() extends DataType +case class BigSerial() extends DataType +case class Json() extends DataType +case class Jsonb() extends DataType +case class Uuid() extends DataType +case class Inet() extends DataType +case class Cidr() extends DataType +case class MacAddr() extends DataType +case class Array(elementType: DataType) extends DataType diff --git a/src/main/scala/sparsity/common/Name.scala b/src/main/scala/sparsity/common/Name.scala new file mode 100644 index 0000000..ddbc48e --- /dev/null +++ b/src/main/scala/sparsity/common/Name.scala @@ -0,0 +1,45 @@ +package sparsity.common + +import sparsity.common.expression.ToSql + +case class Name(name: String, quoted: scala.Boolean = false) extends ToSql { + override def equals(other: Any): scala.Boolean = + other match { + case s: String => equals(s) + case n: Name => equals(n) + case _ => false + } + def equals(other: Name): scala.Boolean = + if (quoted || other.quoted) { name.equals(other.name) } + else { name.equalsIgnoreCase(other.name) } + def equals(other: String): scala.Boolean = + if (quoted) { name.equals(other) } + else { name.equalsIgnoreCase(other) } + override def toSql = if (quoted) { "`" + name + "`" } + else { name } + def +(other: Name) = Name(name + other.name, quoted || other.quoted) + def +(other: String) = Name(name + other, quoted) + + def lower = if (quoted) { name } + else { name.toLowerCase } + def upper = if (quoted) { name } + else { name.toUpperCase } + + def extendUpper(template: String) = + Name(template.replace("\\$", upper), quoted) + def extendLower(template: String) = + Name(template.replace("\\$", lower), quoted) + + def withPrefix(other: String) = Name(other + name, quoted) + def withSuffix(other: String) = Name(name + other, quoted) +} + +class StringNameMatch(cmp: String) { + def unapply(name: Name): Option[Unit] = + if (name.equals(cmp): scala.Boolean) { return Some(()) } + else { return None } +} + +object NameMatch { + def apply(cmp: String) = new StringNameMatch(cmp) +} diff --git a/src/main/scala/sparsity/common/alter/AlterView.scala b/src/main/scala/sparsity/common/alter/AlterView.scala new file mode 100644 index 0000000..99bb6d5 --- /dev/null +++ b/src/main/scala/sparsity/common/alter/AlterView.scala @@ -0,0 +1,10 @@ +package sparsity.common.alter + +import sparsity.common.expression.ToSql + +sealed abstract class AlterViewAction extends ToSql + +case class Materialize(add: Boolean) extends AlterViewAction { + override def toSql = (if (add) { "" } + else { "DROP " }) + "MATERIALIZE" +} diff --git a/src/main/scala/sparsity/common/codec/CirceCodecs.scala b/src/main/scala/sparsity/common/codec/CirceCodecs.scala new file mode 100644 index 0000000..3626606 --- /dev/null +++ b/src/main/scala/sparsity/common/codec/CirceCodecs.scala @@ -0,0 +1,581 @@ +package sparsity.common.codec + +import io.circe.{Decoder, DecodingFailure, Encoder, HCursor} +import io.circe.{Json => CirceJson} +import io.circe.generic.extras.Configuration +import io.circe.generic.extras.semiauto.{deriveConfiguredDecoder, deriveConfiguredEncoder} +import io.circe.generic.semiauto.{deriveDecoder, deriveEncoder} +import io.circe.syntax._ +import cats.syntax.traverse._ +import sparsity.common.{ + Array => SArray, + BigInt => SBigInt, + BigSerial => SBigSerial, + Blob => SBlob, + Boolean => SBoolean, + Char => SChar, + Cidr => SCidr, + Clob => SClob, + DataType, + Date => SDate, + Decimal => SDecimal, + Double => SDouble, + Float => SFloat, + Inet => SInet, + Integer => SInteger, + Json => SJson, + Jsonb => SJsonb, + MacAddr => SMacAddr, + NVarChar2 => SNVarChar2, + Name, + Number => SNumber, + Raw => SRaw, + Serial => SSerial, + Text => SText, + Timestamp => STimestamp, + Uuid => SUuid, + VarChar => SVarChar, + VarChar2 => SVarChar2 +} +import sparsity.common.expression._ +import sparsity.common.statement._ +import sparsity.common.select._ +import sparsity.common.alter._ +import sparsity.oracle._ + +/** Circe codecs for all parser result datatypes */ +object CirceCodecs { + // Configuration for sealed trait/class hierarchies with "type" discriminator + implicit val config: Configuration = Configuration.default.withDiscriminator("type") + + // Name codec + implicit val nameEncoder: Encoder[Name] = deriveEncoder[Name] + implicit val nameDecoder: Decoder[Name] = deriveDecoder[Name] + + // Function codec (needs explicit codec due to generic derivation issues) + implicit val functionEncoder: Encoder[Function] = deriveEncoder[Function] + implicit val functionDecoder: Decoder[Function] = deriveDecoder[Function] + + // DataType codecs + implicit val dataTypeEncoder: Encoder[DataType] = Encoder.instance { + case v: SVarChar => + CirceJson.obj("type" -> CirceJson.fromString("VarChar"), "length" -> CirceJson.fromInt(v.length)) + case c: SChar => CirceJson.obj("type" -> CirceJson.fromString("Char"), "length" -> CirceJson.fromInt(c.length)) + case SInteger() => CirceJson.obj("type" -> CirceJson.fromString("Integer")) + case SBigInt() => CirceJson.obj("type" -> CirceJson.fromString("BigInt")) + case d: SDecimal => + CirceJson.obj( + "type" -> CirceJson.fromString("Decimal"), + "precision" -> CirceJson.fromInt(d.precision), + "scale" -> CirceJson.fromInt(d.scale) + ) + case SFloat() => CirceJson.obj("type" -> CirceJson.fromString("Float")) + case SDouble() => CirceJson.obj("type" -> CirceJson.fromString("Double")) + case SDate() => CirceJson.obj("type" -> CirceJson.fromString("Date")) + case t: STimestamp => CirceJson.obj("type" -> CirceJson.fromString("Timestamp"), "precision" -> t.precision.asJson) + case SBoolean() => CirceJson.obj("type" -> CirceJson.fromString("Boolean")) + case SText() => CirceJson.obj("type" -> CirceJson.fromString("Text")) + case v: SVarChar2 => + CirceJson.obj( + "type" -> CirceJson.fromString("VarChar2"), + "length" -> CirceJson.fromInt(v.length), + "byteOrChar" -> v.byteOrChar.asJson + ) + case n: SNumber => + CirceJson.obj( + "type" -> CirceJson.fromString("Number"), + "precision" -> n.precision.asJson, + "scale" -> n.scale.asJson + ) + case SClob() => CirceJson.obj("type" -> CirceJson.fromString("Clob")) + case SBlob() => CirceJson.obj("type" -> CirceJson.fromString("Blob")) + case n: SNVarChar2 => + CirceJson.obj("type" -> CirceJson.fromString("NVarChar2"), "length" -> CirceJson.fromInt(n.length)) + case r: SRaw => CirceJson.obj("type" -> CirceJson.fromString("Raw"), "length" -> CirceJson.fromInt(r.length)) + case SSerial() => CirceJson.obj("type" -> CirceJson.fromString("Serial")) + case SBigSerial() => CirceJson.obj("type" -> CirceJson.fromString("BigSerial")) + case SJson() => CirceJson.obj("type" -> CirceJson.fromString("Json")) + case SJsonb() => CirceJson.obj("type" -> CirceJson.fromString("Jsonb")) + case SUuid() => CirceJson.obj("type" -> CirceJson.fromString("Uuid")) + case SInet() => CirceJson.obj("type" -> CirceJson.fromString("Inet")) + case SCidr() => CirceJson.obj("type" -> CirceJson.fromString("Cidr")) + case SMacAddr() => CirceJson.obj("type" -> CirceJson.fromString("MacAddr")) + case a: SArray => + CirceJson.obj("type" -> CirceJson.fromString("Array"), "elementType" -> dataTypeEncoder(a.elementType)) + } + + implicit val dataTypeDecoder: Decoder[DataType] = Decoder.instance { c => + c.get[String]("type").flatMap { + case "VarChar" => c.get[Int]("length").map(SVarChar.apply) + case "Char" => c.get[Int]("length").map(SChar.apply) + case "Integer" => Right(SInteger()) + case "BigInt" => Right(SBigInt()) + case "Decimal" => + for { + p <- c.get[Int]("precision") + s <- c.get[Int]("scale") + } yield SDecimal(p, s) + case "Float" => Right(SFloat()) + case "Double" => Right(SDouble()) + case "Date" => Right(SDate()) + case "Timestamp" => c.get[Option[Int]]("precision").map(STimestamp.apply) + case "Boolean" => Right(SBoolean()) + case "Text" => Right(SText()) + case "VarChar2" => + for { + l <- c.get[Int]("length") + boc <- c.get[Option[String]]("byteOrChar") + } yield SVarChar2(l, boc) + case "Number" => + for { + p <- c.get[Option[Int]]("precision") + s <- c.get[Option[Int]]("scale") + } yield SNumber(p, s) + case "Clob" => Right(SClob()) + case "Blob" => Right(SBlob()) + case "NVarChar2" => c.get[Int]("length").map(SNVarChar2.apply) + case "Raw" => c.get[Int]("length").map(SRaw.apply) + case "Serial" => Right(SSerial()) + case "BigSerial" => Right(SBigSerial()) + case "Json" => Right(SJson()) + case "Jsonb" => Right(SJsonb()) + case "Uuid" => Right(SUuid()) + case "Inet" => Right(SInet()) + case "Cidr" => Right(SCidr()) + case "MacAddr" => Right(SMacAddr()) + case "Array" => c.get[DataType]("elementType")(dataTypeDecoder).map(SArray.apply) + case other => Left(DecodingFailure(s"Unknown DataType: $other", c.history)) + } + } + + // Enum codecs + implicit val arithmeticOpEncoder: Encoder[Arithmetic.Op] = Encoder.encodeString.contramap(_.toString) + implicit val arithmeticOpDecoder: Decoder[Arithmetic.Op] = Decoder.decodeString.emap { s => + try Right(Arithmetic.withName(s).asInstanceOf[Arithmetic.Op]) + catch { case _: NoSuchElementException => Left(s"Unknown Arithmetic.Op: $s") } + } + + implicit val comparisonOpEncoder: Encoder[Comparison.Op] = Encoder.encodeString.contramap(_.toString) + implicit val comparisonOpDecoder: Decoder[Comparison.Op] = Decoder.decodeString.emap { s => + try Right(Comparison.withName(s).asInstanceOf[Comparison.Op]) + catch { case _: NoSuchElementException => Left(s"Unknown Comparison.Op: $s") } + } + + implicit val joinTypeEncoder: Encoder[Join.Type] = Encoder.encodeString.contramap(_.toString) + implicit val joinTypeDecoder: Decoder[Join.Type] = Decoder.decodeString.emap { s => + try Right(Join.withName(s).asInstanceOf[Join.Type]) + catch { case _: NoSuchElementException => Left(s"Unknown Join.Type: $s") } + } + + implicit val unionTypeEncoder: Encoder[Union.Type] = Encoder.encodeString.contramap(_.toString) + implicit val unionTypeDecoder: Decoder[Union.Type] = Decoder.decodeString.emap { s => + try Right(Union.withName(s).asInstanceOf[Union.Type]) + catch { case _: NoSuchElementException => Left(s"Unknown Union.Type: $s") } + } + + // Expression codecs (recursive) + implicit val expressionEncoder: Encoder[Expression] = Encoder.instance { + case lp: LongPrimitive => + CirceJson.obj("type" -> CirceJson.fromString("LongPrimitive"), "value" -> CirceJson.fromLong(lp.v)) + case dp: DoublePrimitive => + CirceJson.obj("type" -> CirceJson.fromString("DoublePrimitive"), "value" -> CirceJson.fromDoubleOrNull(dp.v)) + case sp: StringPrimitive => + CirceJson.obj("type" -> CirceJson.fromString("StringPrimitive"), "value" -> CirceJson.fromString(sp.v)) + case bp: BooleanPrimitive => + CirceJson.obj("type" -> CirceJson.fromString("BooleanPrimitive"), "value" -> CirceJson.fromBoolean(bp.v)) + case NullPrimitive() => CirceJson.obj("type" -> CirceJson.fromString("NullPrimitive")) + case c: Column => + CirceJson.obj("type" -> CirceJson.fromString("Column"), "column" -> c.column.asJson, "table" -> c.table.asJson) + case a: Arithmetic => + CirceJson.obj( + "type" -> CirceJson.fromString("Arithmetic"), + "lhs" -> a.lhs.asJson, + "op" -> a.op.asJson, + "rhs" -> a.rhs.asJson + ) + case c: Comparison => + CirceJson.obj( + "type" -> CirceJson.fromString("Comparison"), + "lhs" -> c.lhs.asJson, + "op" -> c.op.asJson, + "rhs" -> c.rhs.asJson + ) + case f: Function => + CirceJson.obj( + "type" -> CirceJson.fromString("Function"), + "name" -> f.name.asJson, + "params" -> f.params.asJson, + "distinct" -> CirceJson.fromBoolean(f.distinct) + ) + case wf: WindowFunction => + CirceJson.obj( + "type" -> CirceJson.fromString("WindowFunction"), + "function" -> wf.function.asJson, + "partitionBy" -> wf.partitionBy.asJson, + "orderBy" -> wf.orderBy.asJson + ) + case JDBCVar() => CirceJson.obj("type" -> CirceJson.fromString("JDBCVar")) + case sq: Subquery => CirceJson.obj("type" -> CirceJson.fromString("Subquery"), "query" -> sq.query.asJson) + case cwe: CaseWhenElse => + CirceJson.obj( + "type" -> CirceJson.fromString("CaseWhenElse"), + "target" -> cwe.target.asJson, + "cases" -> cwe.cases.map { case (e1, e2) => CirceJson.obj("when" -> e1.asJson, "then" -> e2.asJson) }.asJson, + "otherwise" -> cwe.otherwise.asJson + ) + case in: IsNull => CirceJson.obj("type" -> CirceJson.fromString("IsNull"), "target" -> in.target.asJson) + case n: Not => CirceJson.obj("type" -> CirceJson.fromString("Not"), "target" -> n.target.asJson) + case cast: Cast => + CirceJson.obj( + "type" -> CirceJson.fromString("Cast"), + "expression" -> cast.expression.asJson, + "typeName" -> cast.t.asJson + ) + case ie: InExpression => + CirceJson.obj( + "type" -> CirceJson.fromString("InExpression"), + "expression" -> ie.expression.asJson, + "source" -> (ie.source match { + case Left(exprs) => CirceJson.obj("type" -> CirceJson.fromString("Expressions"), "values" -> exprs.asJson) + case Right(query) => CirceJson.obj("type" -> CirceJson.fromString("SelectBody"), "query" -> query.asJson) + }) + ) + case _: NegatableExpression => // Handled by IsNull and InExpression cases above + throw new IllegalStateException("Unhandled NegatableExpression subtype") + } + + implicit val expressionDecoder: Decoder[Expression] = Decoder.instance { c => + c.get[String]("type").flatMap { + case "LongPrimitive" => c.get[Long]("value").map(LongPrimitive.apply) + case "DoublePrimitive" => c.get[Double]("value").map(DoublePrimitive.apply) + case "StringPrimitive" => c.get[String]("value").map(StringPrimitive.apply) + case "BooleanPrimitive" => c.get[scala.Boolean]("value").map(BooleanPrimitive.apply) + case "NullPrimitive" => Right(NullPrimitive()) + case "Column" => + for { + col <- c.get[Name]("column") + table <- c.get[Option[Name]]("table") + } yield Column(col, table) + case "Arithmetic" => + for { + lhs <- c.get[Expression]("lhs") + op <- c.get[Arithmetic.Op]("op") + rhs <- c.get[Expression]("rhs") + } yield Arithmetic(lhs, op, rhs) + case "Comparison" => + for { + lhs <- c.get[Expression]("lhs") + op <- c.get[Comparison.Op]("op") + rhs <- c.get[Expression]("rhs") + } yield Comparison(lhs, op, rhs) + case "Function" => + for { + name <- c.get[Name]("name") + params <- c.get[Option[Seq[Expression]]]("params") + distinct <- c.get[scala.Boolean]("distinct") + } yield Function(name, params, distinct) + case "WindowFunction" => + for { + func <- c.get[Function]("function") + partitionBy <- c.get[Option[Seq[Expression]]]("partitionBy") + orderBy <- c.get[Seq[OrderBy]]("orderBy") + } yield WindowFunction(func, partitionBy, orderBy) + case "JDBCVar" => Right(JDBCVar()) + case "Subquery" => c.get[SelectBody]("query").map(Subquery.apply) + case "CaseWhenElse" => + for { + target <- c.get[Option[Expression]]("target") + cases <- c.get[Seq[CirceJson]]("cases").flatMap { jsons => + jsons.traverse { json => + for { + when <- json.hcursor.get[Expression]("when") + thenExpr <- json.hcursor.get[Expression]("then") + } yield (when, thenExpr) + } + } + otherwise <- c.get[Expression]("otherwise") + } yield CaseWhenElse(target, cases, otherwise) + case "IsNull" => c.get[Expression]("target").map(IsNull.apply) + case "Not" => c.get[Expression]("target").map(Not.apply) + case "Cast" => + for { + expr <- c.get[Expression]("expression") + typeName <- c.get[Name]("typeName") + } yield Cast(expr, typeName) + case "InExpression" => + for { + expr <- c.get[Expression]("expression") + source <- c.get[CirceJson]("source").flatMap { json => + json.hcursor.get[String]("type").flatMap { + case "Expressions" => json.hcursor.get[Seq[Expression]]("values").map(Left.apply) + case "SelectBody" => json.hcursor.get[SelectBody]("query").map(Right.apply) + case other => Left(DecodingFailure(s"Unknown InExpression source type: $other", json.hcursor.history)) + } + } + } yield InExpression(expr, source) + case other => Left(DecodingFailure(s"Unknown Expression type: $other", c.history)) + } + } + + // SelectTarget codecs (using circe-generic-extras) + implicit val selectTargetEncoder: Encoder[SelectTarget] = deriveConfiguredEncoder[SelectTarget] + implicit val selectTargetDecoder: Decoder[SelectTarget] = deriveConfiguredDecoder[SelectTarget] + + // OrderBy codec + implicit val orderByEncoder: Encoder[OrderBy] = deriveEncoder[OrderBy] + implicit val orderByDecoder: Decoder[OrderBy] = deriveDecoder[OrderBy] + + // FromElement codecs (using circe-generic-extras) + implicit val fromElementEncoder: Encoder[FromElement] = deriveConfiguredEncoder[FromElement] + implicit val fromElementDecoder: Decoder[FromElement] = deriveConfiguredDecoder[FromElement] + + // SelectBody codec (recursive) + implicit val selectBodyEncoder: Encoder[SelectBody] = deriveEncoder[SelectBody] + implicit val selectBodyDecoder: Decoder[SelectBody] = deriveDecoder[SelectBody] + + // ColumnAnnotation codecs (using circe-generic-extras) + implicit val columnAnnotationEncoder: Encoder[ColumnAnnotation] = deriveConfiguredEncoder[ColumnAnnotation] + implicit val columnAnnotationDecoder: Decoder[ColumnAnnotation] = deriveConfiguredDecoder[ColumnAnnotation] + + // PrimitiveValue codec (needed for ColumnDefinition) + implicit val primitiveValueEncoder: Encoder[PrimitiveValue] = expressionEncoder.contramap(identity[Expression]) + implicit val primitiveValueDecoder: Decoder[PrimitiveValue] = expressionDecoder.emap { + case pv: PrimitiveValue => Right(pv) + case other => Left(s"Expected PrimitiveValue but got ${other.getClass.getSimpleName}") + } + + // ColumnDefinition codec + implicit val columnDefinitionEncoder: Encoder[ColumnDefinition] = deriveEncoder[ColumnDefinition] + implicit val columnDefinitionDecoder: Decoder[ColumnDefinition] = deriveDecoder[ColumnDefinition] + + // TableAnnotation codecs (using circe-generic-extras) + implicit val tableAnnotationEncoder: Encoder[TableAnnotation] = deriveConfiguredEncoder[TableAnnotation] + implicit val tableAnnotationDecoder: Decoder[TableAnnotation] = deriveConfiguredDecoder[TableAnnotation] + + // InsertValues codecs (using circe-generic-extras) + implicit val insertValuesEncoder: Encoder[InsertValues] = deriveConfiguredEncoder[InsertValues] + implicit val insertValuesDecoder: Decoder[InsertValues] = deriveConfiguredDecoder[InsertValues] + + // WithClause codec + implicit val withClauseEncoder: Encoder[WithClause] = deriveEncoder[WithClause] + implicit val withClauseDecoder: Decoder[WithClause] = deriveDecoder[WithClause] + + // AlterViewAction codecs (using circe-generic-extras) + implicit val alterViewActionEncoder: Encoder[AlterViewAction] = deriveConfiguredEncoder[AlterViewAction] + implicit val alterViewActionDecoder: Decoder[AlterViewAction] = deriveConfiguredDecoder[AlterViewAction] + + // Statement codecs (base and common) + implicit val statementEncoder: Encoder[Statement] = Encoder.instance { + case EmptyStatement() => CirceJson.obj("type" -> CirceJson.fromString("EmptyStatement")) + case u: Unparseable => + CirceJson.obj("type" -> CirceJson.fromString("Unparseable"), "content" -> CirceJson.fromString(u.content)) + case s: Select => + CirceJson.obj( + "type" -> CirceJson.fromString("Select"), + "body" -> s.body.asJson, + "withClause" -> s.withClause.asJson + ) + case u: Update => + CirceJson.obj( + "type" -> CirceJson.fromString("Update"), + "table" -> u.table.asJson, + "set" -> u.set.map { case (n, e) => CirceJson.obj("name" -> n.asJson, "expression" -> e.asJson) }.asJson, + "where" -> u.where.asJson + ) + case d: Delete => + CirceJson.obj("type" -> CirceJson.fromString("Delete"), "table" -> d.table.asJson, "where" -> d.where.asJson) + case i: Insert => + CirceJson.obj( + "type" -> CirceJson.fromString("Insert"), + "table" -> i.table.asJson, + "columns" -> i.columns.asJson, + "values" -> i.values.asJson, + "orReplace" -> CirceJson.fromBoolean(i.orReplace) + ) + case ct: CreateTable => + CirceJson.obj( + "type" -> CirceJson.fromString("CreateTable"), + "name" -> ct.name.asJson, + "orReplace" -> CirceJson.fromBoolean(ct.orReplace), + "columns" -> ct.columns.asJson, + "annotations" -> ct.annotations.asJson + ) + case cta: CreateTableAs => + CirceJson.obj( + "type" -> CirceJson.fromString("CreateTableAs"), + "name" -> cta.name.asJson, + "orReplace" -> CirceJson.fromBoolean(cta.orReplace), + "query" -> cta.query.asJson + ) + case cv: CreateView => + CirceJson.obj( + "type" -> CirceJson.fromString("CreateView"), + "name" -> cv.name.asJson, + "orReplace" -> CirceJson.fromBoolean(cv.orReplace), + "query" -> cv.query.asJson, + "materialized" -> CirceJson.fromBoolean(cv.materialized), + "temporary" -> CirceJson.fromBoolean(cv.temporary) + ) + case av: AlterView => + CirceJson.obj("type" -> CirceJson.fromString("AlterView"), "name" -> av.name.asJson, "action" -> av.action.asJson) + case dt: DropTable => + CirceJson.obj( + "type" -> CirceJson.fromString("DropTable"), + "name" -> dt.name.asJson, + "ifExists" -> CirceJson.fromBoolean(dt.ifExists) + ) + case dv: DropView => + CirceJson.obj( + "type" -> CirceJson.fromString("DropView"), + "name" -> dv.name.asJson, + "ifExists" -> CirceJson.fromBoolean(dv.ifExists) + ) + case e: Explain => CirceJson.obj("type" -> CirceJson.fromString("Explain"), "query" -> e.query.asJson) + // Oracle statements + case os: OracleStatement => oracleStatementEncoder(os) + } + + implicit val statementDecoder: Decoder[Statement] = Decoder.instance { c => + c.get[String]("type").flatMap { + case "EmptyStatement" => Right(EmptyStatement()) + case "Unparseable" => c.get[String]("content").map(Unparseable.apply) + case "Select" => + for { + body <- c.get[SelectBody]("body") + withClause <- c.get[Seq[WithClause]]("withClause") + } yield Select(body, withClause) + case "Update" => + for { + table <- c.get[Name]("table") + set <- c.get[Seq[CirceJson]]("set").flatMap { jsons => + jsons.traverse { json => + for { + name <- json.hcursor.get[Name]("name") + expr <- json.hcursor.get[Expression]("expression") + } yield (name, expr) + } + } + where <- c.get[Option[Expression]]("where") + } yield Update(table, set, where) + case "Delete" => + for { + table <- c.get[Name]("table") + where <- c.get[Option[Expression]]("where") + } yield Delete(table, where) + case "Insert" => + for { + table <- c.get[Name]("table") + columns <- c.get[Option[Seq[Name]]]("columns") + values <- c.get[InsertValues]("values") + orReplace <- c.get[Boolean]("orReplace") + } yield Insert(table, columns, values, orReplace) + case "CreateTable" => + for { + name <- c.get[Name]("name") + orReplace <- c.get[Boolean]("orReplace") + columns <- c.get[Seq[ColumnDefinition]]("columns") + annotations <- c.get[Seq[TableAnnotation]]("annotations") + } yield CreateTable(name, orReplace, columns, annotations) + case "CreateTableAs" => + for { + name <- c.get[Name]("name") + orReplace <- c.get[Boolean]("orReplace") + query <- c.get[SelectBody]("query") + } yield CreateTableAs(name, orReplace, query) + case "CreateView" => + for { + name <- c.get[Name]("name") + orReplace <- c.get[Boolean]("orReplace") + query <- c.get[SelectBody]("query") + materialized <- c.get[Boolean]("materialized") + temporary <- c.get[Boolean]("temporary") + } yield CreateView(name, orReplace, query, materialized, temporary) + case "AlterView" => + for { + name <- c.get[Name]("name") + action <- c.get[AlterViewAction]("action") + } yield AlterView(name, action) + case "DropTable" => + for { + name <- c.get[Name]("name") + ifExists <- c.get[Boolean]("ifExists") + } yield DropTable(name, ifExists) + case "DropView" => + for { + name <- c.get[Name]("name") + ifExists <- c.get[Boolean]("ifExists") + } yield DropView(name, ifExists) + case "Explain" => c.get[SelectBody]("query").map(Explain.apply) + case other => oracleStatementDecoderHelper(c, other) + } + } + + // Oracle-specific codecs + // ReferentialAction codecs + implicit val referentialActionEncoder: Encoder[ReferentialAction] = Encoder.encodeString.contramap(_.sql) + implicit val referentialActionDecoder: Decoder[ReferentialAction] = Decoder.decodeString.emap { + case "NO ACTION" => Right(ReferentialAction.NoAction) + case "RESTRICT" => Right(ReferentialAction.Restrict) + case "CASCADE" => Right(ReferentialAction.Cascade) + case "SET NULL" => Right(ReferentialAction.SetNull) + case "SET DEFAULT" => Right(ReferentialAction.SetDefault) + case other => Left(s"Unknown ReferentialAction: $other") + } + + // TableConstraint codecs + implicit val tableConstraintEncoder: Encoder[TableConstraint] = deriveConfiguredEncoder[TableConstraint] + implicit val tableConstraintDecoder: Decoder[TableConstraint] = deriveConfiguredDecoder[TableConstraint] + + // AlterTableAction codecs (using circe-generic-extras) + implicit val alterTableActionEncoder: Encoder[AlterTableAction] = deriveConfiguredEncoder[AlterTableAction] + implicit val alterTableActionDecoder: Decoder[AlterTableAction] = deriveConfiguredDecoder[AlterTableAction] + + implicit val columnModificationEncoder: Encoder[ColumnModification] = deriveEncoder[ColumnModification] + implicit val columnModificationDecoder: Decoder[ColumnModification] = deriveDecoder[ColumnModification] + + implicit val storageClauseEncoder: Encoder[StorageClause] = deriveEncoder[StorageClause] + implicit val storageClauseDecoder: Decoder[StorageClause] = deriveDecoder[StorageClause] + + // Custom codec for (Name, Expression) tuples to encode as objects with "name" and "expression" fields + implicit val nameExpressionTupleEncoder: Encoder[(Name, Expression)] = Encoder.instance { case (name, expr) => + CirceJson.obj("name" -> name.asJson, "expression" -> expr.asJson) + } + implicit val nameExpressionTupleDecoder: Decoder[(Name, Expression)] = Decoder.instance { c => + for { + name <- c.get[Name]("name") + expr <- c.get[Expression]("expression") + } yield (name, expr) + } + + // Custom codec for Either[SelectBody, Seq[ColumnDefinition]] used in OracleCreateTable.schema + // Encodes as an object with "type" field ("SelectBody" or "Columns") and corresponding data + implicit val oracleCreateTableSchemaEncoder: Encoder[Either[SelectBody, Seq[ColumnDefinition]]] = + Encoder.instance { + case Left(query) => + CirceJson.obj("type" -> CirceJson.fromString("SelectBody"), "query" -> query.asJson) + case Right(columns) => + CirceJson.obj("type" -> CirceJson.fromString("Columns"), "columns" -> columns.asJson) + } + implicit val oracleCreateTableSchemaDecoder: Decoder[Either[SelectBody, Seq[ColumnDefinition]]] = + Decoder.instance { c => + c.get[String]("type").flatMap { + case "SelectBody" => c.get[SelectBody]("query").map(Left.apply) + case "Columns" => c.get[Seq[ColumnDefinition]]("columns").map(Right.apply) + case other => + Left(DecodingFailure(s"Unknown OracleCreateTable schema type: $other", c.history)) + } + } + + // OracleStatement codecs (using circe-generic-extras) + // Note: The tuple codec above handles Update.set and OracleUpdate.set fields + implicit val oracleStatementEncoder: Encoder[OracleStatement] = deriveConfiguredEncoder[OracleStatement] + implicit val oracleStatementDecoder: Decoder[OracleStatement] = deriveConfiguredDecoder[OracleStatement] + + // Helper function to decode OracleStatement from Statement decoder + def oracleStatementDecoderHelper(c: HCursor, typeName: String): Decoder.Result[Statement] = { + // Use the automatic decoder for OracleStatement + oracleStatementDecoder(c).left.map { failure => + DecodingFailure(s"Unknown Statement type: $typeName - ${failure.message}", c.history) + } + } +} diff --git a/src/main/scala/sparsity/common/expression/Expression.scala b/src/main/scala/sparsity/common/expression/Expression.scala new file mode 100644 index 0000000..cb63679 --- /dev/null +++ b/src/main/scala/sparsity/common/expression/Expression.scala @@ -0,0 +1,367 @@ +package sparsity.common.expression + +import sparsity.common.Name +import sparsity.common.select.{OrderBy, SelectBody} + +trait ToSql { + def toSql: String +} + +sealed abstract class Expression extends ToSql { + def needsParenthesis: Boolean + def children: Seq[Expression] + def rebuild(newChildren: Seq[Expression]): Expression +} + +object Expression { + def parenthesize(e: Expression) = + if (e.needsParenthesis) { "(" + e.toSql + ")" } + else { e.toSql } + def escapeString(s: String) = + s.replaceAll("'", "''") +} + +sealed abstract class PrimitiveValue extends Expression { + def needsParenthesis = false + def children: Seq[Expression] = Seq() + def rebuild(newChildren: Seq[Expression]): Expression = this +} +case class LongPrimitive(v: Long) extends PrimitiveValue { override def toSql = v.toString } +case class DoublePrimitive(v: Double) extends PrimitiveValue { override def toSql = v.toString } +case class StringPrimitive(v: String) extends PrimitiveValue { + override def toSql = "'" + Expression.escapeString(v.toString) + "'" +} +case class BooleanPrimitive(v: Boolean) extends PrimitiveValue { + override def toSql = v.toString +} +case class NullPrimitive() extends PrimitiveValue { override def toSql = "NULL" } + +/** A column reference 'Table.Col' + */ +case class Column(column: Name, table: Option[Name] = None) extends Expression { + override def toSql = (table.toSeq ++ Seq(column)).map(_.toSql).mkString(".") + def needsParenthesis = false + def children: Seq[Expression] = Seq() + def rebuild(newChildren: Seq[Expression]): Expression = this +} + +/** Any simple binary arithmetic expression See Arithmetic.scala for an enumeration of the possibilities *not* a case + * class to avoid namespace collisions. Arithmetic defines apply/unapply explicitly + */ +class Arithmetic(val lhs: Expression, val op: Arithmetic.Op, val rhs: Expression) extends Expression { + override def toSql = + Expression.parenthesize(lhs) + " " + + Arithmetic.opString(op) + " " + + Expression.parenthesize(rhs) + def needsParenthesis = true + override def equals(other: Any): Boolean = + other match { + case Arithmetic(otherlhs, otherop, otherrhs) => + lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs) + case _ => false + } + def children: Seq[Expression] = Seq(lhs, rhs) + def rebuild(c: Seq[Expression]): Expression = Arithmetic(c(0), op, c(1)) +} + +object Arithmetic extends Enumeration { + type Op = Value + val Add, Sub, Mult, Div, And, Or, BitAnd, BitOr, ShiftLeft, ShiftRight = Value + + def apply(lhs: Expression, op: Op, rhs: Expression) = + new Arithmetic(lhs, op, rhs) + def apply(lhs: Expression, op: String, rhs: Expression) = + new Arithmetic(lhs, fromString(op), rhs) + + def unapply(e: Arithmetic): Option[(Expression, Op, Expression)] = + Some((e.lhs, e.op, e.rhs)) + + /** Regular expresion to match any and all binary operations + */ + def matchRegex = """\+|-|\*|/|\|\||&&|\||&""".r + + /** Convert from the operator's string encoding to its Arith.Op rep + */ + def fromString(a: String) = + a.toUpperCase match { + case "+" => Add + case "-" => Sub + case "*" => Mult + case "/" => Div + case "&" => BitAnd + case "|" => BitOr + case "<<" => ShiftLeft + case ">>" => ShiftRight + case "&&" => And + case "||" => Or + case "AND" => And + case "OR" => And + case x => throw new Exception("Invalid operand '" + x + "'") + } + + /** Convert from the operator's Arith.Op representation to a string + */ + def opString(v: Op): String = + v match { + case Add => "+" + case Sub => "-" + case Mult => "*" + case Div => "/" + case BitAnd => "&" + case BitOr => "|" + case ShiftLeft => "<<" + case ShiftRight => ">>" + case And => "AND" + case Or => "OR" + } + + /** Is this binary operation a boolean operator (AND/OR) + */ + def isBool(v: Op): Boolean = + v match { + case And | Or => true + case _ => false + } + + /** Is this binary operation a numeric operator (+, -, *, /, & , |) + */ + def isNumeric(v: Op): Boolean = !isBool(v) + +} + +/** Any simple comparison expression See Comparison.scala for an enumeration of the possibilities *not* a case class to + * avoid namespace collisions. Comparison defines apply/unapply explicitly + */ +class Comparison(val lhs: Expression, val op: Comparison.Op, val rhs: Expression) extends Expression { + override def toSql = + Expression.parenthesize(lhs) + " " + + Comparison.opString(op) + " " + + Expression.parenthesize(rhs) + def needsParenthesis = true + override def equals(other: Any): Boolean = + other match { + case Comparison(otherlhs, otherop, otherrhs) => + lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs) + case _ => false + } + def children: Seq[Expression] = Seq(lhs, rhs) + def rebuild(c: Seq[Expression]): Expression = Comparison(c(0), op, c(1)) +} +object Comparison extends Enumeration { + type Op = Value + val Eq, Neq, Gt, Lt, Gte, Lte, Like, NotLike, RLike, NotRLike = Value + + val strings = Map( + "=" -> Eq, + "==" -> Eq, + "!=" -> Neq, + "<>" -> Neq, + ">" -> Gt, + "<" -> Lt, + ">=" -> Gte, + "<=" -> Lte, + "LIKE" -> Like, // SQL-style LIKE expression + "RLIKE" -> RLike, // Regular expression lookup + "NOT LIKE" -> NotLike, // Inverse LIKE + "NOT RLIKE" -> RLike // Inverse NOT LIKE + ) + + def apply(lhs: Expression, op: Op, rhs: Expression) = + new Comparison(lhs, op, rhs) + def apply(lhs: Expression, op: String, rhs: Expression) = + new Comparison(lhs, strings(op.toUpperCase), rhs) + + def unapply(e: Comparison): Option[(Expression, Op, Expression)] = + Some((e.lhs, e.op, e.rhs)) + + def negate(v: Op): Op = + v match { + case Eq => Neq + case Neq => Eq + case Gt => Lte + case Gte => Lt + case Lt => Gte + case Lte => Gt + case Like => NotLike + case NotLike => Like + case RLike => NotRLike + case NotRLike => RLike + } + + def flip(v: Op): Option[Op] = + v match { + case Eq => Some(Eq) + case Neq => Some(Neq) + case Gt => Some(Lt) + case Gte => Some(Lte) + case Lt => Some(Gt) + case Lte => Some(Gte) + case Like => None + case NotLike => None + case RLike => None + case NotRLike => None + } + + def opString(v: Op): String = + v match { + case Eq => "=" + case Neq => "<>" + case Gt => ">" + case Gte => ">=" + case Lt => "<" + case Lte => "<=" + case Like => "LIKE" + case NotLike => "NOT LIKE" + case RLike => "RLIKE" + case NotRLike => "NOT RLIKE" + } +} + +case class Function(name: Name, params: Option[Seq[Expression]], distinct: Boolean = false) extends Expression { + override def toSql = + name.toSql + "(" + + (if (distinct) { "DISTINCT " } + else { "" }) + + params.map(_.map(_.toSql).mkString(", ")).getOrElse("*") + + ")" + def needsParenthesis = false + def children: Seq[Expression] = params.getOrElse(Seq()) + def rebuild(c: Seq[Expression]): Expression = + Function(name, params.map(_ => c), distinct) +} + +/** Window function expression: function(...) OVER (PARTITION BY ... ORDER BY ...) + */ +case class WindowFunction( + function: Function, + partitionBy: Option[Seq[Expression]] = None, + orderBy: Seq[OrderBy] = Seq() +) extends Expression { + override def toSql = { + val overClause = Seq("OVER") ++ + (if (partitionBy.isDefined || orderBy.nonEmpty) { + Seq("(") ++ + partitionBy + .map(exprs => Seq("PARTITION BY") ++ exprs.map(_.toSql)) + .getOrElse(Seq()) ++ + (if (partitionBy.isDefined && orderBy.nonEmpty) Seq(",") else Seq()) ++ + (if (orderBy.nonEmpty) Seq("ORDER BY") ++ orderBy.map(_.toSql) else Seq()) ++ + Seq(")") + } else { + Seq("()") + }) + function.toSql + " " + overClause.mkString(" ") + } + def needsParenthesis = false + def children: Seq[Expression] = + function.children ++ partitionBy.getOrElse(Seq()) ++ orderBy.map(_.expression) + def rebuild(c: Seq[Expression]): Expression = { + val funcChildren = function.children + val funcRebuilt = Function(function.name, function.params.map(_ => c.take(funcChildren.length)), function.distinct) + val partitionRebuilt = + partitionBy.map(_ => c.slice(funcChildren.length, funcChildren.length + partitionBy.get.length)) + val orderByRebuilt = orderBy.zipWithIndex.map { case (ob, idx) => + OrderBy(c(funcChildren.length + partitionBy.map(_.length).getOrElse(0) + idx), ob.ascending) + } + WindowFunction(funcRebuilt, partitionRebuilt, orderByRebuilt) + } +} + +case class JDBCVar() extends Expression { + override def toSql = "?" + def needsParenthesis = false + def children: Seq[Expression] = Seq() + def rebuild(c: Seq[Expression]): Expression = this +} + +/** A scalar subquery expression: (SELECT ...) Used in contexts like VALUES clauses where a subquery can be used as a + * value. + */ +case class Subquery(query: SelectBody) extends Expression { + override def toSql = "(" + query.toSql + ")" + def needsParenthesis = false + def children: Seq[Expression] = Seq() // Subquery is a leaf expression + def rebuild(c: Seq[Expression]): Expression = this +} + +case class CaseWhenElse(target: Option[Expression], cases: Seq[(Expression, Expression)], otherwise: Expression) + extends Expression { + override def toSql = + "CASE " + + target.map(_.toSql).getOrElse("") + + cases + .map { clause => + "WHEN " + Expression.parenthesize(clause._1) + + " THEN " + Expression.parenthesize(clause._2) + } + .mkString(" ") + " " + + "ELSE " + otherwise.toSql + " END" + def needsParenthesis = false + def children: Seq[Expression] = + target.toSeq ++ Seq(otherwise) ++ cases.flatMap(x => Seq(x._1, x._2)) + def rebuild(c: Seq[Expression]): Expression = { + val (newTarget, newOtherwise, newCases) = + if (c.length % 2 == 0) { (Some(c.head), c.tail.head, c.tail.tail) } + else { (None, c.head, c.tail) } + CaseWhenElse(newTarget, newCases.grouped(2).map(x => (x(0), x(1))).toSeq, newOtherwise) + } +} + +abstract class NegatableExpression extends Expression { + def toNegatedString: String +} + +case class IsNull(target: Expression) extends NegatableExpression { + override def toSql = + Expression.parenthesize(target) + " IS NULL" + def toNegatedString = + Expression.parenthesize(target) + " IS NOT NULL" + def needsParenthesis = false + def children: Seq[Expression] = Seq(target) + def rebuild(c: Seq[Expression]): Expression = IsNull(c(0)) +} + +case class Not(target: Expression) extends Expression { + override def toSql = + target match { + case neg: NegatableExpression => neg.toNegatedString + case _ => "NOT " + Expression.parenthesize(target) + } + def needsParenthesis = false + def children: Seq[Expression] = Seq(target) + def rebuild(c: Seq[Expression]): Expression = Not(c(0)) +} + +case class Cast(expression: Expression, t: Name) extends Expression { + override def toSql = "CAST(" + expression.toSql + " AS " + t.toSql + ")" + def needsParenthesis = false + def children: Seq[Expression] = Seq(expression) + def rebuild(c: Seq[Expression]): Expression = Cast(c(0), t) +} + +case class InExpression(expression: Expression, source: Either[Seq[Expression], SelectBody]) + extends NegatableExpression { + override def toSql = + Expression.parenthesize(expression) + " IN " + sourceString + override def toNegatedString = + Expression.parenthesize(expression) + " NOT IN " + sourceString + def needsParenthesis = false + def sourceString = + source match { + case Left(elems) => elems.map(Expression.parenthesize(_)).mkString(", ") + case Right(query) => "(" + query.toSql + ")" + } + def children: Seq[Expression] = + Seq(expression) ++ (source match { + case Left(expr) => expr + case Right(_) => Seq() + }) + def rebuild(c: Seq[Expression]): Expression = + InExpression( + c.head, + source match { + case Left(_) => Left(c.tail) + case Right(query) => Right(query) + } + ) +} diff --git a/src/main/scala/sparsity/common/parser/ErrorMessage.scala b/src/main/scala/sparsity/common/parser/ErrorMessage.scala new file mode 100644 index 0000000..1d9ddc3 --- /dev/null +++ b/src/main/scala/sparsity/common/parser/ErrorMessage.scala @@ -0,0 +1,36 @@ +package sparsity.common.parser + +import fastparse.Parsed + +object ErrorMessage { + + def indexToLine(lines: Seq[String], index: Int): (Int, Int) = { + val cumulativeLengths = lines.scanLeft(0)(_ + _.length) + cumulativeLengths.zipWithIndex + .sliding(2) + .collectFirst { + case Seq((prev, _), (curr, i)) if index < curr => + (i - 1, index - prev) + } + .getOrElse((lines.length - 1, 0)) + } + + def format(query: String, msg: Parsed.Failure): String = { + val lines = query.split("\n") + + // println(msg.longMsg) + + val expected: String = "expected " + msg.label + val index: Int = msg.extra.index + // if(msg.extra.stack.isEmpty){ ("parse error", ) } + // else { + // val (expected, index) = msg.extra.stack(0) + // (" expected "+expected, index) + // } + val (lineNumber, linePosition) = indexToLine(lines, index) + return lines(lineNumber) + "\n" + + " " * linePosition + + "^--- " + expected + } + +} diff --git a/src/main/scala/sparsity/parser/ParseException.scala b/src/main/scala/sparsity/common/parser/ParseException.scala similarity index 76% rename from src/main/scala/sparsity/parser/ParseException.scala rename to src/main/scala/sparsity/common/parser/ParseException.scala index 78ef80f..20b37ab 100644 --- a/src/main/scala/sparsity/parser/ParseException.scala +++ b/src/main/scala/sparsity/common/parser/ParseException.scala @@ -1,4 +1,4 @@ -package sparsity.parser +package sparsity.common.parser import fastparse.Parsed diff --git a/src/main/scala/sparsity/common/parser/SQL.scala b/src/main/scala/sparsity/common/parser/SQL.scala new file mode 100644 index 0000000..bed4cf5 --- /dev/null +++ b/src/main/scala/sparsity/common/parser/SQL.scala @@ -0,0 +1,17 @@ +package sparsity.common.parser + +import fastparse._ +import scala.io._ +import java.io._ +import sparsity.common.statement._ +import sparsity.ansi.ANSISQL + +/** Backward-compatible SQL parser object. Delegates to ANSISQL for parsing. + * + * @deprecated + * Consider using ANSISQL, OracleSQL, or PostgresSQL directly for dialect-specific parsing + */ +object SQL { + def apply(input: String): Parsed[Statement] = ANSISQL(input) + def apply(input: Reader): StreamParser[Statement] = ANSISQL(input) +} diff --git a/src/main/scala/sparsity/common/parser/SQLBase.scala b/src/main/scala/sparsity/common/parser/SQLBase.scala new file mode 100644 index 0000000..5b9b75b --- /dev/null +++ b/src/main/scala/sparsity/common/parser/SQLBase.scala @@ -0,0 +1,893 @@ +package sparsity.common.parser + +import fastparse._ +import scala.io._ +import java.io._ +import sparsity.common.Name +import sparsity.common.statement._ +import sparsity.common.select._ +import sparsity.common.alter._ +import sparsity.common.expression.{ + Arithmetic, + BooleanPrimitive, + CaseWhenElse, + Cast, + Column, + Comparison, + DoublePrimitive, + Expression, + Function, + InExpression, + IsNull, + JDBCVar, + LongPrimitive, + Not, + NullPrimitive, + StringPrimitive +} +import sparsity.common.DataType + +/** Base trait for SQL dialect parsers. Provides shared parser methods and configuration hooks for dialect-specific + * behavior. + */ +trait SQLBase { + // Type members allow dialects to extend the Statement ADT + type Stmt <: Statement + // Configuration for simple variations (abstract, implemented by subclasses) + def caseSensitive: Boolean + def statementTerminator: String // ";" or "/" + def supportsCTEs: Boolean + def supportsReturning: Boolean + def supportsIfExists: Boolean + def stringConcatOp: String // "||" or "+" + implicit val whitespaceImplicit: fastparse.Whitespace = MultiLineWhitespace.whitespace + // Expression parsing methods + // These use `this` to access the SQL parser instance, removing hard-coded ANSISQL.instance references + + // Elements methods - inlined from Elements.scala to allow overriding + def anyKeyword[$ : P] = P( + StringInIgnoreCase( + "ADD", + "ALL", + "ALTER", + "AND", + "AS", + "ASC", + "BEGIN", + "BETWEEN", + "BY", + "CASE", + "CAST", + "COLUMN", + "COMMENT", + "COMMIT", + "CONSTRAINT", + "CREATE", + "DECLARE", + "DEFAULT", + "DEFINE", + "DELETE", + "DESC", + "DISABLE", + "DISTINCT", + "DOUBLE", + "DROP", + "ELSE", + "ELSIF", + "ENABLE", + "END", + "EXCEPTION", + "EXISTS", + "EXPLAIN", + "FALSE", + "FOR", + "FROM", + "FULL", + "GROUP", + "GRANT", + "HAVING", + "IF", + "IN", + "INDEX", + "INNER", + "INSERT", + "INTO", + "IS", + "JOIN", + "KEY", + "LEFT", + "LIMIT", + "LOOP", + "MATERIALIZE", + "MATERIALIZED", + "MODIFY", + "NATURAL", + "NOT", + "NULL", + "OFF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OVER", + "OUTER", + "PARTITION", + "PRECISION", + "PRIMARY", + "PROMPT", + "RENAME", + "REPLACE", + "REVERSE", + "REVOKE", + "RIGHT", + "SELECT", + "SET", + "SMALLINT", + "SYNONYM", + "TABLE", + "TEMPORARY", + "THEN", + "TO", + "TRUE", + "UNIQUE", + "UPDATE", + "UNION", + "VALUES", + "VIEW", + "WHEN", + "WHERE", + "WHILE", + "WITH", + "FOREIGN", + "REFERENCES", + "CASCADE", + "RESTRICT", + "ACTION", + "PURGE" + // avoid dropping keyword prefixes + // (e.g., 'int' matched by 'in') + ).! ~~ !CharIn("a-zA-Z0-9_") + ) + + def keyword[$ : P](expected: String*) = P[Unit]( + anyKeyword + .opaque(expected.mkString(" or ")) + .filter(kw => expected.exists(_.equalsIgnoreCase(kw))) + .map(_ => ()) + ) + + def avoidReservedKeywords[$ : P] = P(!anyKeyword) + + def rawIdentifier[$ : P] = P( + avoidReservedKeywords ~~ + (CharIn("_a-zA-Z") ~~ CharsWhileIn("a-zA-Z0-9_").?).!.map(Name(_)) + ) + + def quotedIdentifier[$ : P] = P( + (("`" ~~/ CharsWhile(_ != '`').! ~ "`") + | ("\"" ~~/ CharsWhile(_ != '"').! ~~ "\"")).map(Name(_, true: scala.Boolean)) + ) + + def identifier[$ : P]: P[Name] = P(rawIdentifier | quotedIdentifier) + + def dottedPair[$ : P]: P[(Option[Name], Name)] = P((identifier ~~ ("." ~~ identifier).?).map { + case (x, None) => (None, x) + case (x, Some(y)) => (Some(x), y) + }) + + def dottedWildcard[$ : P]: P[Name] = P(identifier ~~ ".*") + + // Schema-qualified table name (schema.table) as a single Name + def qualifiedTableName[$ : P]: P[Name] = P((identifier ~~ ("." ~~ identifier).?).map { + case (x, None) => x + case (x, Some(y)) => Name(x.name + "." + y.name, x.quoted || y.quoted) + }) + + def digits[$ : P] = P(CharsWhileIn("0-9")) + def plusMinus[$ : P] = P("-" | "+") + def integral[$ : P] = "0" | CharIn("1-9") ~~ digits.? + + def integer[$ : P] = (plusMinus.? ~~ digits).!.map(_.toLong) ~~ !"." // Fail on a trailing period + def decimal[$ : P] = + (plusMinus.? ~~ digits ~~ ("." ~~ digits).? ~~ ("e" ~~ plusMinus.? ~~ digits).?).!.map(_.toDouble) + + def escapeQuote[$ : P] = P("''".!.map(_.replaceAll("''", "'"))) + def escapedString[$ : P] = P((CharsWhile(_ != '\'') | escapeQuote).repX.!.map { + _.replaceAll("''", "'") + }) + def quotedString[$ : P] = P("'" ~~ escapedString ~~ "'") + + def comma[$ : P] = P(",") + + def expressionList[$ : P]: P[Seq[Expression]] = P(expression.rep(sep = comma)) + + def expression[$ : P]: P[Expression] = P(disjunction) + + def disjunction[$ : P] = P((conjunction ~ (!keyword("ORDER") ~ keyword("OR") ~ conjunction).rep).map { x => + x._2.fold(x._1) { (accum, current) => + Arithmetic(accum, Arithmetic.Or, current) + } + }) + + def conjunction[$ : P] = P((negation ~ (keyword("AND") ~ negation).rep).map { x => + x._2.fold(x._1) { (accum, current) => + Arithmetic(accum, Arithmetic.And, current) + } + }) + + def negation[$ : P] = P((keyword("NOT").!.? ~ comparison).map { + case (None, expression) => expression + case (_, expression) => Not(expression) + }) + + def comparison[$ : P] = P( + (isNullBetweenIn ~ + (StringInIgnoreCase("=", "==", "!=", "<>", ">", "<", ">=", "<=", "LIKE", "NOT LIKE", "RLIKE", "NOT RLIKE").! ~ + addSub).?).map { + case (expression, None) => expression + case (lhs, Some((op, rhs))) => Comparison(lhs, op, rhs) + } + ) + + def optionalNegation[$ : P]: P[Expression => Expression] = P(keyword("NOT").!.?.map { x => + if (x.isDefined) { y => Not(y) } + else { y => y } + }) + + def isNullBetweenIn[$ : P] = P( + (addSub ~ ( + // IS [NOT] NULL -> IsNull(...) + (keyword("IS") ~ optionalNegation ~ + keyword("NULL").map(_ => IsNull(_))) | ( + // [IS] [NOT] BETWEEN low AND high + keyword("IS").? ~ optionalNegation ~ + ( + keyword("BETWEEN") ~ + addSub ~ keyword("AND") ~ addSub + ).map { case (low, high) => + (lhs: Expression) => + Arithmetic(Comparison(lhs, Comparison.Gte, low), Arithmetic.And, Comparison(lhs, Comparison.Lte, high)) + } + ) | ( + optionalNegation ~ keyword("IN") ~/ ( + ( + // IN ( SELECT ... ) + &("(" ~ keyword("SELECT")) ~/ + "(" ~/ this.select.map { query => + InExpression(_: Expression, Right(query)) + } ~ ")" + ) | ( + // IN ('list', 'of', 'items') + "(" ~/ + expressionList.map(exprs => InExpression(_: Expression, Left(exprs))) ~ + ")" + ) + ) + ) + ).?).map { + case (expression, None) => expression + case (expression, Some((neg, build))) => + val f: Expression => Expression = neg.asInstanceOf[Expression => Expression] + f(build(expression)) + } + ) + + def addSub[$ : P] = P((multDiv ~ ((CharIn("+\\-&\\|") | StringIn("<<", ">>")).! ~ multDiv).rep).map { x => + x._2.foldLeft(x._1: Expression) { (accum, current) => + Arithmetic(accum, current._1, current._2) + } + }) + + def multDivOp[$ : P] = P(CharIn("*/")) + + def multDiv[$ : P] = P((leaf ~ (multDivOp.! ~ leaf).rep).map { x => + x._2.foldLeft(x._1) { (accum, current) => + Arithmetic(accum, current._1, current._2) + } + }) + + def leaf[$ : P]: P[Expression] = P( + parens | + primitive | + jdbcvar | + caseWhen | ifThenElse | + cast | + nullLiteral | + // need to lookahead `function` to avoid conflicts with `column` + &(identifier ~ "(") ~ function | + column + ) + + def parens[$ : P] = P("(" ~ expression ~ ")") + + def primitive[$ : P] = P( + integer.map(v => LongPrimitive(v)) + | decimal.map(v => DoublePrimitive(v)) + | quotedString.map(v => StringPrimitive(v)) + | keyword("TRUE").map(_ => BooleanPrimitive(true)) + | keyword("FALSE").map(_ => BooleanPrimitive(false)) + ) + + def column[$ : P] = P(dottedPair.map(x => Column(x._2, x._1))) + + def nullLiteral[$ : P] = P(keyword("NULL").map(_ => NullPrimitive())) + + def function[$ : P] = P( + (identifier ~ "(" ~/ + keyword("DISTINCT").!.?.map { + _ != None + } ~ + ("*".!.map(_ => None) + | expressionList.map(Some(_))) ~ ")").map { case (name, distinct, args) => + Function(name, args, distinct) + } + ) + + def jdbcvar[$ : P] = P("?".!.map(_ => JDBCVar())) + + def caseWhen[$ : P] = P( + keyword("CASE") ~/ + (!keyword("WHEN") ~ expression).? ~ + ( + keyword("WHEN") ~/ + expression ~ + keyword("THEN") ~/ + expression + ).rep ~ + keyword("ELSE") ~/ + expression ~ + keyword("END") + ).map { case (target, whenThen, orElse) => + CaseWhenElse(target, whenThen, orElse) + } + + def ifThenElse[$ : P] = P( + keyword("IF") ~/ + expression ~/ + keyword("THEN") ~/ + expression ~/ + keyword("ELSE") ~/ + expression ~/ + keyword("END") + ).map { case (condition, thenClause, elseClause) => + CaseWhenElse(None, Seq(condition -> thenClause), elseClause) + } + + def cast[$ : P] = P( + ( + keyword("CAST") ~/ "(" ~/ + expression ~ keyword("AS") ~/ + identifier ~ ")" + ).map { case (expression, t) => + Cast(expression, t) + } + ) + + def statementTerminatorParser[$ : P]: P[Unit] = P(statementTerminator) + + // Parser for unparseable statements (error recovery) + // Consumes everything up to and including the next statement terminator + // IMPORTANT: Must match at least one char to avoid infinite loops with .rep + def unparseableStatement[$ : P]: P[Unparseable] = P( + // Consume at least one char, then everything up to (but not including) the terminator + AnyChar.! ~ CharsWhile(c => c != statementTerminator.head).! ~ statementTerminatorParser.! + ).map { case (first, rest, terminator) => Unparseable((first + rest + terminator).trim) } + + // Shared parsers with conditional logic + // Base implementation includes unparseable fallback - can be overridden by dialects that need special handling + def terminatedStatement[$ : P]: P[Stmt] = P( + (statement ~/ statementTerminatorParser) | + unparseableStatement.map(_.asInstanceOf[Stmt]) // Fallback for parse errors (requires at least one char) + ) + + def statement[$ : P]: P[Stmt] = + P( + Pass ~ // This trims off leading whitespace + (parenthesizedSelect.map(s => Select(s).asInstanceOf[Stmt]) + | update + | delete + | insert + | createStatement // Delegates to subclass + | (&(keyword("ALTER")) ~/ + alterView) + | dropTableOrView + | explainStatement + | dialectSpecificStatement // Hook for dialect extensions + ) + ) + + def explainStatement[$ : P]: P[Stmt] = P(keyword("EXPLAIN") ~ select.map(s => Explain(s).asInstanceOf[Stmt])) + + // Methods that can be overridden for dialect-specific behavior + def createStatement[$ : P]: P[Stmt] + def dialectSpecificStatement[$ : P]: P[Stmt] + def dataType[$ : P]: P[DataType] + + // Shared logic with conditional branches + def ifExists[$ : P]: P[Boolean] = + if (supportsIfExists) { + P((keyword("IF") ~ keyword("EXISTS")).!.?.map(_ != None)) + } else { + P(Pass).map(_ => false) + } + + def orReplace[$ : P]: P[Boolean] = P((keyword("OR") ~ keyword("REPLACE")).!.?.map(_ != None)) + + def alterView[$ : P]: P[Stmt] = P( + ( + keyword("ALTER") ~ + keyword("VIEW") ~/ + identifier ~ + ( + (keyword("MATERIALIZE").!.map(_ => Materialize(true))) + | (keyword("DROP") ~ + keyword("MATERIALIZE").!.map(_ => Materialize(false))) + ) + ).map { case (name, op) => AlterView(name, op).asInstanceOf[Stmt] } + ) + + def dropTableOrView[$ : P]: P[Stmt] = P( + ( + keyword("DROP") ~ + keyword("TABLE", "VIEW").!.map(_.toUpperCase) ~/ + ifExists ~ + identifier + ).map { + case ("TABLE", ifExists, name) => + DropTable(name, ifExists).asInstanceOf[Stmt] + case ("VIEW", ifExists, name) => + DropView(name, ifExists).asInstanceOf[Stmt] + case (_, _, _) => + throw new Exception("Internal Error") + } + ) + + def createView[$ : P]: P[Stmt] = P( + ( + keyword("CREATE") ~ + orReplace ~ + keyword("MATERIALIZED", "TEMPORARY").!.?.map { + _.map(_.toUpperCase) match { + case Some("MATERIALIZED") => (true, false) + case Some("TEMPORARY") => (false, true) + case _ => (false, false) + } + } ~ + keyword("VIEW") ~/ + identifier ~ + keyword("AS") ~/ + select + ).map { case (orReplace, (materialized, temporary), name, query) => + CreateView(name, orReplace, query, materialized, temporary).asInstanceOf[Stmt] + } + ) + + // Parser for function calls without parentheses (e.g., CURRENT_TIMESTAMP, SYSDATE) + def functionCallNoParens[$ : P]: P[Expression] = P(identifier.map { name => + // Treat as a function call with no arguments + Function(name, None, false) + }) + + def columnAnnotation[$ : P]: P[ColumnAnnotation] = P( + ( + keyword("PRIMARY") ~/ + keyword("KEY").map(_ => ColumnIsPrimaryKey()) + ) | ( + keyword("NOT") ~/ + keyword("NULL").map(_ => ColumnIsNotNullable()) + ) | ( + keyword("DEFAULT") ~/ + ( + ("(" ~ expression ~ ")") + | function // Function calls with parentheses + | functionCallNoParens // Function calls without parentheses (CURRENT_TIMESTAMP, etc.) + | primitive + ).map(ColumnDefaultValue(_)) + ) + ) + + def oneOrMoreAttributes[$ : P]: P[Seq[Name]] = P( + ("(" ~/ identifier.rep(sep = comma, min = 1) ~ ")") + | identifier.map(Seq(_)) + ) + + def tableField[$ : P]: P[Either[TableAnnotation, ColumnDefinition]] = P( + ( + keyword("PRIMARY") ~/ + keyword("KEY") ~ + oneOrMoreAttributes.map(attrs => Left(TablePrimaryKey(attrs))) + ) | ( + keyword("INDEX") ~/ + keyword("ON") ~ + oneOrMoreAttributes.map(attrs => Left(TableIndexOn(attrs))) + ) | ( + ( + identifier ~/ + identifier ~ + ("(" ~ + primitive.rep(sep = ",") ~ + ")").?.map(_.getOrElse(Seq())) ~ + columnAnnotation.rep + ).map { case (name, t, args, annotations) => + Right(ColumnDefinition(name, t, args, annotations)) + } + ) + ) + + def createTable[$ : P]: P[Stmt] = P( + ( + keyword("CREATE") ~ + orReplace ~ + keyword("TABLE") ~/ + identifier ~ + ( + (keyword("AS") ~/ select).map(Left(_)) + | ("(" ~/ + tableField.rep(sep = comma) ~ + ")").map(Right(_)) + ) + ).map { + case (orReplace, table, Left(query)) => + CreateTableAs(table, orReplace, query).asInstanceOf[Stmt] + case (orReplace, table, Right(fields)) => + val columns = fields.collect { case Right(r) => r } + val annotations = fields.collect { case Left(l) => l } + CreateTable(table, orReplace, columns, annotations).asInstanceOf[Stmt] + } + ) + + def valueList[$ : P]: P[InsertValues] = P( + ( + keyword("VALUES") ~/ + ("(" ~/ expressionList ~ ")").rep(sep = comma) + ).map(ExplicitInsert(_)) + ) + + def insert[$ : P]: P[Stmt] = P( + ( + keyword("INSERT") ~/ + ( + keyword("OR") ~/ + keyword("REPLACE") + ).!.?.map { case None => false; case _ => true } ~ + keyword("INTO") ~/ + identifier ~ + ("(" ~/ + identifier ~ + (comma ~/ identifier).rep ~ + ")").map(x => Seq(x._1) ++ x._2).? ~ + ( + (&(keyword("SELECT")) ~/ select.map(SelectInsert(_))) + | (&(keyword("VALUES")) ~/ valueList) + ) + ).map { case (orReplace, table, columns, values) => + Insert(table, columns, values, orReplace).asInstanceOf[Stmt] + } + ) + + def delete[$ : P]: P[Stmt] = P( + ( + keyword("DELETE") ~/ + keyword("FROM") ~/ + identifier ~ + ( + keyword("WHERE") ~/ + expression + ).? + ).map { case (table, where) => Delete(table, where).asInstanceOf[Stmt] } + ) + + def update[$ : P]: P[Stmt] = P( + ( + keyword("UPDATE") ~/ + identifier ~ + keyword("SET") ~/ + ( + identifier ~ + "=" ~/ + expression + ).rep(sep = comma, min = 1) ~ + ( + StringInIgnoreCase("WHERE") ~/ + expression + ).? + ).map { case (table, set, where) => + Update(table, set, where).asInstanceOf[Stmt] + } + ) + + def alias[$ : P]: P[Name] = + P(keyword("AS").? ~ identifier) + + def selectTarget[$ : P]: P[SelectTarget] = P( + P("*").map(_ => SelectAll()) + // Dotted wildcard needs a lookahead since a single token isn't + // enough to distinguish between `foo`.* and `foo` AS `bar` + | (&(dottedWildcard) ~ + dottedWildcard.map(SelectTable(_))) + | (expression ~ alias.?).map(x => SelectExpression(x._1, x._2)) + ) + + def simpleFromElement[$ : P]: P[FromElement] = P( + (("(" ~ select ~ ")" ~ alias).map(x => FromSelect(x._1, x._2))) + | ((dottedPair ~ alias.?).map { case (schema, table, alias) => + FromTable(schema, table, alias) + }) + | (("(" ~ fromElement ~ ")" ~ alias.?).map { + case (from, None) => from + case (from, Some(alias)) => from.withAlias(alias) + }) + ) + + def joinWith[$ : P]: P[Join.Type] = P( + keyword("JOIN").map(Unit => Join.Inner) + | (( + keyword("NATURAL").!.map(Unit => Join.Natural) + | keyword("INNER").map(Unit => Join.Inner) + | (( + keyword("LEFT").map(Unit => Join.LeftOuter) + | keyword("RIGHT").map(Unit => Join.RightOuter) + | keyword("FULL").map(Unit => Join.FullOuter) + ).?.map(_.getOrElse(Join.FullOuter)) ~/ + keyword("OUTER")) + ) ~/ keyword("JOIN")) + ) + + def fromElement[$ : P]: P[FromElement] = P( + ( + simpleFromElement ~ ( + &(joinWith) ~ + joinWith ~/ + simpleFromElement ~/ + ( + keyword("ON") ~/ + expression + ).? ~ + alias.? + ).rep + ).map { case (lhs, rest) => + rest.foldLeft(lhs) { (lhs, next) => + val (t, rhs, onClause, alias) = next + FromJoin(lhs, rhs, t, onClause.getOrElse(BooleanPrimitive(true)), alias) + } + } + ) + + def fromClause[$ : P] = P( + keyword("FROM") ~/ + fromElement.rep(sep = comma, min = 1) + ) + + def whereClause[$ : P] = P(keyword("WHERE") ~/ expression) + + def groupByClause[$ : P] = P( + keyword("GROUP") ~/ + keyword("BY") ~/ + expressionList + ) + + def havingClause[$ : P] = P(keyword("HAVING") ~ expression) + + def options[A](default: A, options: Map[String, A]): (Option[String] => A) = + _.map(_.toUpperCase).map(options(_)).getOrElse(default) + + def ascOrDesc[$ : P] = P(keyword("ASC", "DESC").!.?.map { + options(true, Map("ASC" -> true, "DESC" -> false)) + }) + + def orderBy[$ : P] = P((expression ~ ascOrDesc).map(x => OrderBy(x._1, x._2))) + + def orderByClause[$ : P] = P( + keyword("ORDER") ~/ + keyword("BY") ~/ + orderBy.rep(sep = comma, min = 1) + ) + + def limitClause[$ : P] = P( + keyword("LIMIT") ~/ + integer + ) + + def offsetClause[$ : P] = P( + keyword("OFFSET") ~/ + integer + ) + + def allOrDistinct[$ : P] = P(keyword("ALL", "DISTINCT").!.?.map { + options(Union.Distinct, Map("ALL" -> Union.All, "DISTINCT" -> Union.Distinct)) + }) + + def unionClause[$ : P] = P(keyword("UNION") ~/ allOrDistinct ~/ parenthesizedSelect) + + def parenthesizedSelect[$ : P]: P[SelectBody] = P( + ( + "(" ~/ select ~ ")" ~/ unionClause.? + ).map { + case (body, Some((unionType, unionBody))) => body.unionWith(unionType, unionBody) + case (body, None) => body + } | select + ) + + def select[$ : P]: P[SelectBody] = P( + ( + keyword("SELECT") ~/ + keyword("DISTINCT").!.?.map(_ != None) ~/ + selectTarget.rep(sep = ",") ~ + fromClause.?.map(_.toSeq.flatten) ~ + whereClause.? ~ + groupByClause.? ~ + havingClause.? ~ + orderByClause.?.map(_.toSeq.flatten) ~ + limitClause.? ~ + offsetClause.? ~ + unionClause.? + ).map { case (distinct, targets, froms, where, groupBy, having, orderBy, limit, offset, union) => + SelectBody( + distinct = distinct, + target = targets, + from = froms, + where = where, + groupBy = groupBy, + having = having, + orderBy = orderBy, + limit = limit, + offset = offset, + union = union + ) + } + ) + + // Parse multiple statements with error recovery + // Handle End separately to avoid infinite loops + def allStatements[$ : P]: P[Seq[Stmt]] = P(Start ~ terminatedStatement.rep ~ End).map(_.toSeq) +} + +trait SQLBaseObject { + type Stmt <: Statement + def name: String + def apply(input: String): Parsed[Stmt] + def apply(input: Reader): StreamParser[Stmt] + + // Abstract method to get statement terminator for this dialect + protected def statementTerminator: String + + // Override this to specify all valid terminator characters (e.g., both ';' and '/') + protected def statementTerminatorChars: Set[Char] = statementTerminator.toSet + + // Parse multiple statements with error recovery, tracking positions + def parseAll(input: String): Seq[StatementParseResult] = { + import sparsity.common.statement.StatementParseResult + import fastparse._ + import fastparse.Parsed + + def parseStatementsSequentially( + remaining: String, + currentPos: Int, + results: List[StatementParseResult] + ): List[StatementParseResult] = { + // Skip whitespace and comment lines at the start + def skipWhitespaceAndComments(s: String): (String, Int) = { + var pos = 0 + var skipped = 0 + while (pos < s.length) { + val char = s(pos) + if (char.isWhitespace) { + pos += 1 + skipped += 1 + } else if (pos + 1 < s.length && s(pos) == '-' && s(pos + 1) == '-') { + // Skip -- comment until newline + while (pos < s.length && s(pos) != '\n') { + pos += 1 + skipped += 1 + } + // Skip the newline too + if (pos < s.length) { + pos += 1 + skipped += 1 + } + } else if (pos + 1 < s.length && s(pos) == '/' && s(pos + 1) == '*') { + // Skip /* */ comment + pos += 2 + skipped += 2 + while (pos + 1 < s.length && !(s(pos) == '*' && s(pos + 1) == '/')) { + pos += 1 + skipped += 1 + } + if (pos + 1 < s.length) { + pos += 2 + skipped += 2 + } + } else { + // Found non-whitespace, non-comment content + return (s.substring(pos), skipped) + } + } + ("", skipped) + } + + // Find the next statement terminator, skipping over string literals and comments + def findNextTerminator(s: String, terminatorChars: Set[Char]): Int = { + var pos = 0 + while (pos < s.length) { + val char = s(pos) + if (char == '\'') { + // Inside string literal - skip until closing quote (handle escaped quotes '') + pos += 1 + while (pos < s.length) { + if (s(pos) == '\'') { + if (pos + 1 < s.length && s(pos + 1) == '\'') { + pos += 2 // Skip escaped quote '' + } else { + pos += 1 // End of string + return findNextTerminator(s.substring(pos), terminatorChars) match { + case -1 => -1 + case n => pos + n + } + } + } else { + pos += 1 + } + } + return -1 // Unclosed string + } else if (pos + 1 < s.length && char == '-' && s(pos + 1) == '-') { + // Skip -- comment until newline + while (pos < s.length && s(pos) != '\n') pos += 1 + } else if (pos + 1 < s.length && char == '/' && s(pos + 1) == '*') { + // Skip /* */ comment + pos += 2 + while (pos + 1 < s.length && !(s(pos) == '*' && s(pos + 1) == '/')) pos += 1 + if (pos + 1 < s.length) pos += 2 + } else if (terminatorChars.contains(char)) { + return pos + } else { + pos += 1 + } + } + -1 + } + + val (trimmed, skipped) = skipWhitespaceAndComments(remaining) + val statementStartPos = currentPos + skipped + + if (trimmed.isEmpty) { + return results.reverse + } + + // Try to parse a single statement + val parseResult = apply(trimmed) + + val (result, endPos) = parseResult match { + case Parsed.Success(stmt, index) => + val parsedLength = index.toInt + val statementEndPos = statementStartPos + parsedLength + val parsedResult: Either[Unparseable, Stmt] = stmt match { + case u: Unparseable => Left(u) + case _ => Right(stmt.asInstanceOf[Stmt]) + } + (parsedResult, statementEndPos) + + case Parsed.Failure(label, index, extra) => + // Parse failed - extract unparseable content + val terminatorPos = findNextTerminator(trimmed, statementTerminatorChars) + val unparseableEnd = if (terminatorPos >= 0) terminatorPos + 1 else trimmed.length + + val unparseableText = trimmed.substring(0, unparseableEnd).trim + val unparseableEndPos = statementStartPos + unparseableEnd + (Left(Unparseable(unparseableText)), unparseableEndPos) + } + + val parseResultObj = StatementParseResult(result, statementStartPos, endPos) + + // Continue parsing from the end position + val nextRemaining = if (endPos < input.length) { + input.substring(endPos) + } else { + "" + } + + parseStatementsSequentially(nextRemaining, endPos, parseResultObj :: results) + } + + parseStatementsSequentially(input, 0, Nil) + } +} diff --git a/src/main/scala/sparsity/parser/StreamParser.scala b/src/main/scala/sparsity/common/parser/StreamParser.scala similarity index 61% rename from src/main/scala/sparsity/parser/StreamParser.scala rename to src/main/scala/sparsity/common/parser/StreamParser.scala index e3f1b35..9c8fc34 100644 --- a/src/main/scala/sparsity/parser/StreamParser.scala +++ b/src/main/scala/sparsity/common/parser/StreamParser.scala @@ -1,42 +1,36 @@ -package sparsity.parser +package sparsity.common.parser import scala.collection.mutable.Buffer import scala.io._ import java.io._ -import fastparse._, NoWhitespace._ -import com.typesafe.scalalogging.LazyLogging +import fastparse._ +import scribe.Logger -class StreamParser[R](parser:(Iterator[String] => Parsed[R]), source: Reader) - extends Iterator[Parsed[R]] - with LazyLogging -{ +class StreamParser[R](parser: (Iterator[String] => Parsed[R]), source: Reader) extends Iterator[Parsed[R]] { + private val logger = Logger("sparsity.common.parser.StreamParser") val bufferedSource = source match { - case b:BufferedReader => b + case b: BufferedReader => b case _ => new BufferedReader(source) } val buffer = Buffer[String]() - def load(): Unit = { - while(bufferedSource.ready){ + def load(): Unit = + while (bufferedSource.ready) buffer += bufferedSource.readLine.replace("\\n", " "); - } - } - def loadBlocking(): Unit = { + def loadBlocking(): Unit = buffer += bufferedSource.readLine.replace("\\n", " "); - } def hasNext: Boolean = { load(); buffer.size > 0 } - def next(): Parsed[R] = - { + def next(): Parsed[R] = { load(); - if(buffer.size > 0) { + if (buffer.size > 0) { parser(buffer.iterator) match { - case r@Parsed.Success(result, index) => + case r @ Parsed.Success(result, index) => logger.debug(s"Parsed(index = $index): $result") skipBytes(index) return r - case f:Parsed.Failure => + case f: Parsed.Failure => buffer.clear() return f } @@ -45,29 +39,29 @@ class StreamParser[R](parser:(Iterator[String] => Parsed[R]), source: Reader) } } - def skipBytes(offset: Int) : Unit = { + def skipBytes(offset: Int): Unit = { var dropped = 0 - while(offset > dropped && !buffer.isEmpty){ + while (offset > dropped && !buffer.isEmpty) { logger.trace(s"Checking for drop: $dropped / $offset: Next unit: ${buffer.head.length}") - if(buffer.head.length < (offset - dropped)){ + if (buffer.head.length < (offset - dropped)) { dropped = dropped + buffer.head.length logger.trace(s"Dropping '${buffer.head}' ($dropped / $offset dropped so far)") buffer.remove(0) } else { logger.trace(s"Trimming '${buffer.head}' (by ${offset - dropped})") var head = buffer.head - .substring(offset - dropped) - .replace("^\\w+", "") + .substring(offset - dropped) + .replace("^\\w+", "") logger.trace(s"Trimming leading whitespace") - while(head.length <= 0 && !buffer.isEmpty){ + while (head.length <= 0 && !buffer.isEmpty) { logger.trace(s"Nothing but whitespace left. Dropping and trying again") buffer.remove(0) - if(!buffer.isEmpty) { + if (!buffer.isEmpty) { head = buffer.head.replace("^\\w+", "") } } logger.trace(s"Remaining = '$head'") - if(head.length > 0){ buffer.update(0, head) } + if (head.length > 0) { buffer.update(0, head) } return } } diff --git a/src/main/scala/sparsity/common/select/FromElement.scala b/src/main/scala/sparsity/common/select/FromElement.scala new file mode 100644 index 0000000..6926a41 --- /dev/null +++ b/src/main/scala/sparsity/common/select/FromElement.scala @@ -0,0 +1,64 @@ +package sparsity.common.select + +import sparsity.common.Name +import sparsity.common.expression.{BooleanPrimitive, Expression} + +sealed abstract class FromElement { + def aliases: Seq[Name] + def withAlias(newAlias: Name): FromElement + def toSqlWithParensIfNeeded: String = toSql + def toSql: String +} + +case class FromTable(schema: Option[Name], table: Name, alias: Option[Name]) extends FromElement { + def aliases = Seq(alias.getOrElse(table)) + override def toSql = + schema.map(_.toSql + ".").getOrElse("") + + table.toSql + + alias.map(" AS " + _.toSql).getOrElse("") + def withAlias(newAlias: Name) = FromTable(schema, table, Some(newAlias)) + +} +case class FromSelect(body: SelectBody, val alias: Name) extends FromElement { + def aliases = Seq(alias) + override def toSql = + "(" + body.toSql + ") AS " + alias.toSql + def withAlias(newAlias: Name) = FromSelect(body, newAlias) +} +case class FromJoin( + lhs: FromElement, + rhs: FromElement, + t: Join.Type = Join.Inner, + on: Expression = BooleanPrimitive(true), + alias: Option[Name] = None +) extends FromElement { + def aliases = + alias.map(Seq(_)).getOrElse(lhs.aliases ++ rhs.aliases) + override def toSqlWithParensIfNeeded: String = "(" + toSql + ")" + override def toSql = { + val baseString = + lhs.toSqlWithParensIfNeeded + " " + + Join.toSql(t) + " " + + rhs.toSqlWithParensIfNeeded + + (on match { case BooleanPrimitive(true) => ""; case _ => " ON " + on.toSql }) + alias match { + case Some(a) => "(" + baseString + ") AS " + a.toSql + case None => baseString + } + } + def withAlias(newAlias: Name) = FromJoin(lhs, rhs, t, on, Some(newAlias)) + +} + +object Join extends Enumeration { + type Type = Value + val Inner, Natural, LeftOuter, RightOuter, FullOuter = Value + + def toSql(t: Type) = t match { + case Inner => "JOIN" + case Natural => "NATURAL JOIN" + case LeftOuter => "LEFT OUTER JOIN" + case RightOuter => "RIGHT OUTER JOIN" + case FullOuter => "FULL OUTER JOIN" + } +} diff --git a/src/main/scala/sparsity/common/select/OrderBy.scala b/src/main/scala/sparsity/common/select/OrderBy.scala new file mode 100644 index 0000000..87b8896 --- /dev/null +++ b/src/main/scala/sparsity/common/select/OrderBy.scala @@ -0,0 +1,9 @@ +package sparsity.common.select + +import sparsity.common.expression.{Expression, ToSql} + +case class OrderBy(expression: Expression, ascending: Boolean) extends ToSql { + def descending = !ascending + override def toSql = expression.toSql + (if (ascending) { "" } + else { " DESC" }) +} diff --git a/src/main/scala/sparsity/common/select/SelectBody.scala b/src/main/scala/sparsity/common/select/SelectBody.scala new file mode 100644 index 0000000..3ce138e --- /dev/null +++ b/src/main/scala/sparsity/common/select/SelectBody.scala @@ -0,0 +1,43 @@ +package sparsity.common.select + +import sparsity.common.expression.{Expression, ToSql} + +case class SelectBody( + distinct: Boolean = false, + target: Seq[SelectTarget] = Seq(), + from: Seq[FromElement] = Seq(), + where: Option[Expression] = None, + groupBy: Option[Seq[Expression]] = None, + having: Option[Expression] = None, + orderBy: Seq[OrderBy] = Seq(), + limit: Option[Long] = None, + offset: Option[Long] = None, + union: Option[(Union.Type, SelectBody)] = None +) extends ToSql { + def stringElements: Seq[String] = + Seq("SELECT") ++ + (if (distinct) { Some("DISTINCT") } + else { None }) ++ + Seq(target.map(_.toSql).mkString(", ")) ++ + (if (from.isEmpty) { None } + else { Seq("FROM", from.map(_.toSql).mkString(", ")) }) ++ + (where.map(x => Seq("WHERE", x.toSql)).toSeq.flatten) ++ + (groupBy.map(x => Seq("GROUP BY") ++ x.map(_.toSql)).toSeq.flatten) ++ + (having.map(x => Seq("HAVING", x.toSql)).toSeq.flatten) ++ + (if (orderBy.isEmpty) { Seq() } + else { Seq("ORDER BY", orderBy.map(_.toSql).mkString(", ")) }) ++ + (limit.map(x => Seq("LIMIT", x.toString)).toSeq.flatten) ++ + (offset.map(x => Seq("OFFSET", x.toString)).toSeq.flatten) ++ + (union.map { case (t, b) => Seq(Union.toSql(t)) ++ b.stringElements }.toSeq.flatten) + override def toSql = stringElements.mkString(" ") + + def unionWith(t: Union.Type, body: SelectBody): SelectBody = { + val replacementUnion = + union match { + case Some(nested) => (nested._1, nested._2.unionWith(t, body)) + case None => (t, body) + } + + SelectBody(distinct, target, from, where, groupBy, having, orderBy, limit, offset, union = Some(replacementUnion)) + } +} diff --git a/src/main/scala/sparsity/common/select/SelectTarget.scala b/src/main/scala/sparsity/common/select/SelectTarget.scala new file mode 100644 index 0000000..738785d --- /dev/null +++ b/src/main/scala/sparsity/common/select/SelectTarget.scala @@ -0,0 +1,17 @@ +package sparsity.common.select + +import sparsity.common.Name +import sparsity.common.expression.{Expression, ToSql} + +sealed abstract class SelectTarget extends ToSql + +case class SelectAll() extends SelectTarget { + override def toSql = "*" +} +case class SelectTable(table: Name) extends SelectTarget { + override def toSql = table.toSql + ".*" +} +case class SelectExpression(expression: Expression, alias: Option[Name] = None) extends SelectTarget { + override def toSql = + expression.toSql ++ alias.map(" AS " + _.toSql).getOrElse("") +} diff --git a/src/main/scala/sparsity/select/Union.scala b/src/main/scala/sparsity/common/select/Union.scala similarity index 70% rename from src/main/scala/sparsity/select/Union.scala rename to src/main/scala/sparsity/common/select/Union.scala index a5cfb33..c0b688a 100644 --- a/src/main/scala/sparsity/select/Union.scala +++ b/src/main/scala/sparsity/common/select/Union.scala @@ -1,11 +1,11 @@ -package sparsity.select +package sparsity.common.select object Union extends Enumeration { type Type = Value val All, Distinct = Value - def toString(t: Type) = t match { + def toSql(t: Type) = t match { case All => "UNION ALL" case Distinct => "UNION DISTINCT" } -} \ No newline at end of file +} diff --git a/src/main/scala/sparsity/common/statement/ColumnDefinition.scala b/src/main/scala/sparsity/common/statement/ColumnDefinition.scala new file mode 100644 index 0000000..1ad4834 --- /dev/null +++ b/src/main/scala/sparsity/common/statement/ColumnDefinition.scala @@ -0,0 +1,32 @@ +package sparsity.common.statement + +import sparsity.common.expression.{Expression, PrimitiveValue, ToSql} +import sparsity.common.Name + +sealed abstract class ColumnAnnotation extends ToSql + +case class ColumnIsPrimaryKey() extends ColumnAnnotation { + override def toSql = "PRIMARY KEY" +} +case class ColumnIsNotNullable() extends ColumnAnnotation { + override def toSql = "NOT NULL" +} +case class ColumnIsNullable() extends ColumnAnnotation { + override def toSql = "NULL" +} +case class ColumnDefaultValue(v: Expression) extends ColumnAnnotation { + override def toSql = "DEFAULT VALUE " + v.toSql +} + +case class ColumnDefinition( + name: Name, + t: Name, + args: Seq[PrimitiveValue] = Seq(), + annotations: Seq[ColumnAnnotation] = Seq() +) { + def toSql = + name.toSql + " " + + t.toSql + (if (args.isEmpty) { "" } + else { "(" + args.map(_.toSql).mkString(", ") + ")" }) + + annotations.map(" " + _.toSql).mkString +} diff --git a/src/main/scala/sparsity/common/statement/InsertTarget.scala b/src/main/scala/sparsity/common/statement/InsertTarget.scala new file mode 100644 index 0000000..8e6d7a2 --- /dev/null +++ b/src/main/scala/sparsity/common/statement/InsertTarget.scala @@ -0,0 +1,15 @@ +package sparsity.common.statement + +import sparsity.common.expression.Expression +import sparsity.common.select.SelectBody +import sparsity.common.expression.ToSql + +sealed abstract class InsertValues extends ToSql + +case class ExplicitInsert(values: Seq[Seq[Expression]]) extends InsertValues { + override def toSql = + "VALUES " + values.map("(" + _.map(_.toSql).mkString(", ") + ")").mkString(", ") +} +case class SelectInsert(query: SelectBody) extends InsertValues { + override def toSql = query.toSql +} diff --git a/src/main/scala/sparsity/common/statement/Statement.scala b/src/main/scala/sparsity/common/statement/Statement.scala new file mode 100644 index 0000000..fa87cf1 --- /dev/null +++ b/src/main/scala/sparsity/common/statement/Statement.scala @@ -0,0 +1,123 @@ +package sparsity.common.statement + +import sparsity.common.Name +import sparsity.common.expression.Expression +import sparsity.common.select.SelectBody +import sparsity.common.alter._ + +abstract class Statement { + def toSql: String +} + +/** Empty statement (whitespace/comments only). Used for comment-only statements or empty statements with just + * terminators. + */ +case class EmptyStatement() extends Statement { + override def toSql = "" +} + +/** Unparseable statement (parse error recovery). Captures statements that failed to parse, allowing the parser to + * continue. + */ +case class Unparseable(content: String) extends Statement { + override def toSql = content +} + +case class Select(body: SelectBody, withClause: Seq[WithClause] = Seq()) extends Statement { + override def toSql = + (if (withClause.isEmpty) { "" } + else { "WITH " + withClause.map(_.toSql).mkString(",\n") + "\n" }) + + body.toSql + ";" +} + +case class Update(table: Name, set: Seq[(Name, Expression)], where: Option[Expression]) extends Statement { + override def toSql = + s"UPDATE ${table.toSql} SET ${set.map(x => x._1.toSql + " = " + x._2.toSql).mkString(", ")}" + + where.map(w => " WHERE " + w.toSql).getOrElse("") + ";" +} + +case class Delete(table: Name, where: Option[Expression]) extends Statement { + override def toSql = + s"DELETE FROM ${table.toSql}" + + where.map(w => " WHERE " + w.toSql).getOrElse("") + ";" +} + +case class Insert(table: Name, columns: Option[Seq[Name]], values: InsertValues, orReplace: Boolean) extends Statement { + override def toSql = + s"INSERT${if (orReplace) { " OR REPLACE " } + else { "" }} INTO ${table.toSql}" + + columns.map("(" + _.map(_.toSql).mkString(", ") + ")").getOrElse("") + + " " + values.toSql + ";" + +} + +case class CreateTable( + name: Name, + orReplace: Boolean, + columns: Seq[ColumnDefinition], + annotations: Seq[TableAnnotation] +) extends Statement { + override def toSql = + s"CREATE ${if (orReplace) { "OR REPLACE " } + else { "" }}TABLE ${name.toSql}(" + + (columns.map(_.toSql) ++ + annotations.map(_.toSql)).mkString(", ") + + ");" +} + +case class CreateTableAs(name: Name, orReplace: Boolean, query: SelectBody) extends Statement { + override def toSql = + s"CREATE ${if (orReplace) { "OR REPLACE " } + else { "" }}TABLE ${name.toSql} AS ${query.toSql};" +} + +case class CreateView( + name: Name, + orReplace: Boolean, + query: SelectBody, + materialized: Boolean = false, + temporary: Boolean = false +) extends Statement { + override def toSql = + "CREATE " + + (if (orReplace) { "OR REPLACE " } + else { "" }) + + (if (materialized) { "MATERIALIZED " } + else { "" }) + + s"VIEW ${name.toSql} AS ${query.toSql};" +} + +case class AlterView(name: Name, action: AlterViewAction) extends Statement { + override def toSql = s"ALTER VIEW ${name.toSql} ${action.toSql}" +} + +case class DropTable(name: Name, ifExists: Boolean) extends Statement { + override def toSql = + "DROP TABLE " + ( + if (ifExists) { "IF EXISTS " } + else { "" } + ) + name.toSql + ";" +} + +case class DropView(name: Name, ifExists: Boolean) extends Statement { + override def toSql = + "DROP VIEW " + ( + if (ifExists) { "IF EXISTS " } + else { "" } + ) + name.toSql + ";" +} + +case class Explain(query: SelectBody) extends Statement { + override def toSql = s"EXPLAIN ${query.toSql};" +} + +class CreateIndex(val name: Name, val table: Name, val columns: Seq[Name], val unique: Boolean = false) + extends Statement { + override def toSql = { + val uniqueStr = if (unique) "UNIQUE " else "" + s"CREATE ${uniqueStr}INDEX ${name.toSql} ON ${table.toSql}(${columns.map(_.toSql).mkString(", ")});" + } +} + +/** Result of parsing a single statement with position information */ +case class StatementParseResult(result: Either[Unparseable, Statement], startPos: Int, endPos: Int) diff --git a/src/main/scala/sparsity/common/statement/TableAnnotation.scala b/src/main/scala/sparsity/common/statement/TableAnnotation.scala new file mode 100644 index 0000000..a17dcc2 --- /dev/null +++ b/src/main/scala/sparsity/common/statement/TableAnnotation.scala @@ -0,0 +1,16 @@ +package sparsity.common.statement + +import sparsity.common.Name +import sparsity.common.expression.ToSql + +sealed abstract class TableAnnotation extends ToSql + +case class TablePrimaryKey(columns: Seq[Name]) extends TableAnnotation { + override def toSql = s"PRIMARY KEY (${columns.map(_.toSql).mkString(", ")})" +} +case class TableIndexOn(columns: Seq[Name]) extends TableAnnotation { + override def toSql = s"INDEX (${columns.map(_.toSql).mkString(", ")})" +} +case class TableUnique(columns: Seq[Name]) extends TableAnnotation { + override def toSql = s"UNIQUE (${columns.map(_.toSql).mkString(", ")})" +} diff --git a/src/main/scala/sparsity/common/statement/WithClause.scala b/src/main/scala/sparsity/common/statement/WithClause.scala new file mode 100644 index 0000000..a4c468a --- /dev/null +++ b/src/main/scala/sparsity/common/statement/WithClause.scala @@ -0,0 +1,9 @@ +package sparsity.common.statement + +import sparsity.common.Name +import sparsity.common.select.SelectBody +import sparsity.common.expression.ToSql + +case class WithClause(body: SelectBody, name: Name) extends ToSql { + override def toSql = s"${name.toSql} AS (${body.toSql})" +} diff --git a/src/main/scala/sparsity/expression/Expression.scala b/src/main/scala/sparsity/expression/Expression.scala deleted file mode 100644 index d89f198..0000000 --- a/src/main/scala/sparsity/expression/Expression.scala +++ /dev/null @@ -1,363 +0,0 @@ -package sparsity.expression - -import sparsity.Name - -sealed abstract class Expression -{ - def needsParenthesis: Boolean - def children: Seq[Expression] - def rebuild(newChildren: Seq[Expression]): Expression -} - -object Expression -{ - def parenthesize(e: Expression) = - if(e.needsParenthesis) { "(" +e.toString +")" } else { e.toString } - def escapeString(s: String) = - s.replaceAll("'", "''") -} - -sealed abstract class PrimitiveValue extends Expression -{ - def needsParenthesis = false - def children:Seq[Expression] = Seq() - def rebuild(newChildren: Seq[Expression]):Expression = this -} -case class LongPrimitive(v: Long) extends PrimitiveValue - { override def toString = v.toString } -case class DoublePrimitive(v: Double) extends PrimitiveValue - { override def toString = v.toString } -case class StringPrimitive(v: String) extends PrimitiveValue - { override def toString = "'"+Expression.escapeString(v.toString)+"'" } -case class BooleanPrimitive(v: Boolean) extends PrimitiveValue - { override def toString = v.toString } -case class NullPrimitive() extends PrimitiveValue - { override def toString = "NULL" } - -/** - * A column reference 'Table.Col' - */ -case class Column(column: Name, table: Option[Name] = None) extends Expression -{ - override def toString = (table.toSeq ++ Seq(column)).mkString(".") - def needsParenthesis = false - def children:Seq[Expression] = Seq() - def rebuild(newChildren: Seq[Expression]):Expression = this -} -/** - * Any simple binary arithmetic expression - * See Arithmetic.scala for an enumeration of the possibilities - * *not* a case class to avoid namespace collisions. - * Arithmetic defines apply/unapply explicitly - */ -class Arithmetic( - val lhs: Expression, - val op: Arithmetic.Op, - val rhs: Expression -) extends Expression -{ - override def toString = - Expression.parenthesize(lhs)+" "+ - Arithmetic.opString(op)+" "+ - Expression.parenthesize(rhs) - def needsParenthesis = true - override def equals(other:Any): Boolean = - other match { - case Arithmetic(otherlhs, otherop, otherrhs) => lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs) - case _ => false - } - def children:Seq[Expression] = Seq(lhs, rhs) - def rebuild(c: Seq[Expression]): Expression = Arithmetic(c(0), op, c(1)) -} - -object Arithmetic extends Enumeration { - type Op = Value - val Add, Sub, Mult, Div, And, Or, BitAnd, BitOr, ShiftLeft, ShiftRight = Value - - def apply(lhs: Expression, op: Op, rhs: Expression) = - new Arithmetic(lhs, op, rhs) - def apply(lhs: Expression, op: String, rhs: Expression) = - new Arithmetic(lhs, fromString(op), rhs) - - def unapply(e:Arithmetic): Option[(Expression, Op, Expression)] = - Some( (e.lhs, e.op, e.rhs) ) - - /** - * Regular expresion to match any and all binary operations - */ - def matchRegex = """\+|-|\*|/|\|\||&&|\||&""".r - - /** - * Convert from the operator's string encoding to its Arith.Op rep - */ - def fromString(a: String) = { - a.toUpperCase match { - case "+" => Add - case "-" => Sub - case "*" => Mult - case "/" => Div - case "&" => BitAnd - case "|" => BitOr - case "<<" => ShiftLeft - case ">>" => ShiftRight - case "&&" => And - case "||" => Or - case "AND" => And - case "OR" => And - case x => throw new Exception("Invalid operand '"+x+"'") - } - } - /** - * Convert from the operator's Arith.Op representation to a string - */ - def opString(v: Op): String = - { - v match { - case Add => "+" - case Sub => "-" - case Mult => "*" - case Div => "/" - case BitAnd => "&" - case BitOr => "|" - case ShiftLeft => "<<" - case ShiftRight => ">>" - case And => "AND" - case Or => "OR" - } - } - /** - * Is this binary operation a boolean operator (AND/OR) - */ - def isBool(v: Op): Boolean = - { - v match { - case And | Or => true - case _ => false - } - } - - /** - * Is this binary operation a numeric operator (+, -, *, /, & , |) - */ - def isNumeric(v: Op): Boolean = !isBool(v) - -} - -/** - * Any simple comparison expression - * See Comparison.scala for an enumeration of the possibilities - * *not* a case class to avoid namespace collisions. - * Comparison defines apply/unapply explicitly - */ -class Comparison( - val lhs: Expression, - val op: Comparison.Op, - val rhs: Expression -) extends Expression -{ - override def toString = - Expression.parenthesize(lhs)+" "+ - Comparison.opString(op)+" "+ - Expression.parenthesize(rhs) - def needsParenthesis = true - override def equals(other:Any): Boolean = - other match { - case Comparison(otherlhs, otherop, otherrhs) => lhs.equals(otherlhs) && (op == otherop) && rhs.equals(otherrhs) - case _ => false - } - def children:Seq[Expression] = Seq(lhs, rhs) - def rebuild(c: Seq[Expression]): Expression = Comparison(c(0), op, c(1)) -} -object Comparison extends Enumeration { - type Op = Value - val Eq, Neq, Gt, Lt, Gte, Lte, Like, NotLike, RLike, NotRLike = Value - - val strings = Map( - "=" -> Eq, - "==" -> Eq, - "!=" -> Neq, - "<>" -> Neq, - ">" -> Gt, - "<" -> Lt, - ">=" -> Gte, - "<=" -> Lte, - "LIKE" -> Like, // SQL-style LIKE expression - "RLIKE" -> RLike, // Regular expression lookup - "NOT LIKE" -> NotLike, // Inverse LIKE - "NOT RLIKE" -> RLike // Inverse NOT LIKE - ) - - def apply(lhs: Expression, op: Op, rhs: Expression) = - new Comparison(lhs, op, rhs) - def apply(lhs: Expression, op: String, rhs: Expression) = - new Comparison(lhs, strings(op.toUpperCase), rhs) - - def unapply(e:Comparison): Option[(Expression, Op, Expression)] = - Some( (e.lhs, e.op, e.rhs) ) - - def negate(v: Op): Op = - { - v match { - case Eq => Neq - case Neq => Eq - case Gt => Lte - case Gte => Lt - case Lt => Gte - case Lte => Gt - case Like => NotLike - case NotLike => Like - case RLike => NotRLike - case NotRLike => RLike - } - } - - def flip(v: Op): Option[Op] = - { - v match { - case Eq => Some(Eq) - case Neq => Some(Neq) - case Gt => Some(Lt) - case Gte => Some(Lte) - case Lt => Some(Gt) - case Lte => Some(Gte) - case Like => None - case NotLike => None - case RLike => None - case NotRLike => None - } - } - - def opString(v: Op): String = - { - v match { - case Eq => "=" - case Neq => "<>" - case Gt => ">" - case Gte => ">=" - case Lt => "<" - case Lte => "<=" - case Like => "LIKE" - case NotLike => "NOT LIKE" - case RLike => "RLIKE" - case NotRLike => "NOT RLIKE" - } - } -} - -case class Function(name: Name, params: Option[Seq[Expression]], distinct: Boolean = false) extends Expression -{ - override def toString = - name.toString + "(" + - (if(distinct){ "DISTINCT " } else { "" }) + - params.map{ _.mkString(", ")}.getOrElse("*")+ - ")" - def needsParenthesis = false - def children:Seq[Expression] = params.getOrElse { Seq() } - def rebuild(c: Seq[Expression]): Expression = - Function(name, params.map { _ => c }, distinct) -} - -case class JDBCVar() extends Expression -{ - override def toString = "?" - def needsParenthesis = false - def children: Seq[Expression] = Seq() - def rebuild(c: Seq[Expression]): Expression = this -} - -case class CaseWhenElse( - target:Option[Expression], - cases:Seq[(Expression, Expression)], - otherwise:Expression -) extends Expression -{ - override def toString = - { - "CASE "+ - target.map { _.toString }.getOrElse("")+ - cases.map { clause => - "WHEN "+Expression.parenthesize(clause._1)+ - " THEN "+Expression.parenthesize(clause._2) - }.mkString(" ")+" " + - "ELSE "+ otherwise.toString+" END" - } - def needsParenthesis = false - def children: Seq[Expression] = - target.toSeq ++ Seq(otherwise) ++ cases.flatMap { x => Seq(x._1, x._2) } - def rebuild(c: Seq[Expression]): Expression = - { - val (newTarget, newOtherwise, newCases) = - if(c.length % 2 == 0){ (Some(c.head), c.tail.head, c.tail.tail) } - else { (None, c.head, c.tail) } - CaseWhenElse( - newTarget, - newCases.grouped(2).map { x => (x(0), x(1)) }.toSeq, - newOtherwise - ) - } -} - -abstract class NegatableExpression extends Expression -{ - def toNegatedString: String -} - -case class IsNull(target: Expression) extends NegatableExpression -{ - override def toString = - Expression.parenthesize(target)+" IS NULL" - def toNegatedString = - Expression.parenthesize(target)+" IS NOT NULL" - def needsParenthesis = false - def children: Seq[Expression] = Seq(target) - def rebuild(c: Seq[Expression]): Expression = IsNull(c(0)) -} - -case class Not(target: Expression) extends Expression -{ - override def toString = - target match { - case neg:NegatableExpression => neg.toNegatedString - case _ => "NOT "+Expression.parenthesize(target) - } - def needsParenthesis = false - def children: Seq[Expression] = Seq(target) - def rebuild(c: Seq[Expression]): Expression = Not(c(0)) -} - -case class Cast(expression: Expression, t: Name) extends Expression -{ - override def toString = "CAST("+expression.toString+" AS "+t+")" - def needsParenthesis = false - def children: Seq[Expression] = Seq(expression) - def rebuild(c: Seq[Expression]): Expression = Cast(c(0), t) -} - -case class InExpression( - expression: Expression, - source: Either[Seq[Expression], sparsity.select.SelectBody] -) extends NegatableExpression -{ - override def toString = - Expression.parenthesize(expression) + " IN " + sourceString - override def toNegatedString = - Expression.parenthesize(expression) + " NOT IN " + sourceString - def needsParenthesis = false - def sourceString = - source match { - case Left(elems) => elems.map { Expression.parenthesize(_) }.mkString(", ") - case Right(query) => "("+query.toString+")" - } - def children: Seq[Expression] = - Seq(expression) ++ (source match { - case Left(expr) => expr - case Right(_) => Seq() - }) - def rebuild(c: Seq[Expression]): Expression = - InExpression(c.head, - source match { - case Left(_) => Left(c.tail) - case Right(query) => Right(query) - } - ) -} \ No newline at end of file diff --git a/src/main/scala/sparsity/oracle/OracleSQL.scala b/src/main/scala/sparsity/oracle/OracleSQL.scala new file mode 100644 index 0000000..a4d6b91 --- /dev/null +++ b/src/main/scala/sparsity/oracle/OracleSQL.scala @@ -0,0 +1,1225 @@ +package sparsity.oracle + +import fastparse._ +import fastparse.ParsingRun +import fastparse.internal.{Msgs, Util} +import scala.io._ +import java.io._ +import sparsity.common.statement._ +import sparsity.oracle._ +import sparsity.common.select.{SelectBody, _} +import sparsity.common.alter.Materialize +import sparsity.common.expression._ +import sparsity.common.{ + BigInt, + Blob, + Boolean, + Char, + Clob, + DataType, + Date, + Decimal, + Double, + Float, + Integer, + NVarChar2, + Name, + Number, + Raw, + Text, + Timestamp, + VarChar, + VarChar2 +} +import sparsity.common.parser.{SQLBase, SQLBaseObject, StreamParser} +import fastparse.CharIn + +/** Custom whitespace parser that includes SQL comments (-- ... and /* ... */). This allows comments to be automatically + * skipped anywhere whitespace is allowed. Based on FastParse's JavaWhitespace pattern. + */ +object OracleWhitespace { + implicit object whitespace extends Whitespace { + def apply(ctx: ParsingRun[_]) = { + val input = ctx.input + @scala.annotation.tailrec + def rec(current: Int, state: Int): ParsingRun[Unit] = + if (!input.isReachable(current)) { + if (state == 0 || state == 2) { + // Normal whitespace or inside -- comment - both are valid at EOF + if (ctx.verboseFailures) ctx.reportTerminalMsg(current, Msgs.empty) + ctx.freshSuccessUnit(current) + } else if (state == 1 || state == 3) { + // After first '-' or '/' but not a comment - return success at previous position + if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty) + ctx.freshSuccessUnit(current - 1) + } else { + // Inside /* */ comment but reached EOF - this is an error + ctx.cut = true + val res = ctx.freshFailure(current) + if (ctx.verboseFailures) ctx.reportTerminalMsg(current, () => Util.literalize("*/")) + res + } + } else { + val currentChar = input(current) + (state: @scala.annotation.switch) match { + case 0 => + // Normal whitespace mode + (currentChar: @scala.annotation.switch) match { + case ' ' | '\t' | '\n' | '\r' => rec(current + 1, state) + case '-' => rec(current + 1, state = 1) // Might be start of -- comment + case '/' => rec(current + 1, state = 3) // Might be start of /* comment + case _ => + if (ctx.verboseFailures) ctx.reportTerminalMsg(current, Msgs.empty) + ctx.freshSuccessUnit(current) + } + case 1 => + // After first '-', check if second '-' follows + if (currentChar == '-') { + rec(current + 1, state = 2) // Start of -- comment + } else { + // Not a comment, just a single '-', return success at previous position + if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty) + ctx.freshSuccessUnit(current - 1) + } + case 2 => + // Inside -- comment, consume until newline + rec(current + 1, state = if (currentChar == '\n') 0 else state) + case 3 => + // After first '/', check if '*' follows + (currentChar: @scala.annotation.switch) match { + case '*' => rec(current + 1, state = 4) // Start of /* comment + case _ => + // Not a comment, just a single '/', return success at previous position + if (ctx.verboseFailures) ctx.reportTerminalMsg(current - 1, Msgs.empty) + ctx.freshSuccessUnit(current - 1) + } + case 4 => + // Inside /* */ comment, waiting for '*' + rec(current + 1, state = if (currentChar == '*') 5 else state) + case 5 => + // Inside /* */ comment, after '*', checking for '/' + (currentChar: @scala.annotation.switch) match { + case '/' => rec(current + 1, state = 0) // End of /* comment + case '*' => rec(current + 1, state = 5) // Stay in state 5 if another '*' + case _ => rec(current + 1, state = 4) // Go back to state 4 if not '/' + } + } + } + rec(current = ctx.index, state = 0) + } + } +} + +/** Oracle SQL dialect implementation. Supports Oracle-specific features including PROMPT, VARCHAR2, NUMBER types, and + * Oracle CREATE TABLE clauses (TABLESPACE, STORAGE, etc.). For reference, see + * https://github.com/antlr/grammars-v4/blob/master/sql/plsql/PlSqlParser.g4 + */ +class OracleSQL extends SQLBase { + // Configuration + type Stmt = Statement + def caseSensitive = false + def statementTerminator = "/" // Oracle supports both "/" and ";" + def supportsCTEs = true + def supportsReturning = true + def supportsIfExists = true + def stringConcatOp = "||" + + // Oracle-specific config + def supportsPrompt = true + def supportsVarchar2 = true + + override implicit val whitespaceImplicit: fastparse.Whitespace = OracleWhitespace.whitespace + + // Override multDivOp to prevent "/" from being parsed as division when it's a statement terminator. + // In Oracle, "/" on its own line is a statement terminator, not division. + // Division "/" must be followed by an expression operand (identifier, number, paren, etc.) + // Statement terminator "/" is followed by newline/whitespace leading to EOF or next statement. + // We use negative lookahead: "/" is NOT division if followed by whitespace then a statement keyword. + override def multDivOp[$ : P]: P[Unit] = P( + "*" | ("/" ~ !(&( + // After whitespace (handled by implicit whitespace parser), check we're NOT at: + // - End of input + // - A statement-starting keyword + End | + keyword("CREATE") | keyword("SELECT") | keyword("INSERT") | keyword("UPDATE") | + keyword("DELETE") | keyword("DROP") | keyword("ALTER") | keyword("GRANT") | + keyword("REVOKE") | keyword("COMMIT") | keyword("EXPLAIN") | keyword("WITH") | + keyword("BEGIN") | keyword("DECLARE") | keyword("SET") | keyword("PROMPT") | + keyword("RENAME") | keyword("COMMENT") + ))) + ) + + // Helper parser for digits that allows whitespace consumption + // Note: We don't override base digits since it returns P[Unit], but we need P[String] here + def oracleDigits[$ : P]: P[String] = P(CharsWhileIn("0-9").!) + + // Helper parser for NUMBER precision/scale that allows * for unspecified precision + def numberPrecision[$ : P]: P[String] = P("*".! | oracleDigits) + + // Oracle-specific identifier parser that works with OracleWhitespace + // Override base identifier to use Oracle-specific parsing + // Note: In Oracle, keywords can be used as identifiers in many contexts (especially after dots) + // So we allow keywords as identifiers, but prefer quoted identifiers when ambiguous + override def identifier[$ : P]: P[Name] = P( + (("`" ~~/ CharsWhile(_ != '`').! ~ "`") + | ("\"" ~~/ CharsWhile(_ != '"').! ~~ "\"")).map(Name(_, true: scala.Boolean)) + | (CharIn("_a-zA-Z") ~~ CharsWhileIn("a-zA-Z0-9_").?).!.map(Name(_)) + ) + + // Oracle-specific parser for table names with optional database links (schema.table@link) + def oracleQualifiedTableName[$ : P]: P[Name] = P((identifier ~~ ("." ~/ identifier).? ~~ ("@" ~/ identifier).?).map { + case (x: Name, y: Option[Name], link: Option[Name]) => + (x, y, link) match { + case (x, None, None) => x + case (x, Some(y), None) => Name(x.name + "." + y.name, x.quoted || y.quoted) + case (x, None, Some(link)) => Name(x.name + "@" + link.name, x.quoted || link.quoted) + case (x, Some(y), Some(link)) => + Name(x.name + "." + y.name + "@" + link.name, x.quoted || y.quoted || link.quoted) + } + }) + + // Oracle-specific parser for dotted pairs with optional database links (schema.table@link) + def oracleDottedPair[$ : P]: P[(Option[Name], Name)] = P( + (identifier ~~ ("." ~/ identifier).? ~~ ("@" ~/ identifier).?).map { + case (x: Name, y: Option[Name], link: Option[Name]) => + (x, y, link) match { + case (x, None, None) => (None, x) + case (x, Some(y), None) => (Some(x), y) + case (x, None, Some(link)) => (None, Name(x.name + "@" + link.name, x.quoted || link.quoted)) + case (x, Some(y), Some(link)) => + (Some(x), Name(y.name + "@" + link.name, y.quoted || link.quoted)) + } + } + ) + + // Override alias to use rawIdentifier which checks for keywords + // This prevents clause keywords like WHERE from being parsed as aliases + override def alias[$ : P]: P[Name] = P(keyword("AS").? ~ rawIdentifier) + + // Override simpleFromElement to support Oracle database links + // Note: For subqueries, alias is optional in Oracle (unlike base class which requires it) + override def simpleFromElement[$ : P]: P[FromElement] = P( + (("(" ~ select ~ ")" ~ alias.?).map { + case (body, Some(alias)) => FromSelect(body, alias) + case (body, None) => + FromSelect(body, Name("")) // Anonymous alias for subqueries without alias + }) + | ((oracleDottedPair ~ alias.?).map { case (schema, table, alias) => + FromTable(schema, table, alias) + }) + | (("(" ~ fromElement ~ ")" ~ alias.?).map { + case (from, None) => from + case (from, Some(alias)) => from.withAlias(alias) + }) + ) + + // Override column to use oracleDottedPair for proper Oracle identifier parsing + override def column[$ : P]: P[Column] = P(oracleDottedPair.map(x => Column(x._2, x._1))) + + // Override leaf to allow keywords as function names (e.g., REPLACE(...)) + override def leaf[$ : P]: P[Expression] = P( + parens | + primitive | + jdbcvar | + caseWhen | ifThenElse | + cast | + nullLiteral | + // Window functions: function(...) OVER (...) - parse function then check for OVER + (function ~ (keyword("OVER") ~/ + "(" ~/ + ( + (keyword("PARTITION") ~/ keyword("BY") ~/ expressionList).? ~ + (keyword("ORDER") ~/ keyword("BY") ~/ orderBy + .rep(sep = comma, min = 1) + .map(_.toSeq)).? + ) ~ + ")").?).map { + case (func, Some((partitionByOpt, orderByOpt))) => + WindowFunction(func, partitionByOpt, orderByOpt.getOrElse(Seq())) + case (func, None) => + func + } | + // Try column first to avoid conflicts with quoted identifiers + column | + // Allow keywords as function names when followed by ( + // Use lookahead that explicitly excludes quoted identifiers (starts with " or `) + &(!CharIn("\"`") ~ CharIn("_a-zA-Z") ~ CharsWhileIn("a-zA-Z0-9_").? ~ "(") ~ function + ) + + // Override parens to handle subqueries: (SELECT ...) + override def parens[$ : P]: P[Expression] = P( + // Try subquery first: (SELECT ...) + ("(" ~/ &(keyword("SELECT")) ~/ select ~ ")").map(Subquery(_)) | + // Fall back to regular parenthesized expression: (expression) + ("(" ~ expression ~ ")") + ) + + // Window function parser: function(...) OVER (PARTITION BY ... ORDER BY ...) + def windowFunction[$ : P]: P[WindowFunction] = P( + function ~ + keyword("OVER") ~/ + "(" ~/ + ( + (keyword("PARTITION") ~/ keyword("BY") ~/ expressionList).? ~ + (keyword("ORDER") ~/ keyword("BY") ~/ orderBy + .rep(sep = comma, min = 1) + .map(_.toSeq)).? + ) ~ + ")" + ).map { case (func, (partitionByOpt, orderByOpt)) => + WindowFunction(func, partitionByOpt, orderByOpt.getOrElse(Seq())) + } + + // Override function to allow keywords as function names + // But explicitly reject quoted identifiers (they should be columns, not functions) + override def function[$ : P]: P[Function] = P( + // First check that we don't have a quoted identifier + !CharIn("\"`") ~ + (identifier ~ "(" ~/ + keyword("DISTINCT").!.?.map { + _ != None + } ~ + ("*".!.map(_ => None) + | expressionList.map(Some(_))) ~ ")").map { case (name, distinct, args) => + Function(name, args, distinct) + } + ) + + // Parser implementations + override def dialectSpecificStatement[$ : P]: P[OracleStatement] = P( + prompt | commit | setScanOff | setScanOn | setDefineOff | grant | revoke | dropSynonym | rename | comment + ) + + def prompt[$ : P]: P[Prompt] = P( + keyword("PROMPT") ~/ CharsWhile(c => c != '\n' && c != '/' && c != ';').!.map(_.trim) + .map(Prompt(_)) + ) + + def commit[$ : P]: P[Commit] = P(keyword("COMMIT") ~/ Pass.map(_ => Commit())) + + def setDefineOff[$ : P]: P[SetDefineOff] = P( + keyword("SET") ~ &(keyword("DEFINE")) ~ keyword("DEFINE") ~ keyword("OFF") ~ Pass.map(_ => SetDefineOff()) + ) + + def setScanOff[$ : P]: P[SetScanOff] = P( + keyword("SET") ~ StringInIgnoreCase("SCAN") ~ StringInIgnoreCase("OFF") ~ Pass.map(_ => SetScanOff()) + ) + + def setScanOn[$ : P]: P[SetScanOn] = P( + keyword("SET") ~ StringInIgnoreCase("SCAN") ~ StringInIgnoreCase("ON") ~ Pass.map(_ => SetScanOn()) + ) + + def grant[$ : P]: P[Grant] = P( + keyword("GRANT") ~/ + grantPrivilegeList ~ + keyword("ON") ~/ + oracleQualifiedTableName ~ + keyword("TO") ~/ + identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ + (keyword("WITH") ~/ keyword("GRANT") ~/ StringInIgnoreCase("OPTION")).!.?.map(_.isDefined) + ).map { case (privileges, onObject, toUsers, withGrantOption) => + Grant(privileges, onObject, toUsers, withGrantOption) + } + + def revoke[$ : P]: P[Revoke] = P( + keyword("REVOKE") ~/ + grantPrivilegeList ~ + keyword("ON") ~/ + oracleQualifiedTableName ~ + keyword("FROM") ~/ + identifier.rep(sep = comma, min = 1).map(_.toSeq) + ).map { case (privileges, onObject, fromUsers) => + Revoke(privileges, onObject, fromUsers) + } + + def grantPrivilegeList[$ : P]: P[Seq[String]] = P(grantPrivilege ~ (comma ~/ grantPrivilege).rep).map { + case (first, rest) => Seq(first) ++ rest.toSeq + } + + def grantPrivilege[$ : P]: P[String] = P( + // Multi-word privileges first (must come before single-word to avoid partial matches) + // Parse "ON COMMIT REFRESH" as a single unit + (StringInIgnoreCase("ON") ~/ StringInIgnoreCase("COMMIT") ~/ StringInIgnoreCase("REFRESH")).!.map(_ => + "ON COMMIT REFRESH" + ) | + (StringInIgnoreCase("QUERY") ~/ StringInIgnoreCase("REWRITE")).!.map(_ => "QUERY REWRITE") | + // Single-word privileges (exclude "ON" and "QUERY" to avoid conflicts) + StringInIgnoreCase( + "ALTER", + "DELETE", + "INDEX", + "INSERT", + "REFERENCES", + "SELECT", + "UPDATE", + "ALL", + "DEBUG", + "FLASHBACK" + ).! + ) + + // Parser for Oracle type names that may be multi-word (e.g., DOUBLE PRECISION, SMALLINT) + def oracleTypeName[$ : P]: P[Name] = P( + (keyword("DOUBLE") ~ keyword("PRECISION")).map(_ => Name("DOUBLE PRECISION")) | + keyword("DOUBLE").map(_ => Name("DOUBLE")) | + keyword("SMALLINT").map(_ => Name("SMALLINT")) | + identifier + ) + + // Override oneOrMoreAttributes to use OracleWhitespace instead of MultiLineWhitespace + override def oneOrMoreAttributes[$ : P]: P[Seq[Name]] = P( + ("(" ~/ identifier.rep(sep = comma, min = 1) ~ ")") + | identifier.map(Seq(_)) + ) + + // Helper parser for Oracle type parameters (NUMBER(p,s), VARCHAR2(size BYTE|CHAR), etc.) + // Returns the parameter string (e.g., "(10,2)", "(15 BYTE)") + def typeParameters[$ : P]: P[String] = P( + "(" ~/ + ( + // NUMBER(*,0) or NUMBER(3,2) or NUMBER(3,2 BYTE) - precision and scale with optional unit + // Try this FIRST to avoid backtracking issues with NUMBER(10, 2) + (numberPrecision ~ + "," ~ numberPrecision ~ + (StringInIgnoreCase("BYTE", "CHAR").!.?)).map { case (size, scale, unit) => + val unitStr = unit.map(" " + _).getOrElse("") + s"($size,$scale$unitStr)" + } | + // VARCHAR2(15 BYTE) or NUMBER(3) or NUMBER(3 BYTE) - size with optional unit + // Parse digits directly to allow whitespace consumption + (oracleDigits ~ + (StringInIgnoreCase("BYTE", "CHAR").!.?)).map { case (size, unit) => + val unitStr = unit.map(" " + _).getOrElse("") + s"($size$unitStr)" + } + ) ~ + ")" + ).! + + // Override tableField to handle Oracle-specific data type syntax + override def tableField[$ : P]: P[Either[TableAnnotation, ColumnDefinition]] = P( + ( + keyword("CONSTRAINT") ~ + identifier ~ + keyword("UNIQUE") ~/ + oneOrMoreAttributes + ).map { case (_, attrs) => Left(TableUnique(attrs)) } | ( + keyword("CONSTRAINT") ~ + identifier ~ + keyword("PRIMARY") ~/ + keyword("KEY") ~/ + oneOrMoreAttributes + ).map { case (_, attrs) => Left(TablePrimaryKey(attrs)) } | ( + keyword("UNIQUE") ~/ + oneOrMoreAttributes.map(attrs => Left(TableUnique(attrs))) + ) | ( + keyword("PRIMARY") ~/ + keyword("KEY") ~/ + oneOrMoreAttributes.map(attrs => Left(TablePrimaryKey(attrs))) + ) | ( + keyword("INDEX") ~/ + keyword("ON") ~ + oneOrMoreAttributes.map(attrs => Left(TableIndexOn(attrs))) + ) | ( + ( + identifier ~/ + oracleTypeName ~ + typeParameters.? ~ + oracleColumnAnnotation.rep + ).map { case (name, typeName, paramsOpt, annotations) => + val fullTypeName = Name(typeName.name + paramsOpt.getOrElse("")) + Right(ColumnDefinition(name, fullTypeName, Seq(), annotations)) + } + ) + ) + + // Override columnAnnotation to handle Oracle-specific ENABLE/DISABLE keywords and DEFAULT ON NULL + def oracleColumnAnnotation[$ : P]: P[ColumnAnnotation] = P( + // Try DEFAULT ON NULL first (Oracle-specific) - use lookahead to check full sequence + (&(keyword("DEFAULT") ~ keyword("ON") ~ keyword("NULL")) ~ + keyword("DEFAULT") ~/ + keyword("ON") ~/ + keyword("NULL") ~/ + ( + ("(" ~ expression ~ ")") + | function // Function calls with parentheses + | functionCallNoParens // Function calls without parentheses (CURRENT_TIMESTAMP, etc.) + | primitive + )).map(ColumnDefaultValue(_)) | + // DEFAULT NULL (without ON) - use lookahead to ensure NULL follows + (&(keyword("DEFAULT") ~ keyword("NULL")) ~ + keyword("DEFAULT") ~/ + keyword("NULL")).map(_ => ColumnDefaultValue(NullPrimitive())) | + // Standalone NULL annotation (Oracle allows explicit NULL, though it's redundant) + keyword("NULL").map(_ => ColumnIsNullable()) | + // Fall back to standard column annotations (including regular DEFAULT and NOT NULL) + // Oracle allows ENABLE/DISABLE followed by VALIDATE/NOVALIDATE + columnAnnotation ~ (keyword("ENABLE") | keyword("DISABLE")).? ~ (StringInIgnoreCase( + "VALIDATE" + ) | StringInIgnoreCase("NOVALIDATE")).? + ) + + // Override common statements to return OracleStatement wrappers + override def insert[$ : P]: P[OracleStatement] = P( + ( + keyword("INSERT") ~/ + ( + keyword("OR") ~/ + keyword("REPLACE") + ).!.?.map { case None => false; case _ => true } ~ + keyword("INTO") ~/ + oracleQualifiedTableName ~ + ("(" ~/ + identifier ~ + (comma ~/ identifier).rep ~ + ")").map(x => Seq(x._1) ++ x._2).? ~ + ( + // WITH ... SELECT + (&(keyword("WITH")) ~/ withSelect.map { case (ctes, body) => SelectInsert(body) }) | + // SELECT (without WITH) + (&(keyword("SELECT")) ~/ select.map(SelectInsert(_))) | + // VALUES + (&(keyword("VALUES")) ~/ valueList) + ) + ).map { case (orReplace, table, columns, values) => + OracleInsert(table, columns, values, orReplace) + } + ) + + // Parser for SELECT with optional WITH clause + def withSelect[$ : P]: P[(Seq[WithClause], SelectBody)] = P(withClauseList ~ select).map { case (ctes, body) => + (ctes, body) + } + + def withClauseList[$ : P]: P[Seq[WithClause]] = P( + keyword("WITH") ~/ + withClause.rep(sep = comma, min = 1).map(_.toSeq) + ) + + def withClause[$ : P]: P[WithClause] = P( + identifier ~ + keyword("AS") ~/ + "(" ~/ select ~ ")" + ).map { case (name, body) => WithClause(body, name) } + + override def update[$ : P]: P[OracleStatement] = P( + ( + keyword("UPDATE") ~/ + oracleQualifiedTableName ~ + // Optional table alias: identifier that is not part of a dotted name and is followed by SET + // Use negative lookahead to ensure we don't match "SET" as an alias + (!keyword("SET") ~~ identifier).? ~ + keyword("SET") ~/ + ( + dottedPair.map { case (schema, col) => + val colName = + schema.map(s => Name(s.name + "." + col.name, s.quoted || col.quoted)).getOrElse(col) + colName + } ~ + "=" ~/ + expression + ).rep(sep = comma, min = 1) ~ + ( + StringInIgnoreCase("WHERE") ~/ + expression + ).? + ).map { case (table, _, set, where) => + OracleUpdate(table, set, where) + } + ) + + override def delete[$ : P]: P[OracleStatement] = P( + ( + keyword("DELETE") ~/ + keyword("FROM").!.?.map(_ => ()) ~/ // FROM is optional in Oracle + oracleQualifiedTableName ~ + ( + keyword("WHERE") ~/ + expression + ).? + ).map { case (table, where) => OracleDelete(table, where) } + ) + + override def alterView[$ : P]: P[OracleStatement] = P( + ( + keyword("ALTER") ~ + keyword("VIEW") ~/ + oracleQualifiedTableName ~ + ( + (keyword("MATERIALIZE").!.map(_ => Materialize(true))) + | (keyword("DROP") ~ + keyword("MATERIALIZE").!.map(_ => Materialize(false))) + ) + ).map { case (name, op) => OracleAlterView(name, op) } + ) + + def alterTable[$ : P]: P[OracleStatement] = P( + ( + keyword("ALTER") ~ + keyword("TABLE") ~/ + oracleQualifiedTableName ~ + ( + alterTableModify | + alterTableRenameColumn | + alterTableAdd | + alterTableDrop // Combined DROP COLUMN and DROP CONSTRAINT parser + ) + ).map { case (name, action) => OracleAlterTable(name, action) } + ) + + def alterTableModify[$ : P]: P[AlterTableAction] = P( + keyword("MODIFY") ~/ + ( + // MODIFY(column type, ...) - with parentheses + ("(" ~/ + columnModification.rep(sep = comma, min = 1) ~ + ")").map(mods => AlterTableModify(mods)) | + // MODIFY column type - without parentheses (single or multiple columns) + (columnModification ~ + (comma ~/ keyword("MODIFY") ~/ columnModification).rep).map { case (first, rest) => + AlterTableModify(Seq(first) ++ rest) + } + ) + ) + + def alterTableRenameColumn[$ : P]: P[AlterTableAction] = P( + keyword("RENAME") ~/ + keyword("COLUMN") ~/ + identifier ~ + keyword("TO") ~/ + identifier + ).map { case (oldName, newName) => AlterTableRenameColumn(oldName, newName) } + + def alterTableAdd[$ : P]: P[AlterTableAction] = P( + keyword("ADD") ~/ + ( + // ADD CONSTRAINT constraint_name UNIQUE/PRIMARY KEY/FOREIGN KEY (...) + (keyword("CONSTRAINT") ~/ + identifier ~ + tableConstraint).map { case (constraintName, constraint) => + AlterTableAddConstraint(constraintName, constraint) + } | + // ADD(column, ...) - with parentheses + ("(" ~/ + tableField.rep(sep = comma, min = 1) ~ + ")").map { fields => + val columns = fields.collect { case Right(colDef) => colDef } + val annotations = fields.collect { case Left(ann) => ann } + AlterTableAdd(columns.toSeq, annotations.toSeq) + } | + // ADD column - without parentheses (single or multiple columns) + (tableField ~ + (keyword("ADD") ~/ tableField).rep).map { case (first, rest) => + val allFields = Seq(first) ++ rest + val columns = allFields.collect { case Right(colDef) => colDef } + val annotations = allFields.collect { case Left(ann) => ann } + AlterTableAdd(columns, annotations) + } + ) + ) + + def tableConstraint[$ : P]: P[TableConstraint] = P( + // UNIQUE (col1, col2, ...) + (keyword("UNIQUE") ~/ + "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")").map { columns => + UniqueConstraint(columns) + } | + // PRIMARY KEY (col1, col2, ...) + (keyword("PRIMARY") ~/ + keyword("KEY") ~/ + "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")").map { columns => + PrimaryKeyConstraint(columns) + } | + // FOREIGN KEY (col1, ...) REFERENCES table(col1, ...) + (keyword("FOREIGN") ~/ + keyword("KEY") ~/ + "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")" ~ + keyword("REFERENCES") ~/ + oracleQualifiedTableName ~ + "(" ~/ identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ ")" ~ + (keyword("ON") ~/ keyword("DELETE") ~/ referentialAction).? ~ + (keyword("ON") ~/ keyword("UPDATE") ~/ referentialAction).?).map { + case (columns, refTable, refColumns, onDelete, onUpdate) => + ForeignKeyConstraint(columns, refTable, refColumns, onDelete, onUpdate) + } + ) + + def referentialAction[$ : P]: P[ReferentialAction] = P( + (keyword("NO") ~/ keyword("ACTION")).map(_ => ReferentialAction.NoAction) | + keyword("RESTRICT").map(_ => ReferentialAction.Restrict) | + keyword("CASCADE").map(_ => ReferentialAction.Cascade) | + (keyword("SET") ~/ keyword("NULL")).map(_ => ReferentialAction.SetNull) | + (keyword("SET") ~/ keyword("DEFAULT")).map(_ => ReferentialAction.SetDefault) + ) + + def alterTableDrop[$ : P]: P[AlterTableAction] = P( + keyword("DROP") ~/ + ( + // DROP (col1, col2, ...) - multiple columns in parentheses + ("(" ~/ + identifier.rep(sep = comma, min = 1).map(_.toSeq) ~ + ")").map(cols => AlterTableDropColumns(cols)) | + // DROP COLUMN col_name - single column + keyword("COLUMN") ~/ identifier.map { colName => + AlterTableDropColumn(colName) + } | + // DROP CONSTRAINT constraint_name + keyword("CONSTRAINT") ~/ identifier.map { constraintName => + AlterTableDropConstraint(constraintName) + } + ) + ) + + def alterTableDropColumn[$ : P]: P[AlterTableAction] = P( + keyword("DROP") ~/ + keyword("COLUMN") ~/ + identifier + ).map(colName => AlterTableDropColumn(colName)) + + def alterTableDropConstraint[$ : P]: P[AlterTableAction] = P( + keyword("DROP") ~/ + keyword("CONSTRAINT") ~/ + identifier + ).map(constraintName => AlterTableDropConstraint(constraintName)) + + def columnModification[$ : P]: P[ColumnModification] = P( + ( + identifier ~/ + // Type is optional - MODIFY column DEFAULT/NOT NULL is valid without type + // Use lookahead to check if this looks like a type (not a keyword like DEFAULT/NOT) + ( + // Only parse type if NOT followed by annotation keywords + !(&(keyword("DEFAULT") | keyword("NOT") | keyword("PRIMARY"))) ~ + (oracleTypeName ~ + typeParameters.?).map { case (typeName, paramsOpt) => + Some(Name(typeName.name + paramsOpt.getOrElse(""))) + } + ).? ~ + oracleColumnAnnotation.rep + ).map { case (name: Name, typeOpt, annotations) => + ColumnModification(name, typeOpt.flatten, annotations) + } + ) + + override def dropTableOrView[$ : P]: P[OracleStatement] = P( + ( + keyword("DROP") ~ + keyword("TABLE", "VIEW").!.map(_.toUpperCase) ~/ + ifExists ~ + oracleQualifiedTableName ~ + (keyword("CASCADE") ~/ StringInIgnoreCase("CONSTRAINTS")).!.?.map(_.isDefined) ~ + keyword("PURGE").!.?.map(_.isDefined) + ).map { + case ("TABLE", ifExists, name, cascadeConstraints, purge) => + OracleDropTable(name, ifExists, cascadeConstraints, purge) + case ("VIEW", ifExists, name, _, _) => + OracleDropView(name, ifExists) + case (_, _, _, _, _) => + throw new Exception("Internal Error") + } + ) + + def dropSynonym[$ : P]: P[OracleStatement] = P( + ( + keyword("DROP") ~ + keyword("SYNONYM") ~/ + ifExists ~ + oracleQualifiedTableName + ).map { case (ifExists, name) => OracleDropSynonym(name, ifExists) } + ) + + def rename[$ : P]: P[OracleStatement] = P( + ( + keyword("RENAME") ~/ + oracleQualifiedTableName ~ + keyword("TO") ~/ + oracleQualifiedTableName + ).map { case (oldName, newName) => OracleRename(oldName, newName) } + ) + + // Parser for fully qualified column names (schema.table.column) + def oracleQualifiedColumnName[$ : P]: P[Name] = P( + (identifier ~ ("." ~/ identifier).rep(min = 1, max = 2)).map { case (first, rest) => + val parts = Seq(first) ++ rest.toSeq + val nameStr = parts.map(_.name).mkString(".") + val quoted = parts.exists(_.quoted) + Name(nameStr, quoted) + } + ) + + def comment[$ : P]: P[Comment] = P( + ( + keyword("COMMENT") ~/ + keyword("ON") ~/ + keyword("COLUMN") ~/ + oracleQualifiedColumnName ~ + keyword("IS") ~/ + quotedString + ).map { case (columnName, commentText) => Comment(columnName, commentText) } + ) + + // PL/SQL recursive parser combinators for blocks and loops + // Uses FastParse's recursive descent to naturally handle nesting + + // Parse a PL/SQL label: <