diff --git a/Cargo.toml b/Cargo.toml index 73091085..7d2457ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [features] default = [] extensions = ["dep:serde_yaml"] -parse = ["dep:hex", "dep:thiserror", "semver"] +parse = ["dep:hex", "dep:thiserror", "dep:serde_yaml", "semver"] protoc = ["dep:protobuf-src"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] @@ -38,6 +38,8 @@ pbjson = { version = "0.8.0", optional = true } pbjson-types = { version = "0.8.0", optional = true } prost = "0.14.1" prost-types = "0.14.1" +# Required by generated text schemas: the typify-generated code emits +# ::regress::Regex for `pattern` validations. regress = "0.10.4" semver = { version = "1.0.27", optional = true } serde = { version = "1.0.228", features = ["derive"] } diff --git a/src/extensions.rs b/src/extensions.rs index 240799da..f6c72067 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -6,6 +6,7 @@ //! included in the packaged crate, ignored by git, and automatically kept //! in-sync. +#[cfg(feature = "extensions")] include!(concat!(env!("OUT_DIR"), "/extensions.in")); #[cfg(test)] diff --git a/src/parse/context.rs b/src/parse/context.rs index 58dc6af3..884ec47e 100644 --- a/src/parse/context.rs +++ b/src/parse/context.rs @@ -4,9 +4,9 @@ use thiserror::Error; -use crate::parse::{ - Anchor, Parse, proto::extensions::SimpleExtensionUrn, text::simple_extensions::SimpleExtensions, -}; +use crate::parse::proto::extensions::SimpleExtensionUrn; +use crate::parse::text::simple_extensions::ExtensionFile; +use crate::parse::{Anchor, Parse}; /// A parse context. /// @@ -22,22 +22,24 @@ pub trait Context { { item.parse(self) } +} +pub trait ProtoContext: Context { /// Add a [SimpleExtensionUrn] to this context. Must return an error for duplicate - /// anchors or when the urn is not supported. + /// anchors or when the URI is not supported. /// /// This function must eagerly resolve and parse the simple extension, returning an /// error if either fails. fn add_simple_extension_urn( &mut self, simple_extension_urn: &SimpleExtensionUrn, - ) -> Result<&SimpleExtensions, ContextError>; + ) -> Result<&ExtensionFile, ContextError>; /// Returns the simple extensions for the given simple extension anchor. fn simple_extensions( &self, anchor: &Anchor, - ) -> Result<&SimpleExtensions, ContextError>; + ) -> Result<&ExtensionFile, ContextError>; } /// Parse context errors. @@ -57,46 +59,39 @@ pub enum ContextError { } #[cfg(test)] -pub(crate) mod tests { +pub(crate) mod fixtures { use std::collections::{HashMap, hash_map::Entry}; use crate::parse::{ Anchor, context::ContextError, proto::extensions::SimpleExtensionUrn, - text::simple_extensions::SimpleExtensions, + text::simple_extensions::ExtensionFile, }; /// A test context. /// /// This currently mocks support for simple extensions (does not resolve or /// parse). + #[derive(Default)] pub struct Context { - empty_simple_extensions: SimpleExtensions, - simple_extensions: HashMap, SimpleExtensionUrn>, + simple_extensions: HashMap, ExtensionFile>, } - impl Default for Context { - fn default() -> Self { - Self { - empty_simple_extensions: SimpleExtensions {}, - simple_extensions: Default::default(), - } - } - } + impl super::Context for Context {} - impl super::Context for Context { + impl super::ProtoContext for Context { fn add_simple_extension_urn( &mut self, simple_extension_urn: &crate::parse::proto::extensions::SimpleExtensionUrn, - ) -> Result<&SimpleExtensions, ContextError> { + ) -> Result<&ExtensionFile, ContextError> { match self.simple_extensions.entry(simple_extension_urn.anchor()) { Entry::Occupied(_) => Err(ContextError::DuplicateSimpleExtension( simple_extension_urn.anchor(), )), Entry::Vacant(entry) => { - // TODO: fetch - entry.insert(simple_extension_urn.clone()); - // For now just return an empty extension - Ok(&self.empty_simple_extensions) + let f = ExtensionFile::empty(simple_extension_urn.urn().clone()); + let ext_ref = entry.insert(f); + + Ok(ext_ref) } } } @@ -104,10 +99,9 @@ pub(crate) mod tests { fn simple_extensions( &self, anchor: &Anchor, - ) -> Result<&SimpleExtensions, ContextError> { + ) -> Result<&ExtensionFile, ContextError> { self.simple_extensions - .contains_key(anchor) - .then_some(&self.empty_simple_extensions) + .get(anchor) .ok_or(ContextError::UndefinedSimpleExtension(*anchor)) } } diff --git a/src/parse/proto/extensions/simple_extension_urn.rs b/src/parse/proto/extensions/simple_extension_urn.rs index 7c164b83..9f34b85b 100644 --- a/src/parse/proto/extensions/simple_extension_urn.rs +++ b/src/parse/proto/extensions/simple_extension_urn.rs @@ -7,7 +7,7 @@ use std::str::FromStr; use thiserror::Error; use crate::{ - parse::{Anchor, Context, Parse, context::ContextError}, + parse::{Anchor, Parse, context::ContextError, context::ProtoContext}, proto, urn::{InvalidUrn, Urn}, }; @@ -50,7 +50,7 @@ pub enum SimpleExtensionUrnError { Context(#[from] ContextError), } -impl Parse for proto::extensions::SimpleExtensionUrn { +impl Parse for proto::extensions::SimpleExtensionUrn { type Parsed = SimpleExtensionUrn; type Error = SimpleExtensionUrnError; @@ -90,7 +90,7 @@ impl From for proto::extensions::SimpleExtensionUrn { #[cfg(test)] mod tests { use super::*; - use crate::parse::{Context as _, context::tests::Context}; + use crate::parse::{Context as _, context::fixtures::Context}; #[test] fn parse() -> Result<(), SimpleExtensionUrnError> { diff --git a/src/parse/proto/plan_version.rs b/src/parse/proto/plan_version.rs index e00d537f..6c4c247c 100644 --- a/src/parse/proto/plan_version.rs +++ b/src/parse/proto/plan_version.rs @@ -3,7 +3,7 @@ //! Parsing of [proto::PlanVersion]. use crate::{ - parse::{Parse, context::Context, proto::Version}, + parse::{Parse, context::ProtoContext, proto::Version}, proto, }; use thiserror::Error; @@ -38,7 +38,7 @@ pub enum PlanVersionError { Version(#[from] VersionError), } -impl Parse for proto::PlanVersion { +impl Parse for proto::PlanVersion { type Parsed = PlanVersion; type Error = PlanVersionError; @@ -71,7 +71,7 @@ impl From for proto::PlanVersion { mod tests { use super::*; use crate::{ - parse::{context::tests::Context, proto::VersionError}, + parse::{context::fixtures::Context, proto::VersionError}, version, }; diff --git a/src/parse/proto/version.rs b/src/parse/proto/version.rs index cf692829..78c7d239 100644 --- a/src/parse/proto/version.rs +++ b/src/parse/proto/version.rs @@ -3,7 +3,7 @@ //! Parsing of [proto::Version]. use crate::{ - parse::{Parse, context::Context}, + parse::{Parse, context::ProtoContext}, proto, version, }; use hex::FromHex; @@ -75,7 +75,7 @@ pub enum VersionError { Substrait(semver::Version, semver::VersionReq), } -impl Parse for proto::Version { +impl Parse for proto::Version { type Parsed = Version; type Error = VersionError; @@ -142,7 +142,7 @@ impl From for proto::Version { #[cfg(test)] mod tests { use super::*; - use crate::parse::context::tests::Context; + use crate::parse::context::fixtures::Context; #[test] fn version() -> Result<(), VersionError> { diff --git a/src/parse/text/mod.rs b/src/parse/text/mod.rs index 4f7b641e..73219362 100644 --- a/src/parse/text/mod.rs +++ b/src/parse/text/mod.rs @@ -1,5 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [text](crate::text) types. +//! Utilities for working with Substrait *text* objects. +//! +//! The generated [`crate::text`] module exposes the raw YAML-derived structs +//! (e.g. [`crate::text::simple_extensions::SimpleExtensions`]). This module +//! provides parsing helpers that validate those raw values and offer +//! higher-level wrappers for validation, lookups, and combining into protobuf +//! objects. pub mod simple_extensions; diff --git a/src/parse/text/simple_extensions/argument.rs b/src/parse/text/simple_extensions/argument.rs index 49896f02..783f714b 100644 --- a/src/parse/text/simple_extensions/argument.rs +++ b/src/parse/text/simple_extensions/argument.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [simple_extensions::ArgumentsItem]. +//! Parsing of type arguments: [`simple_extensions::ArgumentsItem`]. use std::{collections::HashSet, ops::Deref}; @@ -11,21 +11,24 @@ use crate::{ text::simple_extensions, }; -/// A parsed [simple_extensions::ArgumentsItem]. +/// A parsed [`simple_extensions::ArgumentsItem`]. #[derive(Clone, Debug)] pub enum ArgumentsItem { - /// Arguments that support a fixed set of declared values as constant arguments. + /// Arguments that support a fixed set of declared values as constant + /// arguments. EnumArgument(EnumerationArg), /// Arguments that refer to a data value. ValueArgument(ValueArg), - /// Arguments that are used only to inform the evaluation and/or type derivation of the function. + /// Arguments that are used only to inform the evaluation and/or type + /// derivation of the function. TypeArgument(TypeArg), } impl ArgumentsItem { - /// Parses an `Option` field, rejecting it if an empty string is provided. + /// Parses an `Option` field, rejecting it if an empty string is + /// provided. #[inline] fn parse_optional_string( name: &str, @@ -75,7 +78,7 @@ impl From for simple_extensions::ArgumentsItem { } } -/// Parse errors for [simple_extensions::ArgumentsItem]. +/// Parse errors for [`simple_extensions::ArgumentsItem`]. #[derive(Debug, Error, PartialEq)] pub enum ArgumentsItemError { /// Invalid enumeration options. @@ -103,21 +106,21 @@ pub struct EnumerationArg { impl EnumerationArg { /// Returns the name of this argument. /// - /// See [simple_extensions::EnumerationArg::name]. + /// See [`simple_extensions::EnumerationArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::EnumerationArg::description]. + /// See [`simple_extensions::EnumerationArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } /// Returns the options of this argument. /// - /// See [simple_extensions::EnumerationArg::options]. + /// See [`simple_extensions::EnumerationArg::options`]. pub fn options(&self) -> &EnumOptions { &self.options } @@ -177,18 +180,26 @@ impl Parse for simple_extensions::EnumOptions { type Error = EnumOptionsError; fn parse(self, _ctx: &mut C) -> Result { - let options = self.0; + self.try_into() + } +} + +impl TryFrom for EnumOptions { + type Error = EnumOptionsError; + + fn try_from(raw: simple_extensions::EnumOptions) -> Result { + let options = raw.0; if options.is_empty() { return Err(EnumOptionsError::EmptyList); } let mut unique_options = HashSet::new(); - for option in options.iter() { + for option in options.into_iter() { if option.is_empty() { return Err(EnumOptionsError::EmptyOption); } if !unique_options.insert(option.clone()) { - return Err(EnumOptionsError::DuplicatedOption(option.clone())); + return Err(EnumOptionsError::DuplicatedOption(option)); } } @@ -202,7 +213,7 @@ impl From for simple_extensions::EnumOptions { } } -/// Parse errors for [simple_extensions::EnumOptions]. +/// Parse errors for [`simple_extensions::EnumOptions`]. #[derive(Debug, Error, PartialEq)] pub enum EnumOptionsError { /// Empty list. @@ -229,26 +240,27 @@ pub struct ValueArg { /// A fully defined type or a type expression. /// - /// todo: implement parsed [simple_extensions::Type]. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` + /// parser) so the caller does not have to interpret the raw string. value: simple_extensions::Type, - /// Whether this argument is required to be a constant for invocation. - /// For example, in some system a regular expression pattern would only be accepted as a literal - /// and not a column value reference. + /// Whether this argument is required to be a constant for invocation. For + /// example, in some system a regular expression pattern would only be + /// accepted as a literal and not a column value reference. constant: Option, } impl ValueArg { /// Returns the name of this argument. /// - /// See [simple_extensions::ValueArg::name]. + /// See [`simple_extensions::ValueArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::ValueArg::description]. + /// See [`simple_extensions::ValueArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } @@ -256,7 +268,7 @@ impl ValueArg { /// Returns the constant of this argument. /// Defaults to `false` if the underlying value is `None`. /// - /// See [simple_extensions::ValueArg::constant]. + /// See [`simple_extensions::ValueArg::constant`]. pub fn constant(&self) -> bool { self.constant.unwrap_or(false) } @@ -300,10 +312,10 @@ impl From for ArgumentsItem { } } -/// Arguments that are used only to inform the evaluation and/or type derivation of the function. +/// A type argument to a parameterized type, e.g. the `T` in `List`. #[derive(Clone, Debug, PartialEq)] pub struct TypeArg { - /// A human-readable name for this argument to help clarify use. + /// A human-readable name for this argument to clarify use. name: Option, /// Additional description for this argument. @@ -311,21 +323,22 @@ pub struct TypeArg { /// A partially or completely parameterized type. E.g. `List` or `K`. /// - /// todo: implement parsed [simple_extensions::Type]. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` + /// parser) so the caller does not have to interpret the raw string. type_: String, } impl TypeArg { /// Returns the name of this argument. /// - /// See [simple_extensions::TypeArg::name]. + /// See [`simple_extensions::TypeArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::TypeArg::description]. + /// See [`simple_extensions::TypeArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } @@ -371,7 +384,7 @@ impl From for ArgumentsItem { mod tests { use super::*; use crate::text::simple_extensions; - use crate::{parse::context::tests::Context, text}; + use crate::{parse::context::fixtures::Context, text}; #[test] fn parse_enum_argument() -> Result<(), ArgumentsItemError> { @@ -677,7 +690,7 @@ mod tests { #[test] fn parse_extensions() { use crate::extensions::EXTENSIONS; - use crate::parse::context::tests::Context; + use crate::parse::context::fixtures::Context; macro_rules! parse_arguments { ($url:expr, $fns:expr) => { diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs new file mode 100644 index 00000000..56344c00 --- /dev/null +++ b/src/parse/text/simple_extensions/extensions.rs @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Parsing context for extension processing. + +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; + +use super::{SimpleExtensionsError, types::CustomType}; +use crate::{ + parse::{Context, Parse}, + text::simple_extensions::SimpleExtensions as RawExtensions, + urn::Urn, +}; + +/// The contents (types) in an [`ExtensionFile`](super::file::ExtensionFile). +/// +/// This structure stores and provides access to the individual objects defined +/// in an [`ExtensionFile`](super::file::ExtensionFile); [`SimpleExtensions`] +/// represents the contents of an extensions file. +/// +/// Currently, only the [`CustomType`]s are included; any scalar, window, or +/// aggregate functions are not yet included. +#[derive(Clone, Debug, Default)] +pub struct SimpleExtensions { + /// Types defined in this extension file + types: HashMap, +} + +impl SimpleExtensions { + /// Add a type to the context. Name must be unique. + pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> { + if self.types.contains_key(&custom_type.name) { + return Err(SimpleExtensionsError::DuplicateTypeName { + name: custom_type.name.clone(), + }); + } + + self.types + .insert(custom_type.name.clone(), custom_type.clone()); + Ok(()) + } + + /// Check if a type with the given name exists in the context + pub fn has_type(&self, name: &str) -> bool { + self.types.contains_key(name) + } + + /// Get a type by name from the context + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.types.get(name) + } + + /// Get an iterator over all types in the context + pub fn types(&self) -> impl Iterator { + self.types.values() + } + + /// Consume the parsed extension and return its types. + pub(crate) fn into_types(self) -> HashMap { + self.types + } +} + +/// A context for parsing simple extensions, tracking what type names are +/// resolved or unresolved. +#[derive(Debug, Default)] +pub struct TypeContext { + /// Types that have been seen so far, now resolved. + known: HashSet, + /// Types that have been linked to, not yet resolved. + linked: HashSet, +} + +impl TypeContext { + /// Mark a type as found + pub fn found(&mut self, name: &str) { + self.linked.remove(name); + self.known.insert(name.to_string()); + } + + /// Mark a type as linked to - some other type or function references it, + /// but we haven't seen it. + pub fn linked(&mut self, name: &str) { + if !self.known.contains(name) { + self.linked.insert(name.to_string()); + } + } +} + +impl Context for TypeContext { + // ExtensionContext implements the Context trait +} + +// Implement parsing for the raw text representation to produce an `ExtensionFile`. +impl Parse for RawExtensions { + type Parsed = (Urn, SimpleExtensions); + type Error = super::SimpleExtensionsError; + + fn parse(self, ctx: &mut TypeContext) -> Result { + let RawExtensions { urn, types, .. } = self; + let urn = Urn::from_str(&urn)?; + let mut extension = SimpleExtensions::default(); + + for type_item in types { + let custom_type = Parse::parse(type_item, ctx)?; + // Add the parsed type to the context so later types can reference it + extension.add_type(&custom_type)?; + } + + if let Some(missing) = ctx.linked.iter().next() { + // TODO: Track originating type(s) to improve this error message. + return Err(super::SimpleExtensionsError::UnresolvedTypeReference { + type_name: missing.clone(), + }); + } + + Ok((urn, extension)) + } +} + +impl From<(Urn, SimpleExtensions)> for RawExtensions { + fn from((urn, extension): (Urn, SimpleExtensions)) -> Self { + let types = extension + .into_types() + .into_values() + .map(Into::into) + .collect(); + + RawExtensions { + urn: urn.to_string(), + aggregate_functions: vec![], + dependencies: HashMap::new(), + scalar_functions: vec![], + type_variations: vec![], + types, + window_functions: vec![], + } + } +} diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs new file mode 100644 index 00000000..bc2a3dcf --- /dev/null +++ b/src/parse/text/simple_extensions/file.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 + +use super::{CustomType, SimpleExtensions, SimpleExtensionsError}; +use crate::parse::Parse; +use crate::parse::text::simple_extensions::extensions::TypeContext; +use crate::text::simple_extensions::SimpleExtensions as RawExtensions; +use crate::urn::Urn; +use std::io::Read; + +/// A parsed and validated [`RawExtensions`]: a simple extensions file. +/// +/// An [`ExtensionFile`] has a canonical [`Urn`] and a parsed set of +/// [`SimpleExtensions`] data. It represents the extensions file as a whole. +#[derive(Debug)] +pub struct ExtensionFile { + /// The URN this extension was loaded from + pub urn: Urn, + /// The extension data containing types and eventually functions + extension: SimpleExtensions, +} + +impl ExtensionFile { + /// Create a new, empty [`ExtensionFile`] with an empty set of [`SimpleExtensions`]. + pub fn empty(urn: Urn) -> Self { + let extension = SimpleExtensions::default(); + Self { urn, extension } + } + + /// Create an [`ExtensionFile`] from raw simple extension data. + pub fn create(extensions: RawExtensions) -> Result { + // Parse all types (may contain unresolved Extension(String) references) + let mut ctx = TypeContext::default(); + let (urn, extension) = Parse::parse(extensions, &mut ctx)?; + + // TODO: Use ctx.known/ctx.linked to validate unresolved references and cross-file links. + + Ok(Self { urn, extension }) + } + + /// Get a type by name + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.extension.get_type(name) + } + + /// Get an iterator over all types in this extension + pub fn types(&self) -> impl Iterator { + self.extension.types() + } + + /// Returns the [`Urn`]` for this extension file. + pub fn urn(&self) -> &Urn { + &self.urn + } + + /// Get a reference to the underlying [`SimpleExtensions`]. + pub fn extension(&self) -> &SimpleExtensions { + &self.extension + } + + /// Convert the parsed extension file back into the raw text representation + /// by value. + pub fn into_raw(self) -> RawExtensions { + let ExtensionFile { urn, extension } = self; + RawExtensions::from((urn, extension)) + } + + /// Convert the parsed extension file back into the raw text representation + /// by reference. + pub fn to_raw(&self) -> RawExtensions { + RawExtensions::from((self.urn.clone(), self.extension.clone())) + } + + /// Read an extension file from a reader. + /// - `reader`: any [`Read`] instance with the YAML content + /// + /// Returns a parsed and validated [`ExtensionFile`] or an error. + pub fn read(reader: R) -> Result { + let raw: RawExtensions = serde_yaml::from_reader(reader)?; + Self::create(raw) + } + + /// Read an extension file from a string slice. + pub fn read_from_str>(s: S) -> Result { + let raw: RawExtensions = serde_yaml::from_str(s.as_ref())?; + Self::create(raw) + } +} + +// Parsing and conversion implementations are defined on `SimpleExtensions` in `extensions.rs`. + +#[cfg(test)] +mod tests { + use super::*; + use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType; + + const YAML_PARAM_TEST: &str = r#" +%YAML 1.2 +--- +urn: extension:example.com:param_test +types: + - name: "ParamTest" + parameters: + - name: "K" + type: integer + min: 1 + max: 10 +"#; + + const YAML_UNRESOLVED_TYPE: &str = r#" +%YAML 1.2 +--- +urn: extension:example.com:unresolved +types: + - name: "Alias" + structure: List> +"#; + + #[test] + fn yaml_round_trip_integer_param_bounds() { + let ext = ExtensionFile::read_from_str(YAML_PARAM_TEST).expect("parse ok"); + assert_eq!(ext.urn().to_string(), "extension:example.com:param_test"); + + let ty = ext.get_type("ParamTest").expect("type exists"); + match &ty.parameters[..] { + [param] => match ¶m.param_type { + RawParameterType::Integer { + min: actual_min, + max: actual_max, + } => { + assert_eq!(actual_min, &Some(1)); + assert_eq!(actual_max, &Some(10)); + } + other => panic!("unexpected param type: {other:?}"), + }, + other => panic!("unexpected parameters: {other:?}"), + } + + let back = ext.to_raw(); + assert_eq!(back.urn, "extension:example.com:param_test"); + let item = back + .types + .into_iter() + .find(|t| t.name == "ParamTest") + .expect("round-tripped type present"); + let param = item.parameters.unwrap().0.into_iter().next().unwrap(); + assert_eq!(param.name.as_deref(), Some("K")); + assert!(matches!( + param.type_, + crate::text::simple_extensions::TypeParamDefsItemType::Integer + )); + assert_eq!(param.min, Some(1.0)); + assert_eq!(param.max, Some(10.0)); + } + + #[test] + fn unresolved_type_reference_errors() { + let err = ExtensionFile::read_from_str(YAML_UNRESOLVED_TYPE) + .expect_err("expected unresolved type reference error"); + + match err { + SimpleExtensionsError::UnresolvedTypeReference { type_name } => { + assert_eq!(type_name, "MissingType"); + } + other => panic!("unexpected error type: {other:?}"), + } + } +} diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index 44867bf5..0e66e7c8 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -1,51 +1,69 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [text::simple_extensions] types. +//! Rustic types for validating and working with Substrait simple extensions. +//! +//! The raw YAML structs live in [`crate::text::simple_extensions`]. This +//! module parses those values into the typed representations used by this +//! crate: +//! * [`ExtensionFile`] – a fully validated extension document (URN plus its +//! definitions). +//! * [`SimpleExtensions`] – the validated objects declared by a single +//! extension file. +//! * [`CustomType`] / [`ConcreteType`] – type definitions and resolved type +//! structures used when checking function signatures. +//! * [`Registry`] – a reusable lookup structure that stores validated extension +//! files and exposes typed access to their contents. -use thiserror::Error; +use std::convert::Infallible; -use crate::{ - parse::{Context, Parse}, - text, -}; +use thiserror::Error; pub mod argument; +mod extensions; +mod file; +mod parsed_type; +mod registry; +mod types; -/// A parsed [text::simple_extensions::SimpleExtensions]. -pub struct SimpleExtensions { - // TODO -} +pub use extensions::SimpleExtensions; +pub use file::ExtensionFile; +pub use parsed_type::TypeExpr; +pub use registry::Registry; +pub use types::{ConcreteType, CustomType, ExtensionTypeError}; -/// Parse errors for [text::simple_extensions::SimpleExtensions]. -#[derive(Debug, Error, PartialEq)] +/// Errors for converting from YAML to [SimpleExtensions]. +#[derive(Debug, Error)] pub enum SimpleExtensionsError { - // TODO -} - -impl Parse for text::simple_extensions::SimpleExtensions { - type Parsed = SimpleExtensions; - type Error = SimpleExtensionsError; - - fn parse(self, _ctx: &mut C) -> Result { - // let text::simple_extensions::SimpleExtensions { - // aggregate_functions, - // dependencies, - // scalar_functions, - // type_variations, - // types, - // window_functions, - // } = self; - - todo!( - "text::simple_extensions::SimpleExtensions - https://github.com/substrait-io/substrait-rs/issues/157" - ) - } + /// Extension type error + #[error("Extension type error: {0}")] + ExtensionTypeError(#[from] ExtensionTypeError), + /// Failed to parse SimpleExtensions YAML + #[error("YAML parse error: {0}")] + YamlParse(#[from] serde_yaml::Error), + /// I/O error while reading extension content + #[error("io error: {0}")] + Io(#[from] std::io::Error), + /// Invalid URN provided + #[error("invalid urn")] + InvalidUrn(#[from] crate::urn::InvalidUrn), + /// Unresolved type reference in structure field + #[error("Type '{type_name}' referenced in structure not found")] + UnresolvedTypeReference { + /// The type name that could not be resolved + type_name: String, + // TODO: the location in the file where this came from would be nice + }, + /// Duplicate type definition within the same extension + #[error("duplicate type definition for `{name}`")] + DuplicateTypeName { + /// The repeated type name + name: String, + }, } -impl From for text::simple_extensions::SimpleExtensions { - fn from(_value: SimpleExtensions) -> Self { - todo!( - "text::simple_extensions::SimpleExtensions - https://github.com/substrait-io/substrait-rs/issues/157" - ) +// Needed for certain conversions - e.g. Urn -> Urn - to succeed. +impl From for SimpleExtensionsError { + fn from(_: Infallible) -> Self { + unreachable!() } } diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs new file mode 100644 index 00000000..1a784330 --- /dev/null +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Parsed type AST used by the simple extensions type parser. + +use crate::parse::text::simple_extensions::types::is_builtin_type_name; + +/// A parsed type expression from a type string, with lifetime tied to the original string. +#[derive(Clone, Debug, PartialEq)] +pub enum TypeExpr<'a> { + /// A type with a name, optional parameters, and nullability + Simple(&'a str, Vec>, bool), + /// A user-defined extension type, indicated by `u!Name`, with optional + /// parameters and nullability + UserDefined(&'a str, Vec>, bool), + /// Type variable (e.g., any1, any2) + TypeVariable(u32, bool), +} + +/// A parsed parameter to a parameterized type +#[derive(Clone, Debug, PartialEq)] +pub enum TypeExprParam<'a> { + /// A nested type parameter + Type(TypeExpr<'a>), + /// An integer literal parameter + Integer(i64), +} + +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum TypeParseError { + #[error("missing closing angle bracket in parameter list: {0}")] + ExpectedClosingAngleBracket(String), + #[error("Type variation syntax is not supported: {0}")] + UnsupportedVariation(String), +} + +impl<'a> TypeExpr<'a> { + /// Parse a type string into a [`TypeExpr`]. + pub fn parse(type_str: &'a str) -> Result { + // Handle type variables like any1, any2, etc. + if let Some(suffix) = type_str.strip_prefix("any") { + let (middle, nullable) = match suffix.strip_suffix('?') { + Some(middle) => (middle, true), + None => (suffix, false), + }; + + if let Ok(id) = middle.parse::() { + return Ok(TypeExpr::TypeVariable(id, nullable)); + } + } + + let (user_defined, rest) = match type_str.strip_prefix("u!") { + Some(right) => (true, right), + None => (false, type_str), + }; + + let (name_and_nullable, params): (&'a str, Vec>) = + match rest.split_once('<') { + Some((n, p)) => match p.strip_suffix('>') { + Some(p) => (n, parse_params(p)?), + None => return Err(TypeParseError::ExpectedClosingAngleBracket(p.to_string())), + }, + None => (rest, vec![]), + }; + + if name_and_nullable.contains('[') || name_and_nullable.contains(']') { + return Err(TypeParseError::UnsupportedVariation(type_str.to_string())); + } + + let (name, nullable) = match name_and_nullable.strip_suffix('?') { + Some(name) => (name, true), + None => (name_and_nullable, false), + }; + + if user_defined { + Ok(TypeExpr::UserDefined(name, params, nullable)) + } else { + Ok(TypeExpr::Simple(name, params, nullable)) + } + } + + /// Visit all extension type references contained in a parsed type, calling `on_ext` + /// for each encountered extension name (including named extension forms). + pub fn visit_references(&self, on_ext: &mut F) + where + F: FnMut(&str), + { + match self { + TypeExpr::UserDefined(name, params, _) => { + // Strip u! prefix when reporting linkage + on_ext(name); + for p in params { + if let TypeExprParam::Type(t) = p { + t.visit_references(on_ext); + } + } + } + TypeExpr::Simple(name, params, _) => { + if !is_builtin_type_name(name) { + on_ext(name); + } + for p in params { + if let TypeExprParam::Type(t) = p { + t.visit_references(on_ext); + } + } + } + TypeExpr::TypeVariable(..) => {} + } + } +} + +fn parse_params<'a>(s: &'a str) -> Result>, TypeParseError> { + let mut result = Vec::new(); + let mut start = 0; + let mut depth = 0; + + for (i, c) in s.char_indices() { + match c { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + result.push(parse_param(s[start..i].trim())?); + start = i + 1; + } + _ => {} + } + } + + if depth != 0 { + return Err(TypeParseError::ExpectedClosingAngleBracket(s.to_string())); + } + + if start < s.len() { + result.push(parse_param(s[start..].trim())?); + } + + Ok(result) +} + +fn parse_param<'a>(s: &'a str) -> Result, TypeParseError> { + if let Ok(i) = s.parse::() { + return Ok(TypeExprParam::Integer(i)); + } + Ok(TypeExprParam::Type(TypeExpr::parse(s)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse(expr: &str) -> TypeExpr<'_> { + TypeExpr::parse(expr).expect("parse succeeds") + } + + #[test] + fn test_simple_types() { + let cases = vec![ + ("i32", TypeExpr::Simple("i32", vec![], false)), + ("i32?", TypeExpr::Simple("i32", vec![], true)), + ("MAP", TypeExpr::Simple("MAP", vec![], false)), + ("timestamp", TypeExpr::Simple("timestamp", vec![], false)), + ( + "timestamp_tz?", + TypeExpr::Simple("timestamp_tz", vec![], true), + ), + ("time", TypeExpr::Simple("time", vec![], false)), + ]; + + for (expr, expected) in cases { + assert_eq!(parse(expr), expected, "unexpected parse for {expr}"); + } + } + + #[test] + fn test_type_variables() { + let cases = vec![ + ("any1", TypeExpr::TypeVariable(1, false)), + ("any2?", TypeExpr::TypeVariable(2, true)), + ]; + + for (expr, expected) in cases { + assert_eq!( + parse(expr), + expected, + "unexpected variable parse for {expr}" + ); + } + } + + #[test] + fn test_user_defined_and_parameters() { + let expr = "u!geo?>"; + match parse(expr) { + TypeExpr::UserDefined(name, params, nullable) => { + assert_eq!(name, "geo", "unexpected name for {expr}"); + assert!(nullable, "{expr} should be nullable"); + assert_eq!( + params, + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], true)), + TypeExprParam::Type(TypeExpr::Simple( + "point", + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + ], + false, + )), + ] + ); + } + other => panic!("unexpected parse result: {other:?}"), + } + + let map_expr = "Map?"; + assert_eq!( + parse(map_expr), + TypeExpr::Simple( + "Map", + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + TypeExprParam::Type(TypeExpr::Simple("string", vec![], false)), + ], + true, + ), + "unexpected map parse" + ); + } + + #[test] + fn test_visit_references_builtin_case_insensitive() { + let cases = vec![ + ("DECIMAL<10,2>", Vec::::new()), + ("List", Vec::::new()), + ("u!custom", vec!["custom".to_string()]), + ("Geo", vec!["Geo".to_string()]), + ]; + + for (expr, expected_refs) in cases { + let mut refs = Vec::new(); + parse(expr).visit_references(&mut |name| refs.push(name.to_string())); + assert_eq!(refs, expected_refs, "unexpected references for {expr}"); + } + } + + #[test] + fn test_variation_not_supported() { + let cases = vec!["i32[1]", "Foo?[1]", "u!bar[2]"]; + + for expr in cases { + match TypeExpr::parse(expr) { + Err(TypeParseError::UnsupportedVariation(s)) => assert_eq!(s, expr), + other => panic!("expected UnsupportedVariation for {expr}, got {other:?}"), + } + } + } +} diff --git a/src/parse/text/simple_extensions/registry.rs b/src/parse/text/simple_extensions/registry.rs new file mode 100644 index 00000000..c6bd39b7 --- /dev/null +++ b/src/parse/text/simple_extensions/registry.rs @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Substrait Extension Registry +//! +//! This module provides registries for Substrait extensions: +//! - **Global Registry**: Immutable, reusable across plans, URI+name based lookup +//! - **Local Registry**: Per-plan, anchor-based, references Global Registry (TODO) +//! +//! Currently only type definitions are supported. Function parsing will be added in a future update. +//! +//! This module is only available when the `parse` feature is enabled. + +use super::{ExtensionFile, types::CustomType}; +use crate::urn::Urn; + +/// Extension Registry that manages Substrait extensions +/// +/// This registry is immutable and reusable across multiple plans. +/// It provides URN + name based lookup for extension types. Function parsing will be added in a future update. +#[derive(Debug)] +pub struct Registry { + /// Pre-validated extension files + extensions: Vec, +} + +impl Registry { + /// Create a new Global Registry from validated extension files + pub fn new(extensions: Vec) -> Self { + Self { extensions } + } + + /// Get an iterator over all extension files in this registry + pub fn extensions(&self) -> impl Iterator { + self.extensions.iter() + } + + /// Create a Global Registry from the built-in core extensions + #[cfg(feature = "extensions")] + pub fn from_core_extensions() -> Self { + use crate::extensions::EXTENSIONS; + use std::sync::LazyLock; + + // Force evaluation of core extensions + LazyLock::force(&EXTENSIONS); + + // Convert HashMap to Vec + let extensions: Vec = EXTENSIONS + .iter() + .map(|(urn, simple_extensions)| { + let extension_file = ExtensionFile::create(simple_extensions.clone()) + .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {urn}: {err}")); + debug_assert_eq!(extension_file.urn(), urn); + extension_file + }) + .collect(); + + Self { extensions } + } + + // Private helper methods + + fn get_extension(&self, urn: &Urn) -> Option<&ExtensionFile> { + self.extensions.iter().find(|ext| ext.urn() == urn) + } + + /// Get a type by URI and name + pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> { + self.get_extension(urn)?.get_type(name) + } +} + +#[cfg(test)] +mod tests { + use super::{ExtensionFile, Registry}; + use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem}; + use crate::urn::Urn; + use std::str::FromStr; + + fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile { + let types = type_names + .iter() + .map(|name| SimpleExtensionsTypesItem { + name: (*name).to_string(), + description: None, + parameters: None, + structure: None, + variadic: None, + }) + .collect(); + + let raw = SimpleExtensions { + scalar_functions: vec![], + aggregate_functions: vec![], + window_functions: vec![], + dependencies: Default::default(), + type_variations: vec![], + types, + urn: urn.to_string(), + }; + + ExtensionFile::create(raw).expect("valid extension file") + } + + #[test] + fn test_registry_iteration() { + let urns = vec![ + "extension:example.com:first", + "extension:example.com:second", + ]; + let registry = Registry::new( + urns.iter() + .map(|urn| extension_file(urn, &["type"])) + .collect(), + ); + + let collected: Vec<&Urn> = registry.extensions().map(|ext| ext.urn()).collect(); + assert_eq!(collected.len(), 2); + for urn in urns { + assert!( + collected + .iter() + .any(|candidate| candidate.to_string() == urn) + ); + } + } + + #[test] + fn test_type_lookup() { + let urn = Urn::from_str("extension:example.com:test").unwrap(); + let registry = Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]); + let other_urn = Urn::from_str("extension:example.com:other").unwrap(); + + let cases = vec![ + (&urn, "test_type", true), + (&urn, "missing", false), + (&other_urn, "test_type", false), + ]; + + for (query_urn, type_name, expected) in cases { + assert_eq!( + registry.get_type(query_urn, type_name).is_some(), + expected, + "unexpected lookup result for {query_urn}:{type_name}" + ); + } + } + + #[cfg(feature = "extensions")] + #[test] + fn test_from_core_extensions() { + let registry = Registry::from_core_extensions(); + assert!(registry.extensions().count() > 0); + + // Find the unknown.yaml extension dynamically + let unknown_extension = registry + .extensions() + .find(|ext| ext.urn().to_string() == "extension:io.substrait:unknown") + .expect("Should find unknown extension"); + + let unknown_type = unknown_extension.get_type("unknown"); + assert!( + unknown_type.is_some(), + "Should find 'unknown' type in unknown.yaml extension" + ); + + // Also test the registry's get_type method with the actual URI + let unknown_type_via_registry = registry.get_type(unknown_extension.urn(), "unknown"); + assert!(unknown_type_via_registry.is_some()); + } +} diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs new file mode 100644 index 00000000..2385e9cc --- /dev/null +++ b/src/parse/text/simple_extensions/types.rs @@ -0,0 +1,1906 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Concrete type system for function validation in the registry. +//! +//! This module provides a clean, type-safe wrapper around Substrait extension types, +//! separating function signature patterns from concrete argument types. + +use super::TypeExpr; +use super::argument::{ + EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError, +}; +use super::extensions::TypeContext; +use super::parsed_type::TypeExprParam; +use crate::parse::Parse; +use crate::parse::text::simple_extensions::parsed_type::TypeParseError; +use crate::text::simple_extensions::{ + EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs, + TypeParamDefsItem, TypeParamDefsItemType, +}; +use serde_json::Value; +use std::convert::TryFrom; +use std::fmt; +use std::ops::RangeInclusive; +use std::str::FromStr; +use thiserror::Error; + +/// Write a sequence of items separated by a separator, with a start and end +/// delimiter. +/// +/// Start and end are only included in the output if there is at least one item. +fn write_separated( + f: &mut fmt::Formatter<'_>, + iter: I, + start: &str, + end: &str, + sep: &str, +) -> fmt::Result +where + I: IntoIterator, + T: fmt::Display, +{ + let mut it = iter.into_iter(); + if let Some(first) = it.next() { + f.write_str(start)?; + write!(f, "{first}")?; + for item in it { + f.write_str(sep)?; + write!(f, "{item}")?; + } + f.write_str(end) + } else { + Ok(()) + } +} + +/// A pair of a key and a value, separated by a separator. For display purposes. +struct KeyValueDisplay(K, V, &'static str); + +impl fmt::Display for KeyValueDisplay +where + K: fmt::Display, + V: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}{}{}", self.0, self.2, self.1) + } +} + +/// Substrait primitive built-in types (no parameters required) +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum PrimitiveType { + /// Boolean type - `bool` + Boolean, + /// 8-bit signed integer - `i8` + I8, + /// 16-bit signed integer - `i16` + I16, + /// 32-bit signed integer - `i32` + I32, + /// 64-bit signed integer - `i64` + I64, + /// 32-bit floating point - `fp32` + Fp32, + /// 64-bit floating point - `fp64` + Fp64, + /// Variable-length string - `string` + String, + /// Variable-length binary data - `binary` + Binary, + /// Naive Timestamp + Timestamp, + /// Timestamp with time zone - `timestamp_tz` + TimestampTz, + /// Calendar date - `date` + Date, + /// Time of day - `time` + Time, + /// Year-month interval - `interval_year` + IntervalYear, + /// 128-bit UUID - `uuid` + Uuid, +} + +impl fmt::Display for PrimitiveType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + PrimitiveType::Boolean => "bool", + PrimitiveType::I8 => "i8", + PrimitiveType::I16 => "i16", + PrimitiveType::I32 => "i32", + PrimitiveType::I64 => "i64", + PrimitiveType::Fp32 => "fp32", + PrimitiveType::Fp64 => "fp64", + PrimitiveType::String => "string", + PrimitiveType::Binary => "binary", + PrimitiveType::Timestamp => "timestamp", + PrimitiveType::TimestampTz => "timestamp_tz", + PrimitiveType::Date => "date", + PrimitiveType::Time => "time", + PrimitiveType::IntervalYear => "interval_year", + PrimitiveType::Uuid => "uuid", + }; + f.write_str(s) + } +} + +/// Parameter for parameterized types +#[derive(Clone, Debug, PartialEq)] +pub enum TypeParameter { + /// Integer parameter (e.g., precision, scale) + Integer(i64), + /// Type parameter (nested type) + Type(ConcreteType), +} + +impl fmt::Display for TypeParameter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypeParameter::Integer(i) => write!(f, "{i}"), + TypeParameter::Type(t) => write!(f, "{t}"), + } + } +} + +/// Parameterized builtin types that require non-type parameters, e.g. numbers +/// or enum +#[derive(Clone, Debug, PartialEq)] +pub enum BuiltinParameterized { + /// Fixed-length character string: `FIXEDCHAR` + FixedChar { + /// Length (number of characters), must be >= 1 + length: i32, + }, + /// Variable-length character string: `VARCHAR` + VarChar { + /// Maximum length (number of characters), must be >= 1 + length: i32, + }, + /// Fixed-length binary data: `FIXEDBINARY` + FixedBinary { + /// Length (number of bytes), must be >= 1 + length: i32, + }, + /// Fixed-point decimal: `DECIMAL` + Decimal { + /// Precision (total digits), <= 38 + precision: i32, + /// Scale (digits after decimal point), 0 <= S <= P + scale: i32, + }, + /// Time with sub-second precision: `PRECISIONTIME

` + PrecisionTime { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Timestamp with sub-second precision: `PRECISIONTIMESTAMP

` + PrecisionTimestamp { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Timezone-aware timestamp with precision: `PRECISIONTIMESTAMPTZ

` + PrecisionTimestampTz { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Day-time interval: `INTERVAL_DAY

` + IntervalDay { + /// Sub-second precision digits (0-9: seconds to nanoseconds) + precision: i32, + }, + /// Compound interval: `INTERVAL_COMPOUND

` + IntervalCompound { + /// Sub-second precision digits + precision: i32, + }, +} + +impl BuiltinParameterized { + /// Check if a string is a valid name for a parameterized builtin type. + /// + /// Only matches lowercase. + pub fn is_name(s: &str) -> bool { + matches!( + s, + "fixedchar" + | "varchar" + | "fixedbinary" + | "decimal" + | "precisiontime" + | "precisiontimestamp" + | "precisiontimestamptz" + | "interval_day" + | "interval_compound" + ) + } +} + +impl fmt::Display for BuiltinParameterized { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BuiltinParameterized::FixedChar { length } => { + write!(f, "FIXEDCHAR<{length}>") + } + BuiltinParameterized::VarChar { length } => { + write!(f, "VARCHAR<{length}>") + } + BuiltinParameterized::FixedBinary { length } => { + write!(f, "FIXEDBINARY<{length}>") + } + BuiltinParameterized::Decimal { precision, scale } => { + write!(f, "DECIMAL<{precision}, {scale}>") + } + BuiltinParameterized::PrecisionTime { precision } => { + write!(f, "PRECISIONTIME<{precision}>") + } + BuiltinParameterized::PrecisionTimestamp { precision } => { + write!(f, "PRECISIONTIMESTAMP<{precision}>") + } + BuiltinParameterized::PrecisionTimestampTz { precision } => { + write!(f, "PRECISIONTIMESTAMPTZ<{precision}>") + } + BuiltinParameterized::IntervalDay { precision } => { + write!(f, "INTERVAL_DAY<{precision}>") + } + BuiltinParameterized::IntervalCompound { precision } => { + write!(f, "INTERVAL_COMPOUND<{precision}>") + } + } + } +} + +/// Unified representation of simple builtin types (primitive or parameterized). +/// Does not include container types like List, Map, or Struct. +#[derive(Clone, Debug, PartialEq)] +pub enum BuiltinKind { + /// Primitive builtins like `i32` + Primitive(PrimitiveType), + /// Parameterized builtins like `DECIMAL` + Parameterized(BuiltinParameterized), +} + +impl fmt::Display for BuiltinKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BuiltinKind::Primitive(p) => write!(f, "{p}"), + BuiltinKind::Parameterized(p) => write!(f, "{p}"), + } + } +} + +impl From for BuiltinKind { + fn from(value: PrimitiveType) -> Self { + BuiltinKind::Primitive(value) + } +} + +impl From for BuiltinKind { + fn from(value: BuiltinParameterized) -> Self { + BuiltinKind::Parameterized(value) + } +} + +/// Check if a name corresponds to any built-in type (primitive, parameterized, +/// or container) +pub fn is_builtin_type_name(name: &str) -> bool { + let lower = name.to_ascii_lowercase(); + PrimitiveType::from_str(&lower).is_ok() + || BuiltinParameterized::is_name(&lower) + || matches!(lower.as_str(), "list" | "map" | "struct") +} + +/// Error when a builtin type name is not recognized +#[derive(Debug, Error)] +#[error("Unrecognized builtin type: {0}")] +pub struct UnrecognizedBuiltin(String); + +impl FromStr for PrimitiveType { + type Err = UnrecognizedBuiltin; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "bool" => Ok(PrimitiveType::Boolean), + "i8" => Ok(PrimitiveType::I8), + "i16" => Ok(PrimitiveType::I16), + "i32" => Ok(PrimitiveType::I32), + "i64" => Ok(PrimitiveType::I64), + "fp32" => Ok(PrimitiveType::Fp32), + "fp64" => Ok(PrimitiveType::Fp64), + "string" => Ok(PrimitiveType::String), + "binary" => Ok(PrimitiveType::Binary), + "timestamp" => Ok(PrimitiveType::Timestamp), + "timestamp_tz" => Ok(PrimitiveType::TimestampTz), + "date" => Ok(PrimitiveType::Date), + "time" => Ok(PrimitiveType::Time), + "interval_year" => Ok(PrimitiveType::IntervalYear), + "uuid" => Ok(PrimitiveType::Uuid), + _ => Err(UnrecognizedBuiltin(s.to_string())), + } + } +} + +/// Parameter type information for type definitions +#[derive(Clone, Debug, PartialEq)] +pub enum ParameterConstraint { + /// Data type parameter + DataType, + /// Integer parameter with range constraints + Integer { + /// Minimum value (inclusive), if specified + min: Option, + /// Maximum value (inclusive), if specified + max: Option, + }, + /// Enumeration parameter + Enumeration { + /// Valid enumeration values (validated, deduplicated) + options: ParsedEnumOptions, + }, + /// Boolean parameter + Boolean, + /// String parameter + String, +} + +impl ParameterConstraint { + /// Convert back to raw TypeParamDefsItemType + fn raw_type(&self) -> TypeParamDefsItemType { + match self { + ParameterConstraint::DataType => TypeParamDefsItemType::DataType, + ParameterConstraint::Boolean => TypeParamDefsItemType::Boolean, + ParameterConstraint::Integer { .. } => TypeParamDefsItemType::Integer, + ParameterConstraint::Enumeration { .. } => TypeParamDefsItemType::Enumeration, + ParameterConstraint::String => TypeParamDefsItemType::String, + } + } + + /// Extract raw bounds for integer parameters (min, max) + fn raw_bounds(&self) -> (Option, Option) { + match self { + ParameterConstraint::Integer { min, max } => { + (min.map(|i| i as f64), max.map(|i| i as f64)) + } + _ => (None, None), + } + } + + /// Extract raw enum options for enumeration parameters + fn raw_options(&self) -> Option { + match self { + ParameterConstraint::Enumeration { options } => Some(options.clone().into()), + _ => None, + } + } + + /// Check if a parameter value is valid for this parameter type + pub fn is_valid_value(&self, value: &Value) -> bool { + match (self, value) { + (ParameterConstraint::DataType, Value::String(_)) => true, + (ParameterConstraint::Integer { min, max }, Value::Number(n)) => { + if let Some(i) = n.as_i64() { + min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val) + } else { + false + } + } + (ParameterConstraint::Enumeration { options }, Value::String(s)) => options.contains(s), + (ParameterConstraint::Boolean, Value::Bool(_)) => true, + (ParameterConstraint::String, Value::String(_)) => true, + _ => false, + } + } + + fn from_raw( + t: TypeParamDefsItemType, + opts: Option, + min: Option, + max: Option, + ) -> Result { + Ok(match t { + TypeParamDefsItemType::DataType => Self::DataType, + TypeParamDefsItemType::Boolean => Self::Boolean, + TypeParamDefsItemType::Integer => { + if let Some(min_f) = min + && min_f.fract() != 0.0 + { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + if let Some(max_f) = max + && max_f.fract() != 0.0 + { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + + let min_i = min.map(|v| v as i64); + let max_i = max.map(|v| v as i64); + Self::Integer { + min: min_i, + max: max_i, + } + } + TypeParamDefsItemType::Enumeration => { + let options: ParsedEnumOptions = + opts.ok_or(TypeParamError::MissingEnumOptions)?.try_into()?; + Self::Enumeration { options } + } + TypeParamDefsItemType::String => Self::String, + }) + } +} + +/// A validated type parameter with name and constraints +#[derive(Clone, Debug, PartialEq)] +pub struct TypeParam { + /// Parameter name (e.g., "K" for a type variable) + pub name: String, + /// Parameter type constraints + pub param_type: ParameterConstraint, + /// Human-readable description + pub description: Option, +} + +impl TypeParam { + /// Create a new type parameter + pub fn new(name: String, param_type: ParameterConstraint, description: Option) -> Self { + Self { + name, + param_type, + description, + } + } + + /// Check if a parameter value is valid + pub fn is_valid_value(&self, value: &Value) -> bool { + self.param_type.is_valid_value(value) + } +} + +impl TryFrom for TypeParam { + type Error = TypeParamError; + + fn try_from(item: TypeParamDefsItem) -> Result { + let name = item.name.ok_or(TypeParamError::MissingName)?; + let param_type = + ParameterConstraint::from_raw(item.type_, item.options, item.min, item.max)?; + + Ok(Self { + name, + param_type, + description: item.description, + }) + } +} + +/// Error types for extension type validation +#[derive(Debug, Error, PartialEq)] +pub enum ExtensionTypeError { + /// Extension type name is invalid + #[error("{0}")] + InvalidName(#[from] InvalidTypeName), + /// Any type variable is invalid for concrete types + #[error("Any type variable is invalid for concrete types: any{}{}", id, nullability.then_some("?").unwrap_or(""))] + InvalidAnyTypeVariable { + /// The type variable name + id: u32, + /// Whether the type variable is nullable + nullability: bool, + }, + /// Parameter validation failed + #[error("Invalid parameter: {0}")] + InvalidParameter(#[from] TypeParamError), + /// Field type is invalid + #[error("Invalid structure field type: {0}")] + InvalidFieldType(String), + /// Type parameter count is invalid for the given type name + #[error("Type '{type_name}' expects {expected} parameters, got {actual}")] + InvalidParameterCount { + /// The type name being validated + type_name: String, + /// Human-readable description of the expected parameter count + expected: &'static str, + /// The actual number of parameters provided + actual: usize, + }, + /// Type parameter is of the wrong kind for the given position + #[error("Type '{type_name}' parameter {index} must be {expected}")] + InvalidParameterKind { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Expected parameter kind (e.g., integer, type) + expected: &'static str, + }, + /// Provided parameter value does not fit within the expected bounds + #[error("Type '{type_name}' parameter {index} value {value} is not within {expected}")] + InvalidParameterValue { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Provided parameter value + value: i64, + /// Description of the expected range or type + expected: &'static str, + }, + /// Provided parameter value does not fit within the expected bounds + #[error("Type '{type_name}' parameter {index} value {value} is out of range {expected:?}")] + InvalidParameterRange { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Provided parameter value + value: i64, + /// Description of the expected range or type + expected: RangeInclusive, + }, + /// Structure representation cannot be nullable + #[error("Structure representation cannot be nullable: {type_string}")] + StructureCannotBeNullable { + /// The type string that was nullable + type_string: String, + }, + /// Error parsing type + #[error("Error parsing type: {0}")] + ParseType(#[from] TypeParseError), +} + +/// Error types for TypeParam validation +#[derive(Debug, Error, PartialEq)] +pub enum TypeParamError { + /// Parameter name is missing + #[error("Parameter name is required")] + MissingName, + /// Integer parameter has non-integer min/max values + #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")] + InvalidIntegerBounds { + /// The invalid minimum value + min: Option, + /// The invalid maximum value + max: Option, + }, + /// Enumeration parameter is missing options + #[error("Enumeration parameter is missing options")] + MissingEnumOptions, + /// Enumeration parameter has invalid options + #[error("Enumeration parameter has invalid options: {0}")] + InvalidEnumOptions(#[from] ParsedEnumOptionsError), +} + +/// A validated Simple Extension type definition +#[derive(Clone, Debug, PartialEq)] +pub struct CustomType { + /// Type name + pub name: String, + /// Type parameters (e.g., for generic types) + pub parameters: Vec, + /// Concrete structure definition, if any + pub structure: Option, + /// Whether this type can have variadic parameters + pub variadic: Option, + /// Human-readable description + pub description: Option, +} + +impl CustomType { + /// Check if this type name is valid according to Substrait naming rules + /// (see the `Identifier` rule in `substrait/grammar/SubstraitLexer.g4`). + /// Identifiers are case-insensitive and must start with a an ASCII letter, + /// `_`, or `$`, followed by ASCII letters, digits, `_`, or `$`. + // + // Note: I'm not sure if `$` is actually something we want to allow, or if + // `_` is, but it's in the grammar so I'm allowing it here. + pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> { + let mut chars = name.chars(); + let first = chars + .next() + .ok_or_else(|| InvalidTypeName(name.to_string()))?; + if !(first.is_ascii_alphabetic() || first == '_' || first == '$') { + return Err(InvalidTypeName(name.to_string())); + } + + if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$') { + return Err(InvalidTypeName(name.to_string())); + } + + Ok(()) + } + + /// Create a new custom type with validation + pub fn new( + name: String, + parameters: Vec, + structure: Option, + variadic: Option, + description: Option, + ) -> Result { + Self::validate_name(&name)?; + + Ok(Self { + name, + parameters, + structure, + variadic, + description, + }) + } +} + +impl From for SimpleExtensionsTypesItem { + fn from(value: CustomType) -> Self { + // Convert parameters back to TypeParamDefs if any + let parameters = if value.parameters.is_empty() { + None + } else { + Some(TypeParamDefs( + value + .parameters + .into_iter() + .map(|param| { + let (min, max) = param.param_type.raw_bounds(); + TypeParamDefsItem { + name: Some(param.name), + description: param.description, + type_: param.param_type.raw_type(), + min, + max, + options: param.param_type.raw_options(), + optional: None, + } + }) + .collect(), + )) + }; + + // Convert structure back to Type if any + let structure = value.structure.map(Into::into); + + SimpleExtensionsTypesItem { + name: value.name, + description: value.description, + parameters, + structure, + variadic: value.variadic, + } + } +} + +impl Parse for SimpleExtensionsTypesItem { + type Parsed = CustomType; + type Error = ExtensionTypeError; + + fn parse(self, ctx: &mut TypeContext) -> Result { + let name = self.name; + CustomType::validate_name(&name)?; + + // Register this type as found + ctx.found(&name); + + let parameters = if let Some(param_defs) = self.parameters { + param_defs + .0 + .into_iter() + .map(TypeParam::try_from) + .collect::, _>>()? + } else { + Vec::new() + }; + + // Parse structure with context, so referenced extension types are recorded as linked + let structure = match self.structure { + Some(structure_data) => Some(Parse::parse(structure_data, ctx)?), + None => None, + }; + + Ok(CustomType { + name, + parameters, + structure, + variadic: self.variadic, + description: self.description, + }) + } +} + +impl Parse for RawType { + type Parsed = ConcreteType; + type Error = ExtensionTypeError; + + fn parse(self, ctx: &mut TypeContext) -> Result { + match self { + RawType::String(type_string) => { + let parsed_type = TypeExpr::parse(&type_string)?; + let mut link = |name: &str| ctx.linked(name); + parsed_type.visit_references(&mut link); + let concrete = ConcreteType::try_from(parsed_type)?; + + // Structure representation cannot be nullable + if concrete.nullable { + return Err(ExtensionTypeError::StructureCannotBeNullable { type_string }); + } + + Ok(concrete) + } + RawType::Object(field_map) => { + // Here we have the internal structure of a custom type, + // specified by (field name, type) pairs. Note that in the YAML + // itself, these are a map - and thus, while the text has an + // order, the data implicitly does not. In Rust, the field map + // is of type serde_json::Map, which also does not preserve + // order. + // + // So while it might be surprising in some ways that we are not + // preserving the order of fields as specified in the YAML, the + // nature of YAML somewhat precludes that. + // + // To give an example: we are considering these two equivalent: + // + // ```yaml + // types: + // - name: point1 + // structure: + // longitude: i32 + // latitude: i32 + // - name: point2 + // structure: + // latitude: i32 + // longitude: i32 + // ``` + // + // In Rust, these come in as a `serde_json::Map`, which does not + // preserve insertion order. Here, we chose to store keys in + // lexicographic order, with an explicit sort so that our + // internal representation and any round-trip output remain + // deterministic. + let mut entries: Vec<_> = field_map.into_iter().collect(); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut field_names = Vec::new(); + let mut field_types = Vec::new(); + + for (field_name, field_type_value) in entries { + field_names.push(field_name); + + let type_string = match field_type_value { + serde_json::Value::String(s) => s, + _ => { + return Err(ExtensionTypeError::InvalidFieldType( + "Struct field types must be strings".to_string(), + )); + } + }; + + let parsed_field_type = TypeExpr::parse(&type_string)?; + let mut link = |name: &str| ctx.linked(name); + parsed_field_type.visit_references(&mut link); + let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; + + field_types.push(field_concrete_type); + } + + Ok(ConcreteType { + kind: ConcreteTypeKind::NamedStruct { + field_names, + field_types, + }, + nullable: false, + }) + } + } + } +} + +/// Invalid type name error +#[derive(Debug, Error, PartialEq)] +#[error("invalid type name `{0}`")] +pub struct InvalidTypeName(String); + +/// Known Substrait types (builtin + extension references) +#[derive(Clone, Debug, PartialEq)] +pub enum ConcreteTypeKind { + /// Built-in Substrait type (primitive or parameterized) + Builtin(BuiltinKind), + /// Extension type with optional parameters + Extension { + /// Extension type name + name: String, + /// Type parameters + parameters: Vec, + }, + /// List type with element type + List(Box), + /// Map type with key and value types + Map { + /// Key type + key: Box, + /// Value type + value: Box, + }, + /// Struct type (ordered fields without names) + Struct(Vec), + /// Named struct type (nstruct - ordered fields with names) + NamedStruct { + /// Field names + field_names: Vec, + /// Field types (same order as field_names) + field_types: Vec, + }, +} + +impl fmt::Display for ConcreteTypeKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConcreteTypeKind::Builtin(b) => write!(f, "{b}"), + ConcreteTypeKind::Extension { name, parameters } => { + write!(f, "{name}")?; + write_separated(f, parameters.iter(), "<", ">", ", ") + } + ConcreteTypeKind::List(elem) => write!(f, "list<{elem}>"), + ConcreteTypeKind::Map { key, value } => write!(f, "map<{key}, {value}>"), + ConcreteTypeKind::Struct(types) => { + write_separated(f, types.iter(), "struct<", ">", ", ") + } + ConcreteTypeKind::NamedStruct { + field_names, + field_types, + } => { + let kvs = field_names + .iter() + .zip(field_types.iter()) + .map(|(k, v)| KeyValueDisplay(k, v, ": ")); + + write_separated(f, kvs, "{", "}", ", ") + } + } + } +} + +/// A concrete, fully-resolved type instance +#[derive(Clone, Debug, PartialEq)] +pub struct ConcreteType { + /// The resolved type shape + pub kind: ConcreteTypeKind, + /// Whether this type is nullable + pub nullable: bool, +} + +impl ConcreteType { + /// Create a new primitive builtin type + pub fn builtin(builtin_type: PrimitiveType, nullable: bool) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Builtin(BuiltinKind::Primitive(builtin_type)), + nullable, + } + } + + /// Create a new parameterized builtin type + pub fn parameterized_builtin( + builtin_type: BuiltinParameterized, + nullable: bool, + ) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Builtin(BuiltinKind::Parameterized(builtin_type)), + nullable, + } + } + + /// Create a new extension type reference (without parameters) + pub fn extension(name: String, nullable: bool) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Extension { + name, + parameters: Vec::new(), + }, + nullable, + } + } + + /// Create a new parameterized extension type + pub fn extension_with_params( + name: String, + parameters: Vec, + nullable: bool, + ) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Extension { name, parameters }, + nullable, + } + } + + /// Create a new list type + pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::List(Box::new(element_type)), + nullable, + } + } + + /// Create a new struct type (ordered fields without names) + pub fn r#struct(field_types: Vec, nullable: bool) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Struct(field_types), + nullable, + } + } + + /// Create a new map type + pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::Map { + key: Box::new(key_type), + value: Box::new(value_type), + }, + nullable, + } + } + + /// Create a new named struct type (nstruct - ordered fields with names) + pub fn named_struct( + field_names: Vec, + field_types: Vec, + nullable: bool, + ) -> ConcreteType { + ConcreteType { + kind: ConcreteTypeKind::NamedStruct { + field_names, + field_types, + }, + nullable, + } + } + + /// Check if this type matches another type exactly + pub fn matches(&self, other: &ConcreteType) -> bool { + self == other + } + + /// Check if this type is compatible with another type (considering nullability) + pub fn is_compatible_with(&self, other: &ConcreteType) -> bool { + // Types must match exactly, but nullable types can accept non-nullable values + self.kind == other.kind && (self.nullable || !other.nullable) + } +} + +impl fmt::Display for ConcreteType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.kind)?; + if self.nullable { + write!(f, "?")?; + } + Ok(()) + } +} + +impl From for RawType { + fn from(val: ConcreteType) -> Self { + match val.kind { + ConcreteTypeKind::NamedStruct { + field_names, + field_types, + } => { + let mut map = serde_json::Map::new(); + for (name, ty) in field_names.into_iter().zip(field_types.into_iter()) { + if let Some(v) = map.insert(name, serde_json::Value::String(ty.to_string())) { + // This should not happen - you should not have + // duplicate field names in a NamedStruct + panic!("duplicate value '{v:?}' in NamedStruct"); + } + } + RawType::Object(map) + } + _ => RawType::String(val.to_string()), + } + } +} + +fn expect_integer_param( + type_name: &str, + index: usize, + param: &TypeExprParam<'_>, + range: Option>, +) -> Result { + let value = match param { + TypeExprParam::Integer(value) => { + i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue { + type_name: type_name.to_string(), + index, + value: *value, + expected: "an i32", + }) + } + _ => Err(ExtensionTypeError::InvalidParameterKind { + type_name: type_name.to_string(), + index, + expected: "an integer", + }), + }?; + + if let Some(range) = &range + && !range.contains(&value) + { + return Err(ExtensionTypeError::InvalidParameterRange { + type_name: type_name.to_string(), + index, + value: i64::from(value), + expected: range.clone(), + }); + } + + Ok(value) +} + +fn expect_type_argument<'a>( + type_name: &str, + index: usize, + param: TypeExprParam<'a>, +) -> Result { + match param { + TypeExprParam::Type(t) => ConcreteType::try_from(t), + TypeExprParam::Integer(_) => Err(ExtensionTypeError::InvalidParameterKind { + type_name: type_name.to_string(), + index, + expected: "a type", + }), + } +} + +fn type_expr_param_to_type_parameter<'a>( + param: TypeExprParam<'a>, +) -> Result { + Ok(match param { + TypeExprParam::Integer(v) => TypeParameter::Integer(v), + TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?), + }) +} + +fn parse_parameterized_builtin<'a>( + display_name: &str, + lower_name: &str, + params: &[TypeExprParam<'a>], +) -> Result, ExtensionTypeError> { + match lower_name { + "fixedchar" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; + Ok(Some(BuiltinParameterized::FixedChar { length })) + } + "varchar" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; + Ok(Some(BuiltinParameterized::VarChar { length })) + } + "fixedbinary" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; + Ok(Some(BuiltinParameterized::FixedBinary { length })) + } + "decimal" => { + if params.len() != 2 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "2", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(1..=38))?; + let scale = expect_integer_param(display_name, 1, ¶ms[1], Some(0..=precision))?; + Ok(Some(BuiltinParameterized::Decimal { precision, scale })) + } + // Should we accept both "precision_time" and "precisiontime"? The + // docs/spec say PRECISIONTIME. The protos use underscores, so it could + // show up in generated code, although maybe that's out of spec. + "precisiontime" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; + Ok(Some(BuiltinParameterized::PrecisionTime { precision })) + } + "precisiontimestamp" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; + Ok(Some(BuiltinParameterized::PrecisionTimestamp { precision })) + } + "precisiontimestamptz" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; + Ok(Some(BuiltinParameterized::PrecisionTimestampTz { + precision, + })) + } + "interval_day" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=9))?; + Ok(Some(BuiltinParameterized::IntervalDay { precision })) + } + "interval_compound" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0], None)?; + Ok(Some(BuiltinParameterized::IntervalCompound { precision })) + } + _ => Ok(None), + } +} + +impl<'a> TryFrom> for ConcreteType { + type Error = ExtensionTypeError; + + fn try_from(parsed_type: TypeExpr<'a>) -> Result { + match parsed_type { + TypeExpr::Simple(name, params, nullable) => { + let lower = name.to_ascii_lowercase(); + + match lower.as_str() { + "list" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let element = + expect_type_argument(name, 0, params.into_iter().next().unwrap())?; + return Ok(ConcreteType::list(element, nullable)); + } + "map" => { + if params.len() != 2 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "2", + actual: params.len(), + }); + } + let mut iter = params.into_iter(); + let key = expect_type_argument(name, 0, iter.next().unwrap())?; + let value = expect_type_argument(name, 1, iter.next().unwrap())?; + return Ok(ConcreteType::map(key, value, nullable)); + } + "struct" => { + let field_types = params + .into_iter() + .enumerate() + .map(|(idx, param)| expect_type_argument(name, idx, param)) + .collect::, _>>()?; + return Ok(ConcreteType::r#struct(field_types, nullable)); + } + _ => {} + } + + if let Some(parameterized) = + parse_parameterized_builtin(name, lower.as_str(), ¶ms)? + { + return Ok(ConcreteType::parameterized_builtin(parameterized, nullable)); + } + + match PrimitiveType::from_str(&lower) { + Ok(builtin) => { + if !params.is_empty() { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "0", + actual: params.len(), + }); + } + Ok(ConcreteType::builtin(builtin, nullable)) + } + Err(_) => { + let parameters = params + .into_iter() + .map(type_expr_param_to_type_parameter) + .collect::, _>>()?; + Ok(ConcreteType::extension_with_params( + name.to_string(), + parameters, + nullable, + )) + } + } + } + TypeExpr::UserDefined(name, params, nullable) => { + let parameters = params + .into_iter() + .map(type_expr_param_to_type_parameter) + .collect::, _>>()?; + Ok(ConcreteType::extension_with_params( + name.to_string(), + parameters, + nullable, + )) + } + TypeExpr::TypeVariable(id, nullability) => { + Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::extensions::TypeContext; + use super::*; + use crate::parse::text::simple_extensions::TypeExpr; + use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions; + use crate::text::simple_extensions; + use std::iter::FromIterator; + + /// Create a [ConcreteType] from a [BuiltinParameterized] + fn concretize(builtin: BuiltinParameterized) -> ConcreteType { + ConcreteType::parameterized_builtin(builtin, false) + } + + /// Parse a string into a [ConcreteType] + fn parse_type(expr: &str) -> ConcreteType { + let parsed = TypeExpr::parse(expr).unwrap(); + ConcreteType::try_from(parsed).unwrap() + } + + /// Parse a string into a [ConcreteType], returning the result + fn parse_type_result(expr: &str) -> Result { + let parsed = TypeExpr::parse(expr).unwrap(); + ConcreteType::try_from(parsed) + } + + /// Parse a string into a builtin [ConcreteType], with no unresolved + /// extension references + fn parse_simple(s: &str) -> ConcreteType { + let parsed = TypeExpr::parse(s).unwrap(); + + let mut refs = Vec::new(); + parsed.visit_references(&mut |name| refs.push(name.to_string())); + assert!(refs.is_empty(), "{s} should not add an extension reference"); + + ConcreteType::try_from(parsed).unwrap() + } + + /// Create a type parameter from a type expression string + fn type_param(expr: &str) -> TypeParameter { + TypeParameter::Type(parse_type(expr)) + } + + /// Create an extension type + fn extension(name: &str, parameters: Vec, nullable: bool) -> ConcreteType { + ConcreteType::extension_with_params(name.to_string(), parameters, nullable) + } + + /// Convert a custom type to raw and back, ensuring round-trip consistency + fn round_trip(custom: &CustomType) { + let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx).unwrap(); + assert_eq!(&parsed, custom); + } + + /// Create a raw named struct (e.g. straight from YAML) from field name and + /// type pairs + fn raw_named_struct(fields: &[(&str, &str)]) -> RawType { + let map = serde_json::Map::from_iter( + fields + .iter() + .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))), + ); + + // Named struct YAML/json objects are inherently unordered; we sort the + // fields lexicographically when parsing so round-tripped output is + // deterministic. This test locks in that behaviour. + RawType::Object(map) + } + + #[test] + fn test_primitive_type_parsing() { + let cases = vec![ + ("bool", Some(PrimitiveType::Boolean)), + ("i8", Some(PrimitiveType::I8)), + ("i16", Some(PrimitiveType::I16)), + ("i32", Some(PrimitiveType::I32)), + ("i64", Some(PrimitiveType::I64)), + ("fp32", Some(PrimitiveType::Fp32)), + ("fp64", Some(PrimitiveType::Fp64)), + ("STRING", Some(PrimitiveType::String)), + ("binary", Some(PrimitiveType::Binary)), + ("uuid", Some(PrimitiveType::Uuid)), + ("date", Some(PrimitiveType::Date)), + ("interval_year", Some(PrimitiveType::IntervalYear)), + ("time", Some(PrimitiveType::Time)), + ("timestamp", Some(PrimitiveType::Timestamp)), + ("timestamp_tz", Some(PrimitiveType::TimestampTz)), + ("invalid", None), + ]; + + for (input, expected) in cases { + match expected { + Some(expected_type) => { + assert_eq!( + PrimitiveType::from_str(input).unwrap(), + expected_type, + "expected primitive type for {input}" + ); + } + None => { + assert!( + PrimitiveType::from_str(input).is_err(), + "expected parsing {input} to fail" + ); + } + } + } + } + + #[test] + fn test_parameterized_builtin_types() { + let cases = vec![ + ( + "precisiontime<2>", + concretize(BuiltinParameterized::PrecisionTime { precision: 2 }), + ), + ( + "precisiontimestamp<1>", + concretize(BuiltinParameterized::PrecisionTimestamp { precision: 1 }), + ), + ( + "precisiontimestamptz<5>", + concretize(BuiltinParameterized::PrecisionTimestampTz { precision: 5 }), + ), + ( + "DECIMAL<10,2>", + concretize(BuiltinParameterized::Decimal { + precision: 10, + scale: 2, + }), + ), + ( + "fixedchar<12>", + concretize(BuiltinParameterized::FixedChar { length: 12 }), + ), + ( + "VarChar<42>", + concretize(BuiltinParameterized::VarChar { length: 42 }), + ), + ( + "fixedbinary<8>", + concretize(BuiltinParameterized::FixedBinary { length: 8 }), + ), + ( + "interval_day<7>", + concretize(BuiltinParameterized::IntervalDay { precision: 7 }), + ), + ( + "interval_compound<6>", + concretize(BuiltinParameterized::IntervalCompound { precision: 6 }), + ), + ]; + + for (expr, expected) in cases { + let found = parse_simple(expr); + assert_eq!(found, expected, "unexpected type for {expr}"); + } + } + + #[test] + fn test_parameterized_builtin_range_errors() { + use ExtensionTypeError::InvalidParameterRange; + + let cases = vec![ + ("precisiontime<13>", "precisiontime", 0, 13, 0..=12), + ("precisiontime<-1>", "precisiontime", 0, -1, 0..=12), + ( + "precisiontimestamp<13>", + "precisiontimestamp", + 0, + 13, + 0..=12, + ), + ( + "precisiontimestamp<-1>", + "precisiontimestamp", + 0, + -1, + 0..=12, + ), + ( + "precisiontimestamptz<20>", + "precisiontimestamptz", + 0, + 20, + 0..=12, + ), + ("interval_day<10>", "interval_day", 0, 10, 0..=9), + ("DECIMAL<39,0>", "DECIMAL", 0, 39, 1..=38), + ("DECIMAL<0,0>", "DECIMAL", 0, 0, 1..=38), + ("DECIMAL<10,-1>", "DECIMAL", 1, -1, 0..=10), + ("DECIMAL<10,12>", "DECIMAL", 1, 12, 0..=10), + ]; + + for (expr, expected_type, expected_index, expected_value, expected_range) in cases { + match parse_type_result(expr) { + Ok(value) => panic!("expected error parsing {expr}, got {value:?}"), + Err(InvalidParameterRange { + type_name, + index, + value, + expected, + }) => { + assert_eq!(type_name, expected_type, "unexpected type for {expr}"); + assert_eq!(index, expected_index, "unexpected index for {expr}"); + assert_eq!( + value, + i64::from(expected_value), + "unexpected value for {expr}" + ); + assert_eq!(expected, expected_range, "unexpected range for {expr}"); + } + Err(other) => panic!("expected InvalidParameterRange for {expr}, got {other:?}"), + } + } + } + + #[test] + fn test_container_types() { + let cases = vec![ + ( + "List", + ConcreteType::list(ConcreteType::builtin(PrimitiveType::I32, false), false), + ), + ( + "List", + ConcreteType::list(ConcreteType::builtin(PrimitiveType::Fp64, true), false), + ), + ( + "Map?", + ConcreteType::map( + ConcreteType::builtin(PrimitiveType::I64, false), + ConcreteType::builtin(PrimitiveType::String, true), + true, + ), + ), + ( + "Struct?", + ConcreteType::r#struct( + vec![ + ConcreteType::builtin(PrimitiveType::I8, false), + ConcreteType::builtin(PrimitiveType::String, true), + ], + true, + ), + ), + ]; + + for (expr, expected) in cases { + assert_eq!(parse_type(expr), expected, "unexpected parse for {expr}"); + } + } + + #[test] + fn test_extension_types() { + let cases = vec![ + ( + "u!geo, 10>", + extension( + "geo", + vec![type_param("List"), TypeParameter::Integer(10)], + false, + ), + ), + ( + "Geo?>", + extension("Geo", vec![type_param("List")], true), + ), + ( + "Custom", + extension( + "Custom", + vec![ + type_param("string?"), + TypeParameter::Type(ConcreteType::builtin(PrimitiveType::Boolean, false)), + ], + false, + ), + ), + ]; + + for (expr, expected) in cases { + assert_eq!( + parse_type(expr), + expected, + "unexpected extension for {expr}" + ); + } + } + + #[test] + fn test_parameter_type_validation() { + let int_param = ParameterConstraint::Integer { + min: Some(1), + max: Some(10), + }; + let enum_param = ParameterConstraint::Enumeration { + options: ParsedEnumOptions::try_from(simple_extensions::EnumOptions(vec![ + "OVERFLOW".to_string(), + "ERROR".to_string(), + ])) + .unwrap(), + }; + + let cases = vec![ + (&int_param, Value::Number(5.into()), true), + (&int_param, Value::Number(0.into()), false), + (&int_param, Value::Number(11.into()), false), + (&int_param, Value::String("not a number".into()), false), + (&enum_param, Value::String("OVERFLOW".into()), true), + (&enum_param, Value::String("INVALID".into()), false), + ]; + + for (param, value, expected) in cases { + assert_eq!( + param.is_valid_value(&value), + expected, + "unexpected validation result for {value:?}" + ); + } + } + + #[test] + fn test_type_round_trip_display() { + let cases = vec![ + ("i32", None), + ("I64?", Some("i64?")), + ("list", None), + ("List", Some("list")), + ("map>", None), + ("struct", None), + ( + "Struct, Map>>", + Some("struct, map>>"), + ), + ( + "Map, Struct>>", + Some("map, struct>>"), + ), + ("u!custom", Some("custom")), + ]; + + for (input, expected) in cases { + let parsed = TypeExpr::parse(input).unwrap(); + let concrete = ConcreteType::try_from(parsed).unwrap(); + let actual = concrete.to_string(); + if let Some(expected_display) = expected { + assert_eq!(actual, expected_display, "unexpected display for {input}"); + } else { + assert_eq!(actual, input, "unexpected canonical output for {input}"); + } + } + } + + /// Test that named struct field order is stable and sorted + /// lexicographically when round-tripping through RawType. + /// + /// Normally, order in structs in SQL / relational algebra is significant; + /// but the spec doesn't say that, and it starts as a YAML map, which + /// doesn't generally preserve order, so for both implementation ease and + /// test stability, we sort the fields when parsing named structs. + /// (Preserving order would be difficult - most YAML parsers don't preserve + /// map order) + #[test] + fn test_named_struct_field_order_stability() -> Result<(), ExtensionTypeError> { + let mut raw_fields = serde_json::Map::new(); + raw_fields.insert("beta".to_string(), Value::String("i32".to_string())); + raw_fields.insert("alpha".to_string(), Value::String("string?".to_string())); + + let raw = RawType::Object(raw_fields); + let mut ctx = TypeContext::default(); + let concrete = Parse::parse(raw, &mut ctx)?; + + let round_tripped: RawType = concrete.into(); + match round_tripped { + RawType::Object(result_map) => { + let keys: Vec<_> = result_map.keys().collect(); + assert_eq!(keys, vec!["alpha", "beta"], "field order should be sorted"); + } + other => panic!("expected Object, got {other:?}"), + } + + Ok(()) + } + + #[test] + fn test_integer_param_bounds_round_trip() { + let cases = vec![ + ( + "bounded", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.0), + max: Some(10.0), + options: None, + optional: None, + }, + Ok((Some(1), Some(10))), + ), + ( + "fractional_min", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.5), + max: None, + options: None, + optional: None, + }, + Err(TypeParamError::InvalidIntegerBounds { + min: Some(1.5), + max: None, + }), + ), + ( + "fractional_max", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: None, + max: Some(9.9), + options: None, + optional: None, + }, + Err(TypeParamError::InvalidIntegerBounds { + min: None, + max: Some(9.9), + }), + ), + ]; + + for (label, item, expected) in cases { + match (TypeParam::try_from(item), expected) { + (Ok(tp), Ok((expected_min, expected_max))) => match tp.param_type { + ParameterConstraint::Integer { min, max } => { + assert_eq!(min, expected_min, "min mismatch for {label}"); + assert_eq!(max, expected_max, "max mismatch for {label}"); + } + _ => panic!("expected integer param type for {label}"), + }, + (Err(actual_err), Err(expected_err)) => { + assert_eq!(actual_err, expected_err, "unexpected error for {label}"); + } + (result, expected) => { + panic!("unexpected result for {label}: got {result:?}, expected {expected:?}") + } + } + } + } + + #[test] + fn test_custom_type_round_trip() -> Result<(), ExtensionTypeError> { + let fields = vec![ + ( + "x".to_string(), + ConcreteType::builtin(PrimitiveType::Fp64, false), + ), + ( + "y".to_string(), + ConcreteType::builtin(PrimitiveType::Fp64, false), + ), + ]; + let (names, types): (Vec<_>, Vec<_>) = fields.into_iter().unzip(); + + let cases = vec![ + CustomType::new( + "AliasType".to_string(), + vec![], + Some(ConcreteType::builtin(PrimitiveType::I32, false)), + None, + Some("a test alias type".to_string()), + )?, + CustomType::new( + "Point".to_string(), + vec![TypeParam::new( + "T".to_string(), + ParameterConstraint::DataType, + Some("a numeric value".to_string()), + )], + Some(ConcreteType::named_struct(names, types, false)), + None, + None, + )?, + ]; + + for custom in cases { + round_trip(&custom); + } + + Ok(()) + } + + #[test] + fn test_invalid_type_names() { + let cases = vec![ + ("", false), + ("bad name", false), + ("9bad", false), + ("bad-name", false), + ("bad.name", false), + ("GoodName", true), + ("also_good", true), + ("_underscore", true), + ("$dollar", true), + ("CamelCase123", true), + ]; + + for (name, expected_ok) in cases { + let result = CustomType::validate_name(name); + assert_eq!( + result.is_ok(), + expected_ok, + "unexpected validation for {name}" + ); + } + } + + #[test] + fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> { + let cases = vec![ + ( + "alias", + RawType::String("i32".to_string()), + ConcreteType::builtin(PrimitiveType::I32, false), + ), + ( + "named_struct", + raw_named_struct(&[("field1", "fp64"), ("field2", "i32?")]), + ConcreteType::named_struct( + vec!["field1".to_string(), "field2".to_string()], + vec![ + ConcreteType::builtin(PrimitiveType::Fp64, false), + ConcreteType::builtin(PrimitiveType::I32, true), + ], + false, + ), + ), + ]; + + for (label, raw, expected) in cases { + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(raw, &mut ctx)?; + assert_eq!(parsed, expected, "unexpected type for {label}"); + } + + Ok(()) + } + + #[test] + fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> { + let cases = vec![ + ( + "alias", + simple_extensions::SimpleExtensionsTypesItem { + name: "Alias".to_string(), + description: Some("Alias type".to_string()), + parameters: None, + structure: Some(RawType::String("BINARY".to_string())), + variadic: None, + }, + "Alias", + Some("Alias type"), + Some(ConcreteType::builtin(PrimitiveType::Binary, false)), + ), + ( + "named_struct", + simple_extensions::SimpleExtensionsTypesItem { + name: "Point".to_string(), + description: Some("A 2D point".to_string()), + parameters: None, + structure: Some(raw_named_struct(&[("x", "fp64"), ("y", "fp64?")])), + variadic: None, + }, + "Point", + Some("A 2D point"), + Some(ConcreteType::named_struct( + vec!["x".to_string(), "y".to_string()], + vec![ + ConcreteType::builtin(PrimitiveType::Fp64, false), + ConcreteType::builtin(PrimitiveType::Fp64, true), + ], + false, + )), + ), + ( + "no_structure", + simple_extensions::SimpleExtensionsTypesItem { + name: "Opaque".to_string(), + description: None, + parameters: None, + structure: None, + variadic: Some(true), + }, + "Opaque", + None, + None, + ), + ]; + + for (label, item, expected_name, expected_description, expected_structure) in cases { + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx)?; + assert_eq!(parsed.name, expected_name); + assert_eq!( + parsed.description.as_deref(), + expected_description, + "description mismatch for {label}" + ); + assert_eq!( + parsed.structure, expected_structure, + "structure mismatch for {label}" + ); + } + + Ok(()) + } + + /// A type defined with a structure cannot be defined as nullable; e.g. if + /// you define 'Integer' as an alias for 'i64?', then what do you mean by + /// 'INTEGER?' - is that now equal to `i64??` + #[test] + fn test_nullable_structure_rejected() { + let cases = vec![RawType::String("i32?".to_string())]; + + for raw in cases { + let mut ctx = TypeContext::default(); + let result = Parse::parse(raw, &mut ctx); + match result { + Err(ExtensionTypeError::StructureCannotBeNullable { type_string }) => { + assert!(type_string.contains('?')); + } + other => panic!("Expected nullable structure error, got {other:?}"), + } + } + } +}