From 076998e011d912742bee91caee3613c45c97e87d Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Tue, 12 Aug 2025 16:38:47 -0400 Subject: [PATCH 01/31] feat: ExtensionRegistry, wip --- Cargo.toml | 3 +- src/extensions.rs | 1 + src/lib.rs | 3 + src/registry/extension.rs | 230 +++++++++++++++++++++ src/registry/mod.rs | 37 ++++ src/registry/registry.rs | 226 +++++++++++++++++++++ src/registry/types.rs | 407 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 906 insertions(+), 1 deletion(-) create mode 100644 src/registry/extension.rs create mode 100644 src/registry/mod.rs create mode 100644 src/registry/registry.rs create mode 100644 src/registry/types.rs diff --git a/Cargo.toml b/Cargo.toml index 78358392..12bcc755 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,8 +27,9 @@ include = [ [features] default = [] extensions = ["dep:serde_yaml", "dep:url"] -parse = ["dep:hex", "dep:thiserror", "dep:url", "semver"] +parse = ["registry", "dep:hex", "semver"] protoc = ["dep:protobuf-src"] +registry = ["dep:thiserror", "dep:url"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] diff --git a/src/extensions.rs b/src/extensions.rs index 240799da..23b17bfc 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -18,5 +18,6 @@ mod tests { fn core_extensions() { // Force evaluation of core extensions. LazyLock::force(&EXTENSIONS); + println!("Core extensions: {:#?}", EXTENSIONS); } } diff --git a/src/lib.rs b/src/lib.rs index 508f8387..f5376cee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,5 +122,8 @@ pub mod proto; pub mod text; pub mod version; +#[cfg(feature = "registry")] +pub mod registry; + #[cfg(feature = "parse")] pub mod parse; diff --git a/src/registry/extension.rs b/src/registry/extension.rs new file mode 100644 index 00000000..1cf0f02a --- /dev/null +++ b/src/registry/extension.rs @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Validated extension file wrapper. +//! +//! This module provides `ExtensionFile`, a validated wrapper around SimpleExtensions +//! that ensures extension data is valid on construction and provides safe accessor methods. + +use thiserror::Error; +use url::Url; + +use crate::text::simple_extensions::{ + AggregateFunction, AggregateFunctionImplsItem, Arguments, ReturnValue, ScalarFunction, + ScalarFunctionImplsItem, SimpleExtensions, WindowFunction, WindowFunctionImplsItem, +}; + +/// Errors that can occur during extension validation +#[derive(Debug, Error, PartialEq)] +pub enum ValidationError { + /// A function implementation has None for arguments + #[error("Function '{function}' has implementation with missing arguments")] + MissingArguments { + /// The missing function + function: String, + }, + + /// A function implementation is missing a return type + #[error("Function '{function}' has implementation with missing return type")] + MissingReturnType { + /// The missing function + function: String, + }, + // TODO: Add more validation errors for malformed argument patterns, return type patterns, etc. +} + +/// A validated extension file containing functions and types from a single URI +#[derive(Debug)] +pub struct ExtensionFile { + /// The URI this extension was loaded from + pub uri: Url, + /// The validated extension data + extensions: SimpleExtensions, +} + +impl ExtensionFile { + /// Create a validated extension file from raw data + pub fn create(uri: Url, extensions: SimpleExtensions) -> Result { + // Validate scalar functions + for function in &extensions.scalar_functions { + Self::validate_scalar_function(function)?; + } + + // Validate aggregate functions + for function in &extensions.aggregate_functions { + Self::validate_aggregate_function(function)?; + } + + // Validate window functions + for function in &extensions.window_functions { + Self::validate_window_function(function)?; + } + + Ok(Self { uri, extensions }) + } + + /// Find a scalar function by name + pub fn find_scalar_function(&self, name: &str) -> Option { + self.extensions + .scalar_functions + .iter() + .find(|f| f.name == name) + .map(|f| ScalarFunctionRef(&self.uri, f)) + } + + /// Find an aggregate function by name + pub fn find_aggregate_function(&self, name: &str) -> Option { + self.extensions + .aggregate_functions + .iter() + .find(|f| f.name == name) + .map(|f| AggregateFunctionRef(&self.uri, f)) + } + + /// Find a window function by name + pub fn find_window_function(&self, name: &str) -> Option { + self.extensions + .window_functions + .iter() + .find(|f| f.name == name) + .map(|f| WindowFunctionRef(&self.uri, f)) + } + + // Private validation methods + + fn validate_scalar_function(function: &ScalarFunction) -> Result<(), ValidationError> { + for impl_item in &function.impls { + // Check that arguments are present (can be empty, but not None) + if impl_item.args.is_none() { + return Err(ValidationError::MissingArguments { + function: function.name.clone(), + }); + } + + // TODO: Validate that return type is well-formed + // For now, we assume return_ field existence is enforced by the type system + } + Ok(()) + } + + fn validate_aggregate_function(function: &AggregateFunction) -> Result<(), ValidationError> { + for impl_item in &function.impls { + if impl_item.args.is_none() { + return Err(ValidationError::MissingArguments { + function: function.name.clone(), + }); + } + } + Ok(()) + } + + fn validate_window_function(function: &WindowFunction) -> Result<(), ValidationError> { + for impl_item in &function.impls { + if impl_item.args.is_none() { + return Err(ValidationError::MissingArguments { + function: function.name.clone(), + }); + } + } + Ok(()) + } +} + +/// Handle for a validated scalar function definition +pub struct ScalarFunctionRef<'a>(&'a Url, &'a ScalarFunction); + +impl<'a> ScalarFunctionRef<'a> { + /// Get the function name + pub fn name(&self) -> &str { + &self.1.name + } + + /// Get all implementations as handles to specific type signatures + pub fn implementations(&self) -> impl Iterator> { + self.1 + .impls + .iter() + .map(move |impl_item| ScalarFunctionImplRef(self.0, impl_item)) + } +} + +/// Handle for a validated aggregate function definition +pub struct AggregateFunctionRef<'a>(&'a Url, &'a AggregateFunction); + +impl<'a> AggregateFunctionRef<'a> { + /// Get the function name + pub fn name(&self) -> &str { + &self.1.name + } + + /// Get all implementations as handles to specific type signatures + pub fn implementations(&self) -> impl Iterator> { + self.1 + .impls + .iter() + .map(move |impl_item| AggregateFunctionImplRef(self.0, impl_item)) + } +} + +/// Handle for a validated window function definition +pub struct WindowFunctionRef<'a>(&'a Url, &'a WindowFunction); + +impl<'a> WindowFunctionRef<'a> { + /// Get the function name + pub fn name(&self) -> &str { + &self.1.name + } + + /// Get all implementations as handles to specific type signatures + pub fn implementations(&self) -> impl Iterator> { + self.1 + .impls + .iter() + .map(move |impl_item| WindowFunctionImplRef(self.0, impl_item)) + } +} + +/// Handle for a specific scalar function implementation with validated signature +#[derive(Debug, Copy, Clone)] +pub struct ScalarFunctionImplRef<'a>(&'a Url, &'a ScalarFunctionImplsItem); + +impl<'a> ScalarFunctionImplRef<'a> { + /// Get the argument signature (guaranteed to be present due to validation) + pub fn args(&self) -> &Arguments { + self.1.args.as_ref().expect("validated to be present") + } + + /// Get the return type pattern + pub fn return_type(&self) -> &ReturnValue { + &self.1.return_ + } +} + +/// Handle for a specific aggregate function implementation with validated signature +pub struct AggregateFunctionImplRef<'a>(&'a Url, &'a AggregateFunctionImplsItem); + +impl<'a> AggregateFunctionImplRef<'a> { + /// Get the argument signature (guaranteed to be present due to validation) + pub fn args(&self) -> &Arguments { + self.1.args.as_ref().expect("validated to be present") + } + + /// Get the return type pattern + pub fn return_type(&self) -> &ReturnValue { + &self.1.return_ + } +} + +/// Handle for a specific window function implementation with validated signature +pub struct WindowFunctionImplRef<'a>(&'a Url, &'a WindowFunctionImplsItem); + +impl<'a> WindowFunctionImplRef<'a> { + /// Get the argument signature (guaranteed to be present due to validation) + pub fn args(&self) -> &Arguments { + self.1.args.as_ref().expect("validated to be present") + } + + /// Get the return type pattern + pub fn return_type(&self) -> &ReturnValue { + &self.1.return_ + } +} diff --git a/src/registry/mod.rs b/src/registry/mod.rs new file mode 100644 index 00000000..7379406d --- /dev/null +++ b/src/registry/mod.rs @@ -0,0 +1,37 @@ +//! Substrait Extension Registry +//! +//! This module provides types and methods that abstract over Substrait +//! SimpleExtensions. +//! +//! ## Design Philosophy +//! +//! Internally, the types in this module are handles to the raw parsed +//! SimpleExtensions from the text module. Externally, they provide a coherent +//! interface that hides those internal details and presents methods where +//! extensions are validated on creation and then assumed valid thereafter. +//! +//! This "validate once, assume valid" approach allows for: +//! - **Type safety**: Invalid extensions are caught at construction time +//! - **Performance**: No repeated validation during registry operations +//! - **Clean APIs**: Methods can focus on logic rather than error handling +//! - **Reliability**: Once constructed, registry operations won't fail due to +//! malformed data +//! +//! ## Core Types +//! +//! - [`ExtensionFile`]: Validated wrapper around a SimpleExtensions + URI +//! - [`ConcreteType`]: Fully-specified types for function arguments and return +//! values +//! - [`TypeSignature`]: Pattern matching for function signatures +//! - [`GlobalRegistry`]: Immutable registry for URI+name based function lookup + +mod extension; +mod registry; +pub mod types; + +pub use extension::{ + ExtensionFile, ValidationError, + ScalarFunctionRef, AggregateFunctionRef, WindowFunctionRef, + ScalarFunctionImplRef, AggregateFunctionImplRef, WindowFunctionImplRef, +}; +pub use registry::GlobalRegistry; \ No newline at end of file diff --git a/src/registry/registry.rs b/src/registry/registry.rs new file mode 100644 index 00000000..e5d592b3 --- /dev/null +++ b/src/registry/registry.rs @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Substrait Extension Registry +//! +//! This module provides registries for Substrait extensions: +//! - **Global Registry**: Immutable, reusable across plans, URI+name based lookup +//! - **Local Registry**: Per-plan, anchor-based, references Global Registry (TODO) +//! +//! This module is only available when the `registry` feature is enabled. + +#![cfg(feature = "registry")] + +use thiserror::Error; +use url::Url; + +use super::{ + types::{ConcreteType, TypeSignature}, + ExtensionFile, +}; + +/// Errors that can occur when using the Global Registry +#[derive(Debug, Error, PartialEq)] +pub enum GlobalRegistryError { + /// The specified extension URI is not registered in this registry + #[error("Unknown extension URI: {0}")] + UnknownExtensionUri(String), + /// The specified function was not found in the given extension + #[error("Function '{function}' not found in extension '{uri}'")] + FunctionNotFound { + /// The extension URI where the function was expected + uri: String, + /// The name of the function that was not found + function: String, + }, + /// No function signature matches the provided arguments + #[error("No matching signature for function '{function}' in extension '{uri}' with provided arguments")] + NoMatchingSignature { + /// The extension URI containing the function + uri: String, + /// The name of the function with no matching signature + function: String, + }, +} + +impl GlobalRegistryError { + /// Create a FunctionNotFound error + pub fn not_found(uri: &Url, function: &str) -> Self { + Self::FunctionNotFound { + uri: uri.to_string(), + function: function.to_string(), + } + } +} + +/// Global Extension Registry that manages Substrait extensions +/// +/// This registry is immutable and reusable across multiple plans. +/// It provides URI + name based lookup for function validation and signature matching. +#[derive(Debug)] +pub struct GlobalRegistry { + /// Simple Extensions from parsed and validated YAML files + pub extensions: Vec, +} + +impl GlobalRegistry { + /// Create a new Global Registry from validated extension files + pub fn new(extensions: Vec) -> Self { + Self { extensions } + } + + /// Get an iterator over all extension files in this registry + pub fn extensions(&self) -> impl Iterator { + self.extensions.iter() + } + + /// Create a Global Registry from the built-in core extensions + #[cfg(feature = "extensions")] + pub fn from_core_extensions() -> Self { + use crate::extensions::EXTENSIONS; + use std::sync::LazyLock; + + // Force evaluation of core extensions + LazyLock::force(&EXTENSIONS); + + // Convert HashMap to Vec + let extensions: Vec = EXTENSIONS + .iter() + .map(|(uri, simple_extensions)| { + ExtensionFile::create(uri.clone(), simple_extensions.clone()) + .expect("Core extensions should be valid") + }) + .collect(); + + Self { extensions } + } + + // Private helper methods + + fn get_extension(&self, uri: &Url) -> Result<&ExtensionFile, GlobalRegistryError> { + self.extensions + .iter() + .find(|ext| &ext.uri == uri) + .ok_or_else(|| GlobalRegistryError::UnknownExtensionUri(uri.to_string())) + } + + /// Validate a scalar function call and return the concrete return type + pub fn validate_scalar_call<'a, 'b>( + &'a self, + uri: &Url, + name: &str, + args: &'b [ConcreteType<'a>], + ) -> Result, GlobalRegistryError> { + let extension = self.get_extension(uri)?; + let function_ref = extension + .find_scalar_function(name) + .ok_or_else(|| GlobalRegistryError::not_found(uri, name))?; + + // Try each implementation until one matches + for impl_ref in function_ref.implementations() { + let signature = TypeSignature::new(impl_ref.args(), impl_ref.return_type()); + if let Some(return_type) = signature.matches(args) { + return Ok(return_type); + } + } + + Err(GlobalRegistryError::NoMatchingSignature { + uri: uri.to_string(), + function: name.to_string(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::types::{BuiltinType, ConcreteType}; + use crate::text::simple_extensions::*; + + fn create_test_extension() -> SimpleExtensions { + SimpleExtensions { + scalar_functions: vec![ScalarFunction { + name: "add".to_string(), + description: Some("Addition function".to_string()), + impls: vec![ScalarFunctionImplsItem { + args: Some(Arguments(vec![])), // Simplified for testing. TODO: Add real args + return_: ReturnValue(Type::Variant0("i32".to_string())), + deterministic: None, + implementation: None, + nullability: None, + options: None, + session_dependent: None, + variadic: None, + }], + }], + aggregate_functions: vec![], + window_functions: vec![], + dependencies: Default::default(), + type_variations: vec![], + types: vec![], + } + } + + #[test] + fn test_new_registry() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let extension_file = ExtensionFile::create(uri.clone(), create_test_extension()).unwrap(); + let extensions = vec![extension_file]; + + let registry = GlobalRegistry::new(extensions); + assert_eq!(registry.extensions().count(), 1); + let extension_uris: Vec<&Url> = registry.extensions().map(|ext| &ext.uri).collect(); + assert!(extension_uris.contains(&&uri)); + } + + #[test] + fn test_validate_scalar_call_with_test_extension() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let extension_file = ExtensionFile::create(uri.clone(), create_test_extension()).unwrap(); + let extensions = vec![extension_file]; + + let registry = GlobalRegistry::new(extensions); + let args: &[ConcreteType] = &[]; // Empty ConcreteType args + + let result = registry.validate_scalar_call(&uri, "add", args); + assert!(result.is_ok()); + } + + #[test] + fn test_standard_extension() { + let registry = GlobalRegistry::from_core_extensions(); + let arithmetic_uri = Url::parse("https://github.com/substrait-io/substrait/raw/v0.57.0/extensions/functions_arithmetic.yaml").unwrap(); + + // Test that add function fails with no arguments (should require 2 arguments) + let no_args: &[ConcreteType] = &[]; + let result_no_args = registry.validate_scalar_call(&arithmetic_uri, "add", no_args); + assert!( + result_no_args.is_err(), + "add function should fail with no arguments" + ); + + // Test that add function succeeds with two i32 arguments and returns i32 + let i32_args = &[ + ConcreteType::builtin(BuiltinType::I32, false), + ConcreteType::builtin(BuiltinType::I32, false), + ]; + let result_with_args = registry.validate_scalar_call(&arithmetic_uri, "add", i32_args); + assert!( + result_with_args.is_ok(), + "add function should succeed with two i32 arguments" + ); + + // Verify it returns the correct concrete type (i32) + let return_type = result_with_args.unwrap(); + assert_eq!( + return_type, + ConcreteType::builtin(BuiltinType::I32, false), + "add(i32, i32) should return i32" + ); + } + + #[test] + fn test_from_core_extensions() { + let registry = GlobalRegistry::from_core_extensions(); + assert!(registry.extensions().count() > 0); + } +} diff --git a/src/registry/types.rs b/src/registry/types.rs new file mode 100644 index 00000000..a20b922b --- /dev/null +++ b/src/registry/types.rs @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Concrete type system for function validation in the registry. +//! +//! This module provides a clean, type-safe wrapper around Substrait extension types, +//! separating function signature patterns from concrete argument types. + +use crate::text::simple_extensions::{ + Arguments, ArgumentsItem, ReturnValue, SimpleExtensionsTypesItem, Type, +}; +use std::collections::HashMap; +use std::str::FromStr; +use url::Url; + +/// Substrait built-in primitive types +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BuiltinType { + /// Boolean type - `bool` + Boolean, + /// 8-bit signed integer - `i8` + I8, + /// 16-bit signed integer - `i16` + I16, + /// 32-bit signed integer - `i32` + I32, + /// 64-bit signed integer - `i64` + I64, + /// 32-bit floating point - `fp32` + Fp32, + /// 64-bit floating point - `fp64` + Fp64, + /// Variable-length string - `string` + String, + /// Variable-length binary data - `binary` + Binary, + /// Calendar date - `date` + Date, + /// Time of day - `time` (deprecated, use precision_time) + Time, + /// Date and time - `timestamp` (deprecated, use precision_timestamp) + Timestamp, + /// Date and time with timezone - `timestamp_tz` (deprecated, use precision_timestamp_tz) + TimestampTz, + /// Year-month interval - `interval_year` + IntervalYear, + /// Day-time interval - `interval_day` + IntervalDay, + /// Compound interval - `interval_compound` + IntervalCompound, + /// UUID type - `uuid` + Uuid, + /// Fixed-length character string - `fixed_char` + FixedChar, + /// Variable-length character string - `varchar` + VarChar, + /// Fixed-length binary data - `fixed_binary` + FixedBinary, + /// Decimal number - `decimal` + Decimal, + /// Time with precision - `precision_time` + PrecisionTime, + /// Timestamp with precision - `precision_timestamp` + PrecisionTimestamp, + /// Timestamp with timezone and precision - `precision_timestamp_tz` + PrecisionTimestampTz, + /// Struct/record type - `struct` + Struct, + /// List/array type - `list` + List, + /// Map/dictionary type - `map` + Map, + /// User-defined type - `user_defined` + UserDefined, +} + +#[derive(Debug, thiserror::Error)] +/// Error for unrecognized builtin type strings +#[error("Unrecognized builtin type")] +pub struct UnrecognizedBuiltin; + +impl FromStr for BuiltinType { + type Err = UnrecognizedBuiltin; + + fn from_str(s: &str) -> Result { + match s { + "boolean" => Ok(BuiltinType::Boolean), + "i8" => Ok(BuiltinType::I8), + "i16" => Ok(BuiltinType::I16), + "i32" => Ok(BuiltinType::I32), + "i64" => Ok(BuiltinType::I64), + "fp32" => Ok(BuiltinType::Fp32), + "fp64" => Ok(BuiltinType::Fp64), + "string" => Ok(BuiltinType::String), + "binary" => Ok(BuiltinType::Binary), + "date" => Ok(BuiltinType::Date), + "time" => Ok(BuiltinType::Time), + "timestamp" => Ok(BuiltinType::Timestamp), + "timestamp_tz" => Ok(BuiltinType::TimestampTz), + "interval_year" => Ok(BuiltinType::IntervalYear), + "interval_day" => Ok(BuiltinType::IntervalDay), + "interval_compound" => Ok(BuiltinType::IntervalCompound), + "uuid" => Ok(BuiltinType::Uuid), + "fixed_char" => Ok(BuiltinType::FixedChar), + "varchar" => Ok(BuiltinType::VarChar), + "fixed_binary" => Ok(BuiltinType::FixedBinary), + "decimal" => Ok(BuiltinType::Decimal), + "precision_time" => Ok(BuiltinType::PrecisionTime), + "precision_timestamp" => Ok(BuiltinType::PrecisionTimestamp), + "precision_timestamp_tz" => Ok(BuiltinType::PrecisionTimestampTz), + "struct" => Ok(BuiltinType::Struct), + "list" => Ok(BuiltinType::List), + "map" => Ok(BuiltinType::Map), + "user_defined" => Ok(BuiltinType::UserDefined), + _ => Err(UnrecognizedBuiltin), + } + } +} + +/// Represents a known, specific type, either builtin or extension +#[derive(Clone, Debug)] +pub enum KnownType<'a> { + /// Built-in primitive types + Builtin(BuiltinType), + /// Custom types defined in extension YAML files + Extension(&'a Url, &'a SimpleExtensionsTypesItem), +} + +impl<'a> PartialEq for KnownType<'a> { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (KnownType::Builtin(a), KnownType::Builtin(b)) => a == b, + // For extension types, compare by URI and name + (KnownType::Extension(au, a), KnownType::Extension(bu, b)) => { + // There should only be one type with a given name per URI + a.name == b.name && au == bu + } + _ => false, + } + } +} + +/// A concrete type, fully specified with nullability and parameters +#[derive(Clone, Debug, PartialEq)] +pub struct ConcreteType<'a> { + /// Base type, can be builtin or extension + pub base: KnownType<'a>, + /// Is the overall type nullable? + pub nullable: bool, + // TODO: Add non-type parameters (e.g. integers, enum, etc.) + /// Parameters for the type, if there are any + pub parameters: Vec>, +} + +impl<'a> ConcreteType<'a> { + /// Create a concrete type from a builtin type + pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType<'static> { + ConcreteType { + base: KnownType::Builtin(builtin_type), + nullable, + parameters: Vec::new(), + } + } + + /// Create a concrete type from an extension type + pub fn extension( + uri: &'a Url, + ext_type: &'a SimpleExtensionsTypesItem, + nullable: bool, + ) -> Self { + Self { + base: KnownType::Extension(uri, ext_type), + nullable, + parameters: Vec::new(), + } + } + + /// Create a parameterized concrete type + pub fn parameterized( + base: KnownType<'a>, + nullable: bool, + parameters: Vec>, + ) -> Self { + Self { + base, + nullable, + parameters, + } + } +} + +/// A parsed type that can represent type variables, builtin types, extension types, or parameterized types +#[derive(Clone, Debug, PartialEq)] +pub enum ParsedType<'a> { + /// Type variable like any1, any2, etc. + TypeVariable(u32), + /// Nullable type variable like any1?, any2?, etc.; used in return types + NullableTypeVariable(u32), + /// Built-in primitive type, with nullability + Builtin(BuiltinType, bool), + /// Extension type from a specific URI, with nullability + NamedExtension(&'a str, bool), + /// Parameterized type + Parameterized { + /// Base type, can be builtin or extension + base: Box>, + /// Parameters for that type + parameters: Vec>, + /// Is the overall type nullable? + nullable: bool, + }, +} + +impl<'a> ParsedType<'a> { + /// Parse a type string into a ParsedType + pub fn parse(type_str: &'a str) -> ParsedType<'a> { + // Strip nullability + let (type_str, nullability) = if let Some(rest) = type_str.strip_suffix('?') { + (rest, true) + } else { + (type_str, false) + }; + + // Handle any expressions + if let Some(rest) = type_str.strip_prefix("any") { + if let Ok(id) = rest.parse::() { + if nullability { + // any1? etc. are nullable type variables - permissible in + // return position + return ParsedType::NullableTypeVariable(id); + } else { + return ParsedType::TypeVariable(id); + } + } + } + + // Handle parameterized types like "list" (future implementation) + if type_str.contains('<') && type_str.ends_with('>') { + unimplemented!("Parameterized types not yet implemented: {}", type_str); + } + + // Try to parse as builtin type + if let Ok(builtin_type) = BuiltinType::from_str(type_str) { + return ParsedType::Builtin(builtin_type, nullability); + } + + // Not a builtin or type variable - assume it's an extension type name + ParsedType::NamedExtension(type_str, nullability) + } +} + +/// Wrapper around function signature patterns with smart matching +pub struct TypeSignature<'a> { + args: &'a Arguments, + return_type: &'a ReturnValue, + variadic: bool, +} + +impl<'a> TypeSignature<'a> { + /// Create a new type signature from function definition parts + pub(crate) fn new(args: &'a Arguments, return_type: &'a ReturnValue) -> Self { + Self { + args, + return_type, + variadic: false, // TODO: Extract from function impl + } + } + + /// Check if concrete argument types match this signature pattern + pub fn matches<'b: 'a>( + &'a self, + concrete_args: &'b [ConcreteType], + ) -> Option> { + // Convert raw arguments to ArgumentPatterns + let arg_patterns: Vec = self + .args + .iter() + .filter_map(|arg| ArgumentPattern::from_argument_item(arg)) + .collect(); + + // Create type bindings by matching patterns against concrete arguments + let _bindings = TypeBindings::new(&arg_patterns, concrete_args)?; + + // If arguments match, return the inferred return type + unimplemented!("Return type inference not yet implemented") + } +} + +/// A pattern for function arguments that can match concrete types or type variables +#[derive(Clone, Debug, PartialEq)] +pub enum ArgumentPattern<'a> { + /// Type variable like any1, any2, etc. + TypeVariable(u32), + /// Concrete type pattern + Concrete(ConcreteType<'a>), +} + +/// Result of matching an argument pattern against a concrete type +#[derive(Clone, Debug, PartialEq)] +pub enum Match<'a> { + /// Pattern matched exactly (for concrete patterns) + Concrete, + /// Type variable bound to concrete type + Variable(u32, ConcreteType<'a>), + /// Match failed + Fail, +} + +impl<'a> ArgumentPattern<'a> { + /// Create an argument pattern from a raw ArgumentsItem + pub fn from_argument_item(item: &ArgumentsItem) -> Option> { + match item { + ArgumentsItem::ValueArg(value_arg) => ArgumentPattern::from_type(&value_arg.value), + _ => unimplemented!("Handle non-ValueArg argument types"), + } + } + + /// Create an argument pattern from a type string + fn from_type(type_val: &Type) -> Option> { + match type_val { + Type::Variant0(type_str) => { + let parsed_type = ParsedType::parse(type_str); + match parsed_type { + ParsedType::TypeVariable(id) => Some(ArgumentPattern::TypeVariable(id)), + ParsedType::NullableTypeVariable(_) => { + panic!("Nullable type variables not allowed in argument position") + } + ParsedType::Builtin(builtin_type, nullable) => Some(ArgumentPattern::Concrete( + ConcreteType::builtin(builtin_type, nullable), + )), + ParsedType::NamedExtension(_, _) => { + unimplemented!("Extension types not yet supported in argument patterns") + } + ParsedType::Parameterized { .. } => { + unimplemented!("Parameterized types not yet supported in argument patterns") + } + } + } + _ => unimplemented!("Handle non-string type variants"), + } + } + + /// Check if this pattern matches the given concrete type + pub fn matches(&self, concrete: &ConcreteType<'a>) -> Match<'a> { + match self { + ArgumentPattern::TypeVariable(id) => Match::Variable(*id, concrete.clone()), + ArgumentPattern::Concrete(pattern_type) => { + if pattern_type == concrete { + Match::Concrete + } else { + Match::Fail + } + } + } + } +} + +/// Type variable bindings from matching function arguments +#[derive(Debug, Clone, PartialEq)] +pub struct TypeBindings<'a> { + /// Map of type variable IDs (e.g. 1 for 'any1') to their concrete types + pub vars: HashMap>, +} + +impl<'a> TypeBindings<'a> { + /// Create type bindings by matching argument patterns against concrete arguments + pub fn new(patterns: &[ArgumentPattern<'a>], args: &[ConcreteType<'a>]) -> Option { + // Check length compatibility + if patterns.len() != args.len() { + unimplemented!("Handle variadic functions"); + } + + let mut vars = HashMap::new(); + + // Match each pattern against corresponding argument + for (pattern, arg) in patterns.iter().zip(args.iter()) { + match pattern.matches(arg) { + Match::Concrete => { + // Concrete pattern matched, continue + continue; + } + Match::Variable(id, concrete_type) => { + // Check for consistency with existing bindings + if let Some(existing_binding) = vars.get(&id) { + if existing_binding != &concrete_type { + // Conflicting binding - type variable bound to different types + return None; + } + } else { + // New binding + vars.insert(id, concrete_type); + } + } + Match::Fail => { + // Pattern didn't match + return None; + } + } + } + + Some(TypeBindings { vars }) + } + + /// Get the bound type for a type variable, if any + pub fn get(&self, var_id: u32) -> Option<&ConcreteType<'a>> { + self.vars.get(&var_id) + } +} From 57e5dee6630e30a21e516569e8417bf2beba8e99 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 13 Aug 2025 18:51:42 -0400 Subject: [PATCH 02/31] wip compiles --- src/parse/context.rs | 6 +- .../proto/extensions/simple_extension_uri.rs | 7 +- src/parse/proto/plan_version.rs | 8 +- src/parse/proto/version.rs | 7 +- src/registry/context.rs | 45 +++ src/registry/extension.rs | 270 +++++++++++++++--- src/registry/mod.rs | 21 +- src/registry/registry.rs | 27 +- src/registry/types.rs | 155 ++++------ 9 files changed, 375 insertions(+), 171 deletions(-) create mode 100644 src/registry/context.rs diff --git a/src/parse/context.rs b/src/parse/context.rs index fe3ff142..0eb54f08 100644 --- a/src/parse/context.rs +++ b/src/parse/context.rs @@ -22,7 +22,9 @@ pub trait Context { { item.parse(self) } +} +pub trait ProtoContext: Context { /// Add a [SimpleExtensionUri] to this context. Must return an error for duplicate /// anchors or when the URI is not supported. /// @@ -83,7 +85,9 @@ pub(crate) mod tests { } } - impl super::Context for Context { + impl super::Context for Context {} + + impl super::ProtoContext for Context { fn add_simple_extension_uri( &mut self, simple_extension_uri: &crate::parse::proto::extensions::SimpleExtensionUri, diff --git a/src/parse/proto/extensions/simple_extension_uri.rs b/src/parse/proto/extensions/simple_extension_uri.rs index 8fcb6df6..efd97db6 100644 --- a/src/parse/proto/extensions/simple_extension_uri.rs +++ b/src/parse/proto/extensions/simple_extension_uri.rs @@ -6,7 +6,10 @@ use thiserror::Error; use url::Url; use crate::{ - parse::{context::ContextError, Anchor, Context, Parse}, + parse::{ + context::{ContextError, ProtoContext}, + Anchor, Context, Parse, + }, proto, }; @@ -48,7 +51,7 @@ pub enum SimpleExtensionUriError { Context(#[from] ContextError), } -impl Parse for proto::extensions::SimpleExtensionUri { +impl Parse for proto::extensions::SimpleExtensionUri { type Parsed = SimpleExtensionUri; type Error = SimpleExtensionUriError; diff --git a/src/parse/proto/plan_version.rs b/src/parse/proto/plan_version.rs index 6e95659f..25c06cc2 100644 --- a/src/parse/proto/plan_version.rs +++ b/src/parse/proto/plan_version.rs @@ -3,7 +3,11 @@ //! Parsing of [proto::PlanVersion]. use crate::{ - parse::{context::Context, proto::Version, Parse}, + parse::{ + context::{Context, ProtoContext}, + proto::Version, + Parse, + }, proto, }; use thiserror::Error; @@ -38,7 +42,7 @@ pub enum PlanVersionError { Version(#[from] VersionError), } -impl Parse for proto::PlanVersion { +impl Parse for proto::PlanVersion { type Parsed = PlanVersion; type Error = PlanVersionError; diff --git a/src/parse/proto/version.rs b/src/parse/proto/version.rs index 59e42feb..63c9b0b8 100644 --- a/src/parse/proto/version.rs +++ b/src/parse/proto/version.rs @@ -3,7 +3,10 @@ //! Parsing of [proto::Version]. use crate::{ - parse::{context::Context, Parse}, + parse::{ + context::{Context, ProtoContext}, + Parse, + }, proto, version, }; use hex::FromHex; @@ -75,7 +78,7 @@ pub enum VersionError { Substrait(semver::Version, semver::VersionReq), } -impl Parse for proto::Version { +impl Parse for proto::Version { type Parsed = Version; type Error = VersionError; diff --git a/src/registry/context.rs b/src/registry/context.rs new file mode 100644 index 00000000..d41953b4 --- /dev/null +++ b/src/registry/context.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Extension parsing context for validation. + +use std::collections::HashMap; + +use url::Url; + +use crate::parse::Context; +use crate::text::simple_extensions::SimpleExtensionsTypesItem; + +/// Context for parsing and validating extension definitions. +/// +/// This context accumulates validated types as they are parsed, +/// allowing later elements to reference previously validated types. +pub struct ExtensionContext<'a> { + /// The URI of the extension being parsed. + pub uri: &'a Url, + /// Map of type names to their definitions + types: HashMap<&'a str, &'a SimpleExtensionsTypesItem>, +} + +impl<'a> ExtensionContext<'a> { + /// Create a new extension context for parsing. + pub fn new(uri: &'a Url) -> Self { + Self { + uri, + types: HashMap::new(), + } + } + + /// Check if a type with the given name exists in the context + pub fn has_type(&self, name: &str) -> bool { + self.types.contains_key(name) + } + + /// Add a type to the context after it has been validated + pub(crate) fn add_type(&mut self, type_item: &'a SimpleExtensionsTypesItem) { + self.types.insert(&type_item.name, type_item); + } +} + +impl Context for ExtensionContext<'_> { + // Implementation required by the Context trait +} diff --git a/src/registry/extension.rs b/src/registry/extension.rs index 1cf0f02a..718ca24d 100644 --- a/src/registry/extension.rs +++ b/src/registry/extension.rs @@ -8,13 +8,19 @@ use thiserror::Error; use url::Url; +use crate::parse::Parse; +use crate::registry::types::InvalidTypeName; use crate::text::simple_extensions::{ - AggregateFunction, AggregateFunctionImplsItem, Arguments, ReturnValue, ScalarFunction, - ScalarFunctionImplsItem, SimpleExtensions, WindowFunction, WindowFunctionImplsItem, + AggregateFunction, AggregateFunctionImplsItem, Arguments, ArgumentsItem, ReturnValue, + ScalarFunction, ScalarFunctionImplsItem, SimpleExtensions, SimpleExtensionsTypesItem, Type, + WindowFunction, WindowFunctionImplsItem, }; +use super::context::ExtensionContext; +use super::types::{ArgumentPattern, ConcreteType, ExtensionType, ParsedType, TypeBindings}; + /// Errors that can occur during extension validation -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] pub enum ValidationError { /// A function implementation has None for arguments #[error("Function '{function}' has implementation with missing arguments")] @@ -29,11 +35,33 @@ pub enum ValidationError { /// The missing function function: String, }, - // TODO: Add more validation errors for malformed argument patterns, return type patterns, etc. + + /// A type string is malformed or unrecognized + #[error("Invalid type string '{type_str}' in function '{function}': todo!")] + InvalidArgument { + /// The function containing the invalid type + function: String, + /// The malformed type string + type_str: String, + // Reason why the type is invalid + // reason: todo!(), + }, + + /// A type name is invalid + #[error("{0}")] + InvalidTypeName(InvalidTypeName), } -/// A validated extension file containing functions and types from a single URI -#[derive(Debug)] +impl From for ValidationError { + fn from(err: InvalidTypeName) -> Self { + ValidationError::InvalidTypeName(err) + } +} + +// TODO: Add more validation errors for malformed argument patterns, return type patterns, etc. +/// A validated extension file containing functions and types from a single URI. +/// All functions should have valid argument and return type patterns. +#[derive(Debug, Clone)] pub struct ExtensionFile { /// The URI this extension was loaded from pub uri: Url, @@ -43,7 +71,16 @@ pub struct ExtensionFile { impl ExtensionFile { /// Create a validated extension file from raw data - pub fn create(uri: Url, extensions: SimpleExtensions) -> Result { + pub fn create( + uri: Url, + extensions: SimpleExtensions, + ) -> Result { + // Parse/validate types first - they're referenced by functions + let mut ctx = ExtensionContext::new(&uri); + for type_item in &extensions.types { + let _validated_type = type_item.parse(&mut ctx)?; + } + // Validate scalar functions for function in &extensions.scalar_functions { Self::validate_scalar_function(function)?; @@ -68,7 +105,10 @@ impl ExtensionFile { .scalar_functions .iter() .find(|f| f.name == name) - .map(|f| ScalarFunctionRef(&self.uri, f)) + .map(|f| ScalarFunctionRef { + file: self, + function: f, + }) } /// Find an aggregate function by name @@ -77,7 +117,10 @@ impl ExtensionFile { .aggregate_functions .iter() .find(|f| f.name == name) - .map(|f| AggregateFunctionRef(&self.uri, f)) + .map(|f| AggregateFunctionRef { + file: self, + function: f, + }) } /// Find a window function by name @@ -86,7 +129,58 @@ impl ExtensionFile { .window_functions .iter() .find(|f| f.name == name) - .map(|f| WindowFunctionRef(&self.uri, f)) + .map(|f| WindowFunctionRef { + file: self, + function: f, + }) + } + + /// Find a type by name + pub fn find_type(&self, name: &str) -> Option<&SimpleExtensionsTypesItem> { + let types = self.extensions.types.as_slice(); + types.iter().find(|t| t.name == name) + } + + /// Create an argument pattern from a raw ArgumentsItem + fn argument_pattern_from_item(&self, item: &ArgumentsItem) -> Option { + match item { + ArgumentsItem::ValueArg(value_arg) => self.argument_pattern_from_type(&value_arg.value), + _ => unimplemented!("Handle non-ValueArg argument types"), + } + } + + /// Create an argument pattern from a type string + fn argument_pattern_from_type(&self, type_val: &Type) -> Option { + match type_val { + Type::Variant0(type_str) => { + let parsed_type = ParsedType::parse(type_str); + match parsed_type { + ParsedType::TypeVariable(id) => Some(ArgumentPattern::TypeVariable(id)), + ParsedType::NullableTypeVariable(_) => { + panic!("Nullable type variables not allowed in argument position") + } + ParsedType::Builtin(builtin_type, nullable) => Some(ArgumentPattern::Concrete( + ConcreteType::builtin(builtin_type, nullable), + )), + ParsedType::NamedExtension(name, nullability) => { + // Find the extension type by name using the find_type method + let ext_type = self + .find_type(name) + .expect("This should have been validated"); + + let ext_type_wrapper = ExtensionType::new_unchecked(&self.uri, ext_type); + Some(ArgumentPattern::Concrete(ConcreteType::extension( + ext_type_wrapper, + nullability, + ))) + } + ParsedType::Parameterized { .. } => { + unimplemented!("Parameterized types not yet supported in argument patterns") + } + } + } + _ => unimplemented!("Handle non-string type variants"), + } } // Private validation methods @@ -130,101 +224,189 @@ impl ExtensionFile { } /// Handle for a validated scalar function definition -pub struct ScalarFunctionRef<'a>(&'a Url, &'a ScalarFunction); +pub struct ScalarFunctionRef<'a> { + file: &'a ExtensionFile, + function: &'a ScalarFunction, +} impl<'a> ScalarFunctionRef<'a> { /// Get the function name pub fn name(&self) -> &str { - &self.1.name + &self.function.name } /// Get all implementations as handles to specific type signatures - pub fn implementations(&self) -> impl Iterator> { - self.1 + pub fn implementations(self) -> impl Iterator> { + self.function .impls .iter() - .map(move |impl_item| ScalarFunctionImplRef(self.0, impl_item)) + .map(move |impl_item| ScalarImplementation { + file: self.file, + impl_item, + }) } } /// Handle for a validated aggregate function definition -pub struct AggregateFunctionRef<'a>(&'a Url, &'a AggregateFunction); +pub struct AggregateFunctionRef<'a> { + file: &'a ExtensionFile, + function: &'a AggregateFunction, +} impl<'a> AggregateFunctionRef<'a> { /// Get the function name pub fn name(&self) -> &str { - &self.1.name + &self.function.name } /// Get all implementations as handles to specific type signatures - pub fn implementations(&self) -> impl Iterator> { - self.1 + pub fn implementations(&self) -> impl Iterator> + '_ { + self.function .impls .iter() - .map(move |impl_item| AggregateFunctionImplRef(self.0, impl_item)) + .map(move |impl_item| AggregateFunctionImplRef { + file: self.file, + impl_item, + }) } } /// Handle for a validated window function definition -pub struct WindowFunctionRef<'a>(&'a Url, &'a WindowFunction); +pub struct WindowFunctionRef<'a> { + file: &'a ExtensionFile, + function: &'a WindowFunction, +} impl<'a> WindowFunctionRef<'a> { /// Get the function name pub fn name(&self) -> &str { - &self.1.name + &self.function.name } /// Get all implementations as handles to specific type signatures - pub fn implementations(&self) -> impl Iterator> { - self.1 + pub fn implementations(&self) -> impl Iterator> + '_ { + self.function .impls .iter() - .map(move |impl_item| WindowFunctionImplRef(self.0, impl_item)) + .map(move |impl_item| WindowFunctionImplRef { + file: self.file, + impl_item, + }) } } /// Handle for a specific scalar function implementation with validated signature #[derive(Debug, Copy, Clone)] -pub struct ScalarFunctionImplRef<'a>(&'a Url, &'a ScalarFunctionImplsItem); +pub struct ScalarImplementation<'a> { + file: &'a ExtensionFile, + impl_item: &'a ScalarFunctionImplsItem, +} -impl<'a> ScalarFunctionImplRef<'a> { - /// Get the argument signature (guaranteed to be present due to validation) - pub fn args(&self) -> &Arguments { - self.1.args.as_ref().expect("validated to be present") - } +impl<'a> ScalarImplementation<'a> { + /// Check if this implementation can be called with the given concrete argument types + /// Returns the inferred concrete return type if the call would succeed, None otherwise + pub fn call_with(&self, concrete_args: &[ConcreteType<'a>]) -> Option> { + // Convert raw arguments to ArgumentPatterns using ExtensionFile context + let arg_patterns: Vec> = self + .impl_item + .args + .as_ref() + .expect("validated to be present") + .iter() + .filter_map(|arg| self.file.argument_pattern_from_item(arg)) + .collect(); - /// Get the return type pattern - pub fn return_type(&self) -> &ReturnValue { - &self.1.return_ + // Create type bindings by matching patterns against concrete arguments + let _bindings: TypeBindings<'a> = TypeBindings::new(&arg_patterns, concrete_args)?; + + if concrete_args.len() > 1_000_000 { + // For lifetime management + return concrete_args.first().cloned(); + } + + // If arguments match, parse and return the inferred return type + let return_type_str = match &self.impl_item.return_ { + ReturnValue(Type::Variant0(type_str)) => type_str, + _ => unimplemented!("Handle non-string return types"), + }; + + let parsed_return_type = ParsedType::parse(return_type_str); + match parsed_return_type { + ParsedType::Builtin(builtin_type, nullable) => { + Some(ConcreteType::builtin(builtin_type, nullable)) + } + ParsedType::TypeVariable(id) => { + // Look up the bound type for this variable + if let Some(bound_type) = _bindings.get(id) { + Some(bound_type.clone()) + } else { + None + } + } + ParsedType::NullableTypeVariable(id) => { + // Look up the bound type and make it nullable + if let Some(mut bound_type) = _bindings.get(id).cloned() { + bound_type.nullable = true; + Some(bound_type) + } else { + None + } + } + ParsedType::NamedExtension(name, nullable) => { + // Find the extension type by name + let ext_type = self + .file + .find_type(name) + .expect("This should have been validated"); + + let ext_type_wrapper = ExtensionType::new_unchecked(&self.file.uri, ext_type); + Some(ConcreteType::extension(ext_type_wrapper, nullable)) + } + ParsedType::Parameterized { .. } => { + unimplemented!("Parameterized return types not yet supported") + } + } } } /// Handle for a specific aggregate function implementation with validated signature -pub struct AggregateFunctionImplRef<'a>(&'a Url, &'a AggregateFunctionImplsItem); +pub struct AggregateFunctionImplRef<'a> { + file: &'a ExtensionFile, + impl_item: &'a AggregateFunctionImplsItem, +} impl<'a> AggregateFunctionImplRef<'a> { /// Get the argument signature (guaranteed to be present due to validation) - pub fn args(&self) -> &Arguments { - self.1.args.as_ref().expect("validated to be present") + fn args(&self) -> &Arguments { + self.impl_item + .args + .as_ref() + .expect("validated to be present") } /// Get the return type pattern - pub fn return_type(&self) -> &ReturnValue { - &self.1.return_ + fn return_type(&self) -> &ReturnValue { + &self.impl_item.return_ } } /// Handle for a specific window function implementation with validated signature -pub struct WindowFunctionImplRef<'a>(&'a Url, &'a WindowFunctionImplsItem); +pub struct WindowFunctionImplRef<'a> { + file: &'a ExtensionFile, + impl_item: &'a WindowFunctionImplsItem, +} impl<'a> WindowFunctionImplRef<'a> { /// Get the argument signature (guaranteed to be present due to validation) - pub fn args(&self) -> &Arguments { - self.1.args.as_ref().expect("validated to be present") + fn args(&self) -> &Arguments { + self.impl_item + .args + .as_ref() + .expect("validated to be present") } /// Get the return type pattern - pub fn return_type(&self) -> &ReturnValue { - &self.1.return_ + fn return_type(&self) -> &ReturnValue { + &self.impl_item.return_ } } diff --git a/src/registry/mod.rs b/src/registry/mod.rs index 7379406d..a032c2f2 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -10,28 +10,27 @@ //! interface that hides those internal details and presents methods where //! extensions are validated on creation and then assumed valid thereafter. //! -//! This "validate once, assume valid" approach allows for: -//! - **Type safety**: Invalid extensions are caught at construction time -//! - **Performance**: No repeated validation during registry operations -//! - **Clean APIs**: Methods can focus on logic rather than error handling -//! - **Reliability**: Once constructed, registry operations won't fail due to -//! malformed data +//! This allows for a clean API that externally follows the "parse don't +//! validate" principle, with an API that encourages users to work with +//! validated extensions without worrying about their internal structure, +//! without needing to add entirely new parse trees - the type tree can be +//! recreated on-demand. //! //! ## Core Types //! //! - [`ExtensionFile`]: Validated wrapper around a SimpleExtensions + URI //! - [`ConcreteType`]: Fully-specified types for function arguments and return //! values -//! - [`TypeSignature`]: Pattern matching for function signatures //! - [`GlobalRegistry`]: Immutable registry for URI+name based function lookup +mod context; mod extension; mod registry; pub mod types; pub use extension::{ - ExtensionFile, ValidationError, - ScalarFunctionRef, AggregateFunctionRef, WindowFunctionRef, - ScalarFunctionImplRef, AggregateFunctionImplRef, WindowFunctionImplRef, + AggregateFunctionImplRef, AggregateFunctionRef, ExtensionFile, ScalarFunctionRef, + ScalarImplementation, ValidationError, WindowFunctionImplRef, WindowFunctionRef, }; -pub use registry::GlobalRegistry; \ No newline at end of file +pub use registry::GlobalRegistry; +pub use types::ConcreteType; diff --git a/src/registry/registry.rs b/src/registry/registry.rs index e5d592b3..c45943b7 100644 --- a/src/registry/registry.rs +++ b/src/registry/registry.rs @@ -13,10 +13,9 @@ use thiserror::Error; use url::Url; -use super::{ - types::{ConcreteType, TypeSignature}, - ExtensionFile, -}; +use crate::registry::ScalarFunctionRef; + +use super::{types::ConcreteType, ExtensionFile}; /// Errors that can occur when using the Global Registry #[derive(Debug, Error, PartialEq)] @@ -58,8 +57,8 @@ impl GlobalRegistryError { /// It provides URI + name based lookup for function validation and signature matching. #[derive(Debug)] pub struct GlobalRegistry { - /// Simple Extensions from parsed and validated YAML files - pub extensions: Vec, + /// Pre-validated extension files + extensions: Vec, } impl GlobalRegistry { @@ -87,6 +86,9 @@ impl GlobalRegistry { .iter() .map(|(uri, simple_extensions)| { ExtensionFile::create(uri.clone(), simple_extensions.clone()) + .map_err(|err| { + eprintln!("Failed to create extension file for {}: {}", uri, err); + }) .expect("Core extensions should be valid") }) .collect(); @@ -104,21 +106,20 @@ impl GlobalRegistry { } /// Validate a scalar function call and return the concrete return type - pub fn validate_scalar_call<'a, 'b>( + pub fn validate_scalar_call<'a>( &'a self, uri: &Url, name: &str, - args: &'b [ConcreteType<'a>], - ) -> Result, GlobalRegistryError> { - let extension = self.get_extension(uri)?; - let function_ref = extension + args: &[ConcreteType<'a>], + ) -> Result, GlobalRegistryError> { + let extension: &'a ExtensionFile = self.get_extension(uri)?; + let function_ref: ScalarFunctionRef<'a> = extension .find_scalar_function(name) .ok_or_else(|| GlobalRegistryError::not_found(uri, name))?; // Try each implementation until one matches for impl_ref in function_ref.implementations() { - let signature = TypeSignature::new(impl_ref.args(), impl_ref.return_type()); - if let Some(return_type) = signature.matches(args) { + if let Some(return_type) = impl_ref.call_with(args) { return Ok(return_type); } } diff --git a/src/registry/types.rs b/src/registry/types.rs index a20b922b..726da4df 100644 --- a/src/registry/types.rs +++ b/src/registry/types.rs @@ -5,9 +5,9 @@ //! This module provides a clean, type-safe wrapper around Substrait extension types, //! separating function signature patterns from concrete argument types. -use crate::text::simple_extensions::{ - Arguments, ArgumentsItem, ReturnValue, SimpleExtensionsTypesItem, Type, -}; +use crate::parse::Parse; +use crate::registry::context::ExtensionContext; +use crate::text::simple_extensions::SimpleExtensionsTypesItem; use std::collections::HashMap; use std::str::FromStr; use url::Url; @@ -115,28 +115,65 @@ impl FromStr for BuiltinType { } } } +/// A validated extension type definition +#[derive(Clone, Debug)] +pub struct ExtensionType<'a> { + /// The URI of the extension defining this type + pub uri: &'a Url, + item: &'a SimpleExtensionsTypesItem, +} + +impl<'a> ExtensionType<'a> { + /// Create a new ExtensionType wrapper from already-validated data + pub(crate) fn new_unchecked(uri: &'a Url, item: &'a SimpleExtensionsTypesItem) -> Self { + Self { uri, item } + } + + /// Get the name of this extension type + pub fn name(&self) -> &str { + &self.item.name + } +} + +impl<'a> From> for &'a SimpleExtensionsTypesItem { + fn from(ext_type: ExtensionType<'a>) -> Self { + ext_type.item + } +} + +impl PartialEq for ExtensionType<'_> { + fn eq(&self, other: &Self) -> bool { + // There should only be one type of a given name per file + self.uri == other.uri && self.item.name == other.item.name + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Invalid type name: {0}")] +/// Error for invalid type names in extension definitions +pub struct InvalidTypeName(String); + +impl<'a> Parse> for &'a SimpleExtensionsTypesItem { + type Parsed = ExtensionType<'a>; + // TODO: Not all names are valid for types, we should validate that + type Error = InvalidTypeName; + + fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { + ctx.add_type(self); + Ok(ExtensionType { + uri: ctx.uri, + item: &self, + }) + } +} /// Represents a known, specific type, either builtin or extension -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum KnownType<'a> { /// Built-in primitive types Builtin(BuiltinType), /// Custom types defined in extension YAML files - Extension(&'a Url, &'a SimpleExtensionsTypesItem), -} - -impl<'a> PartialEq for KnownType<'a> { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (KnownType::Builtin(a), KnownType::Builtin(b)) => a == b, - // For extension types, compare by URI and name - (KnownType::Extension(au, a), KnownType::Extension(bu, b)) => { - // There should only be one type with a given name per URI - a.name == b.name && au == bu - } - _ => false, - } - } + Extension(ExtensionType<'a>), } /// A concrete type, fully specified with nullability and parameters @@ -162,13 +199,9 @@ impl<'a> ConcreteType<'a> { } /// Create a concrete type from an extension type - pub fn extension( - uri: &'a Url, - ext_type: &'a SimpleExtensionsTypesItem, - nullable: bool, - ) -> Self { + pub fn extension(t: ExtensionType<'a>, nullable: bool) -> Self { Self { - base: KnownType::Extension(uri, ext_type), + base: KnownType::Extension(t), nullable, parameters: Vec::new(), } @@ -197,7 +230,7 @@ pub enum ParsedType<'a> { NullableTypeVariable(u32), /// Built-in primitive type, with nullability Builtin(BuiltinType, bool), - /// Extension type from a specific URI, with nullability + /// Extension type for the given name, with nullability. URI not known at this level. NamedExtension(&'a str, bool), /// Parameterized type Parameterized { @@ -248,43 +281,6 @@ impl<'a> ParsedType<'a> { } } -/// Wrapper around function signature patterns with smart matching -pub struct TypeSignature<'a> { - args: &'a Arguments, - return_type: &'a ReturnValue, - variadic: bool, -} - -impl<'a> TypeSignature<'a> { - /// Create a new type signature from function definition parts - pub(crate) fn new(args: &'a Arguments, return_type: &'a ReturnValue) -> Self { - Self { - args, - return_type, - variadic: false, // TODO: Extract from function impl - } - } - - /// Check if concrete argument types match this signature pattern - pub fn matches<'b: 'a>( - &'a self, - concrete_args: &'b [ConcreteType], - ) -> Option> { - // Convert raw arguments to ArgumentPatterns - let arg_patterns: Vec = self - .args - .iter() - .filter_map(|arg| ArgumentPattern::from_argument_item(arg)) - .collect(); - - // Create type bindings by matching patterns against concrete arguments - let _bindings = TypeBindings::new(&arg_patterns, concrete_args)?; - - // If arguments match, return the inferred return type - unimplemented!("Return type inference not yet implemented") - } -} - /// A pattern for function arguments that can match concrete types or type variables #[derive(Clone, Debug, PartialEq)] pub enum ArgumentPattern<'a> { @@ -306,39 +302,6 @@ pub enum Match<'a> { } impl<'a> ArgumentPattern<'a> { - /// Create an argument pattern from a raw ArgumentsItem - pub fn from_argument_item(item: &ArgumentsItem) -> Option> { - match item { - ArgumentsItem::ValueArg(value_arg) => ArgumentPattern::from_type(&value_arg.value), - _ => unimplemented!("Handle non-ValueArg argument types"), - } - } - - /// Create an argument pattern from a type string - fn from_type(type_val: &Type) -> Option> { - match type_val { - Type::Variant0(type_str) => { - let parsed_type = ParsedType::parse(type_str); - match parsed_type { - ParsedType::TypeVariable(id) => Some(ArgumentPattern::TypeVariable(id)), - ParsedType::NullableTypeVariable(_) => { - panic!("Nullable type variables not allowed in argument position") - } - ParsedType::Builtin(builtin_type, nullable) => Some(ArgumentPattern::Concrete( - ConcreteType::builtin(builtin_type, nullable), - )), - ParsedType::NamedExtension(_, _) => { - unimplemented!("Extension types not yet supported in argument patterns") - } - ParsedType::Parameterized { .. } => { - unimplemented!("Parameterized types not yet supported in argument patterns") - } - } - } - _ => unimplemented!("Handle non-string type variants"), - } - } - /// Check if this pattern matches the given concrete type pub fn matches(&self, concrete: &ConcreteType<'a>) -> Match<'a> { match self { From 68f0e7df874f6c49d5d9fe0623d647750606c95b Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Thu, 14 Aug 2025 10:47:55 -0400 Subject: [PATCH 03/31] working through some implications --- src/registry/context.rs | 8 +++ src/registry/types.rs | 149 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/src/registry/context.rs b/src/registry/context.rs index d41953b4..ce1aa134 100644 --- a/src/registry/context.rs +++ b/src/registry/context.rs @@ -8,6 +8,7 @@ use url::Url; use crate::parse::Context; use crate::text::simple_extensions::SimpleExtensionsTypesItem; +use super::types::ExtensionType; /// Context for parsing and validating extension definitions. /// @@ -38,6 +39,13 @@ impl<'a> ExtensionContext<'a> { pub(crate) fn add_type(&mut self, type_item: &'a SimpleExtensionsTypesItem) { self.types.insert(&type_item.name, type_item); } + + /// Get a type by name from the context, returning the ExtensionType handle + pub(crate) fn get_type(&self, name: &str) -> Option> { + self.types.get(name).map(|&item| { + super::types::ExtensionType::new_unchecked(self.uri, item) + }) + } } impl Context for ExtensionContext<'_> { diff --git a/src/registry/types.rs b/src/registry/types.rs index 726da4df..0e5c770c 100644 --- a/src/registry/types.rs +++ b/src/registry/types.rs @@ -7,7 +7,8 @@ use crate::parse::Parse; use crate::registry::context::ExtensionContext; -use crate::text::simple_extensions::SimpleExtensionsTypesItem; +use crate::text::simple_extensions::Type as extType; +use crate::text::simple_extensions::{ArgumentsItem, SimpleExtensionsTypesItem}; use std::collections::HashMap; use std::str::FromStr; use url::Url; @@ -167,6 +168,152 @@ impl<'a> Parse> for &'a SimpleExtensionsTypesItem { } } +/// Error for invalid Type specifications +#[derive(Debug, thiserror::Error)] +pub enum TypeParseError { + /// Extension type name not found in context + #[error("Extension type '{name}' not found")] + ExtensionTypeNotFound { + /// The extension type name that was not found + name: String, + }, + /// Type variable ID is invalid (must be >= 1) + #[error("Type variable 'any{id}' is invalid (must be >= 1)")] + InvalidTypeVariableId { + /// The invalid type variable ID + id: u32, + }, + /// Unimplemented Type variant + #[error("Unimplemented Type variant")] + UnimplementedVariant, +} + +/// A validated Type that wraps the original Type with its validated ParsedType representation +#[derive(Debug, Clone)] +pub struct ValidatedType<'a> { + /// The original Type from the YAML + original: &'a extType, + /// The validated, parsed representation + pub parsed: ParsedType<'a>, +} + +impl<'a> ValidatedType<'a> { + /// Get the parsed type representation + pub fn parsed_type(&self) -> &ParsedType<'a> { + &self.parsed + } +} + +impl<'a> From> for &'a extType { + fn from(validated: ValidatedType<'a>) -> Self { + validated.original + } +} + +impl<'a> Parse> for &'a extType { + type Parsed = ValidatedType<'a>; + type Error = TypeParseError; + + fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { + match self { + extType::Variant0(type_str) => { + // Parse the type string into ParsedType + let parsed_type = ParsedType::parse(type_str); + + // Add context validation + match &parsed_type { + ParsedType::NamedExtension(name, _nullable) => { + // Verify the extension type exists in the context + if !ctx.has_type(name) { + return Err(TypeParseError::ExtensionTypeNotFound { + name: name.to_string(), + }); + } + } + ParsedType::TypeVariable(id) | ParsedType::NullableTypeVariable(id) => { + // Validate type variable ID (must be >= 1) + if *id == 0 { + return Err(TypeParseError::InvalidTypeVariableId { id: *id }); + } + } + ParsedType::Builtin(_, _) => { + // Builtin types are always valid + } + ParsedType::Parameterized { .. } => { + // TODO: Add validation for parameterized types + unimplemented!("Parameterized type validation not yet implemented") + } + } + + Ok(ValidatedType { + original: self, + parsed: parsed_type, + }) + } + _ => Err(TypeParseError::UnimplementedVariant), + } + } +} + +/// Error for invalid ArgumentsItem specifications +#[derive(Debug, thiserror::Error)] +pub enum ArgumentsItemError { + /// Type parsing failed + #[error("Type parsing failed: {0}")] + TypeParseError(#[from] TypeParseError), + /// Unsupported ArgumentsItem variant + #[error("Unimplemented ArgumentsItem variant: {variant}")] + UnimplementedVariant { + /// The unsupported variant name + variant: String, + }, +} + +impl<'a> Parse> for &'a ArgumentsItem { + type Parsed = ArgumentPattern<'a>; + type Error = ArgumentsItemError; + + fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { + match self { + ArgumentsItem::ValueArgument(value_arg) => { + // Parse the Type into ValidatedType, then convert to ArgumentPattern + let validated_type = value_arg.value.parse(ctx)?; + let parsed_type = &validated_type.parsed; + + match parsed_type { + ParsedType::TypeVariable(id) => Ok(ArgumentPattern::TypeVariable(*id)), + ParsedType::NullableTypeVariable(_) => { + panic!("Nullable type variables not allowed in argument position") + } + ParsedType::Builtin(builtin_type, nullable) => Ok(ArgumentPattern::Concrete( + ConcreteType::builtin(*builtin_type, *nullable), + )), + ParsedType::NamedExtension(name, nullable) => { + // Find the extension type by name using the context + // We know it exists because Type parsing already validated it + let ext_type = ctx + .get_type(name) + .expect("Extension type should exist after validation"); + + Ok(ArgumentPattern::Concrete(ConcreteType::extension( + ext_type, *nullable, + ))) + } + ParsedType::Parameterized { .. } => { + unimplemented!("Parameterized types not yet supported in argument patterns") + } + } + } + ArgumentsItem::EnumArgument(_) => Err(ArgumentsItemError::UnsupportedVariant { + variant: "EnumArgument".to_string(), + }), + ArgumentsItem::TypeArgument(_) => Err(ArgumentsItemError::UnsupportedVariant { + variant: "TypeArgument".to_string(), + }), + } + } +} + /// Represents a known, specific type, either builtin or extension #[derive(Clone, Debug, PartialEq)] pub enum KnownType<'a> { From dfef7eee3250f61d734e3f02cdc3dc8c1ab71052 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 15 Aug 2025 11:32:44 -0400 Subject: [PATCH 04/31] Switched types over, compiles --- Cargo.toml | 4 +- src/registry/context.rs | 29 +- src/registry/extension.rs | 51 ++- src/registry/registry.rs | 4 +- src/registry/types.rs | 684 ++++++++++++++++++++++++++++---------- 5 files changed, 551 insertions(+), 221 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 12bcc755..1ce48694 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,9 @@ include = [ [features] default = [] extensions = ["dep:serde_yaml", "dep:url"] -parse = ["registry", "dep:hex", "semver"] +parse = ["dep:hex", "semver"] protoc = ["dep:protobuf-src"] -registry = ["dep:thiserror", "dep:url"] +registry = ["dep:thiserror", "dep:url", "parse"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] diff --git a/src/registry/context.rs b/src/registry/context.rs index ce1aa134..d3f7e006 100644 --- a/src/registry/context.rs +++ b/src/registry/context.rs @@ -7,23 +7,22 @@ use std::collections::HashMap; use url::Url; use crate::parse::Context; -use crate::text::simple_extensions::SimpleExtensionsTypesItem; -use super::types::ExtensionType; +use super::types::CustomType; /// Context for parsing and validating extension definitions. /// /// This context accumulates validated types as they are parsed, /// allowing later elements to reference previously validated types. -pub struct ExtensionContext<'a> { +pub struct ExtensionContext { /// The URI of the extension being parsed. - pub uri: &'a Url, - /// Map of type names to their definitions - types: HashMap<&'a str, &'a SimpleExtensionsTypesItem>, + pub uri: Url, + /// Map of type names to their validated definitions + types: HashMap, } -impl<'a> ExtensionContext<'a> { +impl ExtensionContext { /// Create a new extension context for parsing. - pub fn new(uri: &'a Url) -> Self { + pub fn new(uri: Url) -> Self { Self { uri, types: HashMap::new(), @@ -36,18 +35,16 @@ impl<'a> ExtensionContext<'a> { } /// Add a type to the context after it has been validated - pub(crate) fn add_type(&mut self, type_item: &'a SimpleExtensionsTypesItem) { - self.types.insert(&type_item.name, type_item); + pub(crate) fn add_type(&mut self, custom_type: &CustomType) { + self.types.insert(custom_type.name.clone(), custom_type.clone()); } - /// Get a type by name from the context, returning the ExtensionType handle - pub(crate) fn get_type(&self, name: &str) -> Option> { - self.types.get(name).map(|&item| { - super::types::ExtensionType::new_unchecked(self.uri, item) - }) + /// Get a type by name from the context, returning the CustomType + pub(crate) fn get_type(&self, name: &str) -> Option<&CustomType> { + self.types.get(name) } } -impl Context for ExtensionContext<'_> { +impl Context for ExtensionContext { // Implementation required by the Context trait } diff --git a/src/registry/extension.rs b/src/registry/extension.rs index 718ca24d..44cb268c 100644 --- a/src/registry/extension.rs +++ b/src/registry/extension.rs @@ -9,7 +9,7 @@ use thiserror::Error; use url::Url; use crate::parse::Parse; -use crate::registry::types::InvalidTypeName; +use crate::registry::types::{ExtensionTypeError, InvalidTypeName}; use crate::text::simple_extensions::{ AggregateFunction, AggregateFunctionImplsItem, Arguments, ArgumentsItem, ReturnValue, ScalarFunction, ScalarFunctionImplsItem, SimpleExtensions, SimpleExtensionsTypesItem, Type, @@ -17,7 +17,7 @@ use crate::text::simple_extensions::{ }; use super::context::ExtensionContext; -use super::types::{ArgumentPattern, ConcreteType, ExtensionType, ParsedType, TypeBindings}; +use super::types::{ArgumentPattern, ConcreteType, ParsedType, TypeBindings}; /// Errors that can occur during extension validation #[derive(Debug, Error)] @@ -50,6 +50,10 @@ pub enum ValidationError { /// A type name is invalid #[error("{0}")] InvalidTypeName(InvalidTypeName), + + /// Extension type error + #[error("Extension type error: {0}")] + ExtensionTypeError(#[from] ExtensionTypeError), } impl From for ValidationError { @@ -71,14 +75,11 @@ pub struct ExtensionFile { impl ExtensionFile { /// Create a validated extension file from raw data - pub fn create( - uri: Url, - extensions: SimpleExtensions, - ) -> Result { + pub fn create(uri: Url, extensions: SimpleExtensions) -> Result { // Parse/validate types first - they're referenced by functions - let mut ctx = ExtensionContext::new(&uri); + let mut ctx = ExtensionContext::new(uri.clone()); for type_item in &extensions.types { - let _validated_type = type_item.parse(&mut ctx)?; + let _validated_type = type_item.clone().parse(&mut ctx)?; } // Validate scalar functions @@ -168,11 +169,8 @@ impl ExtensionFile { .find_type(name) .expect("This should have been validated"); - let ext_type_wrapper = ExtensionType::new_unchecked(&self.uri, ext_type); - Some(ArgumentPattern::Concrete(ConcreteType::extension( - ext_type_wrapper, - nullability, - ))) + // TODO: Update when ExtensionType is fully integrated and ArgumentPattern is owned + todo!("Update when ExtensionType constructor is available") } ParsedType::Parameterized { .. } => { unimplemented!("Parameterized types not yet supported in argument patterns") @@ -201,24 +199,13 @@ impl ExtensionFile { } fn validate_aggregate_function(function: &AggregateFunction) -> Result<(), ValidationError> { - for impl_item in &function.impls { - if impl_item.args.is_none() { - return Err(ValidationError::MissingArguments { - function: function.name.clone(), - }); - } - } + // Note: args can legitimately be None for functions like count() that count records + // rather than field values, so we don't validate args presence here Ok(()) } fn validate_window_function(function: &WindowFunction) -> Result<(), ValidationError> { - for impl_item in &function.impls { - if impl_item.args.is_none() { - return Err(ValidationError::MissingArguments { - function: function.name.clone(), - }); - } - } + // Note: args can legitimately be None for some window functions Ok(()) } } @@ -305,9 +292,9 @@ pub struct ScalarImplementation<'a> { impl<'a> ScalarImplementation<'a> { /// Check if this implementation can be called with the given concrete argument types /// Returns the inferred concrete return type if the call would succeed, None otherwise - pub fn call_with(&self, concrete_args: &[ConcreteType<'a>]) -> Option> { + pub fn call_with(&self, concrete_args: &[ConcreteType]) -> Option { // Convert raw arguments to ArgumentPatterns using ExtensionFile context - let arg_patterns: Vec> = self + let arg_patterns: Vec = self .impl_item .args .as_ref() @@ -317,7 +304,7 @@ impl<'a> ScalarImplementation<'a> { .collect(); // Create type bindings by matching patterns against concrete arguments - let _bindings: TypeBindings<'a> = TypeBindings::new(&arg_patterns, concrete_args)?; + let _bindings: TypeBindings = TypeBindings::new(&arg_patterns, concrete_args)?; if concrete_args.len() > 1_000_000 { // For lifetime management @@ -359,8 +346,8 @@ impl<'a> ScalarImplementation<'a> { .find_type(name) .expect("This should have been validated"); - let ext_type_wrapper = ExtensionType::new_unchecked(&self.file.uri, ext_type); - Some(ConcreteType::extension(ext_type_wrapper, nullable)) + // TODO: Update when ExtensionType is fully integrated + todo!("Update when ExtensionType constructor is available") } ParsedType::Parameterized { .. } => { unimplemented!("Parameterized return types not yet supported") diff --git a/src/registry/registry.rs b/src/registry/registry.rs index c45943b7..b5e96b8c 100644 --- a/src/registry/registry.rs +++ b/src/registry/registry.rs @@ -110,8 +110,8 @@ impl GlobalRegistry { &'a self, uri: &Url, name: &str, - args: &[ConcreteType<'a>], - ) -> Result, GlobalRegistryError> { + args: &[ConcreteType], + ) -> Result { let extension: &'a ExtensionFile = self.get_extension(uri)?; let function_ref: ScalarFunctionRef<'a> = extension .find_scalar_function(name) diff --git a/src/registry/types.rs b/src/registry/types.rs index 0e5c770c..71e6a9fb 100644 --- a/src/registry/types.rs +++ b/src/registry/types.rs @@ -7,11 +7,13 @@ use crate::parse::Parse; use crate::registry::context::ExtensionContext; -use crate::text::simple_extensions::Type as extType; -use crate::text::simple_extensions::{ArgumentsItem, SimpleExtensionsTypesItem}; +use crate::text::simple_extensions::{ + EnumOptions, SimpleExtensionsTypesItem, TypeParamDefsItem, + TypeParamDefsItemType, +}; use std::collections::HashMap; use std::str::FromStr; -use url::Url; +use thiserror::Error; /// Substrait built-in primitive types #[derive(Clone, Debug, PartialEq, Eq)] @@ -116,36 +118,284 @@ impl FromStr for BuiltinType { } } } -/// A validated extension type definition +/// Parameter type for extension type definitions +#[derive(Clone, Debug, PartialEq)] +pub enum ParameterType { + /// A type name + DataType, + /// True/False + Boolean, + /// Integer + Integer, + /// A particular enum + Enum, + /// A string + String, +} + +/// What a type actually represents - either a reference to another type or a compound structure #[derive(Clone, Debug)] -pub struct ExtensionType<'a> { - /// The URI of the extension defining this type - pub uri: &'a Url, - item: &'a SimpleExtensionsTypesItem, +pub enum TypeDefinition { + /// Reference to another type by name (e.g., "i32", "string", or custom type name) + Reference(String), + /// Compound structure with named fields + Struct(HashMap), } -impl<'a> ExtensionType<'a> { - /// Create a new ExtensionType wrapper from already-validated data - pub(crate) fn new_unchecked(uri: &'a Url, item: &'a SimpleExtensionsTypesItem) -> Self { - Self { uri, item } +/// Type-safe parameter constraints based on parameter kind +#[derive(Clone, Debug)] +pub enum ParamKind { + /// A type name parameter + DataType, + /// True/False parameter + Boolean, + /// Integer parameter with optional bounds + Integer { + /// Minimum value constraint + min: Option, + /// Maximum value constraint + max: Option + }, + /// Enumeration parameter with predefined options + Enumeration { + /// Valid enumeration values + options: Vec + }, + /// String parameter + String, +} + +impl ParamKind { + fn get_integer_bounds(min: Option, max: Option) -> Result<(Option, Option), TypeParamError> { + // Convert float bounds to integers, validating they are whole numbers + let min_bound = if let Some(min_f) = min { + if min_f.fract() != 0.0 { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + Some(min_f as i64) + } else { + None + }; + + let max_bound = if let Some(max_f) = max { + if max_f.fract() != 0.0 { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + Some(max_f as i64) + } else { + None + }; + + Ok((min_bound, max_bound)) } - /// Get the name of this extension type - pub fn name(&self) -> &str { - &self.item.name + /// Create a ParamKind from TypeParamDefsItemType and associated fields + fn try_from_item_parts( + item_type: TypeParamDefsItemType, + min: Option, + max: Option, + options: Option, + ) -> Result { + match (item_type, min, max, options) { + // Valid cases - each type with its expected parameters + (TypeParamDefsItemType::DataType, None, None, None) => Ok(ParamKind::DataType), + (TypeParamDefsItemType::Boolean, None, None, None) => Ok(ParamKind::Boolean), + (TypeParamDefsItemType::Integer, min, max, None) => { + let (min_bound, max_bound) = Self::get_integer_bounds(min, max)?; + Ok(ParamKind::Integer { min: min_bound, max: max_bound }) + } + (TypeParamDefsItemType::Enumeration, None, None, Some(enum_options)) => { + Ok(ParamKind::Enumeration { options: enum_options.0 }) + } + (TypeParamDefsItemType::String, None, None, None) => Ok(ParamKind::String), + + // Error cases - DataType with unexpected parameters + (TypeParamDefsItemType::DataType, Some(_), _, _) | (TypeParamDefsItemType::DataType, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::DataType }) + } + (TypeParamDefsItemType::DataType, None, None, Some(_)) => { + Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::DataType }) + } + + // Error cases - Boolean with unexpected parameters + (TypeParamDefsItemType::Boolean, Some(_), _, _) | (TypeParamDefsItemType::Boolean, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::Boolean }) + } + (TypeParamDefsItemType::Boolean, None, None, Some(_)) => { + Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::Boolean }) + } + + // Error cases - Integer with enum options + (TypeParamDefsItemType::Integer, _, _, Some(_)) => { + Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::Integer }) + } + + // Error cases - Enumeration with unexpected parameters + (TypeParamDefsItemType::Enumeration, Some(_), _, _) | (TypeParamDefsItemType::Enumeration, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::Enumeration }) + } + (TypeParamDefsItemType::Enumeration, None, None, None) => { + Err(TypeParamError::MissingEnumOptions) + } + + // Error cases - String with unexpected parameters + (TypeParamDefsItemType::String, Some(_), _, _) | (TypeParamDefsItemType::String, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::String }) + } + (TypeParamDefsItemType::String, None, None, Some(_)) => { + Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::String }) + } + } } } -impl<'a> From> for &'a SimpleExtensionsTypesItem { - fn from(ext_type: ExtensionType<'a>) -> Self { - ext_type.item +/// Type parameter definition for custom types +#[derive(Clone, Debug)] +pub struct TypeParam { + /// Name of the parameter (required) + pub name: String, + /// Optional description of the parameter + pub description: Option, + /// Type-safe parameter constraints + pub kind: ParamKind, +} + +impl TryFrom for TypeParam { + type Error = TypeParamError; + + fn try_from(item: TypeParamDefsItem) -> Result { + let name = item.name.ok_or(TypeParamError::MissingName)?; + + let kind = ParamKind::try_from_item_parts(item.type_, item.min, item.max, item.options)?; + + Ok(Self { + name, + description: item.description, + kind, + }) + } +} + +impl From for TypeParamDefsItem { + fn from(param_def: TypeParam) -> Self { + let (param_type, min, max, options) = match param_def.kind { + ParamKind::DataType => (TypeParamDefsItemType::DataType, None, None, None), + ParamKind::Boolean => (TypeParamDefsItemType::Boolean, None, None, None), + ParamKind::Integer { min, max } => ( + TypeParamDefsItemType::Integer, + min.map(|i| i as f64), + max.map(|i| i as f64), + None, + ), + ParamKind::Enumeration { options } => ( + TypeParamDefsItemType::Enumeration, + None, + None, + Some(EnumOptions(options)), + ), + ParamKind::String => (TypeParamDefsItemType::String, None, None, None), + }; + + Self { + name: Some(param_def.name), + description: param_def.description, + type_: param_type, + min, + max, + optional: None, // Not needed for type definitions + options, + } } } -impl PartialEq for ExtensionType<'_> { +/// Error types for ExtensionType parsing +#[derive(Debug, Error, PartialEq)] +pub enum ExtensionTypeError { + /// Extension type name is invalid + #[error("Invalid extension type name: {name}")] + InvalidName { + /// The invalid name + name: String, + }, + /// Parameter validation failed + #[error("Invalid parameter: {0}")] + InvalidParameter(#[from] TypeParamError), +} + +/// Error types for TypeParam validation +#[derive(Debug, Error, PartialEq)] +pub enum TypeParamError { + /// Parameter name is missing + #[error("Parameter name is required")] + MissingName, + /// Integer parameter has non-integer min/max values + #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")] + InvalidIntegerBounds { + /// The invalid minimum value + min: Option, + /// The invalid maximum value + max: Option + }, + /// Parameter type cannot have min/max bounds + #[error("Parameter type '{param_type}' cannot have min/max bounds")] + UnexpectedMinMaxBounds { + /// The parameter type that cannot have bounds + param_type: TypeParamDefsItemType + }, + /// Parameter type cannot have enumeration options + #[error("Parameter type '{param_type}' cannot have enumeration options")] + UnexpectedEnumOptions { + /// The parameter type that cannot have options + param_type: TypeParamDefsItemType + }, + /// Enumeration parameter is missing required options + #[error("Enumeration parameter is missing required options")] + MissingEnumOptions, +} + +/// A custom type definition +#[derive(Clone, Debug)] +pub struct CustomType { + /// The name of this custom type + pub name: String, + /// Optional description of this type + pub description: Option, + /// What this type actually represents + pub definition: TypeDefinition, + /// Parameters for this type (empty if none) + pub parameters: Vec, + // TODO: Add variadic field for variadic type support +} + +impl PartialEq for CustomType { fn eq(&self, other: &Self) -> bool { - // There should only be one type of a given name per file - self.uri == other.uri && self.item.name == other.item.name + // Name should be unique for a given extension file + self.name == other.name + } +} + +impl CustomType { + /// Get the name of this custom type + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for SimpleExtensionsTypesItem { + fn from(custom_type: CustomType) -> Self { + Self { + name: custom_type.name, + description: custom_type.description, + parameters: if custom_type.parameters.is_empty() { + None + } else { + Some(crate::text::simple_extensions::TypeParamDefs( + custom_type.parameters.into_iter().map(Into::into).collect(), + )) + }, + structure: None, // TODO: Add structure support + variadic: None, // TODO: Add variadic support + } } } @@ -154,17 +404,44 @@ impl PartialEq for ExtensionType<'_> { /// Error for invalid type names in extension definitions pub struct InvalidTypeName(String); -impl<'a> Parse> for &'a SimpleExtensionsTypesItem { - type Parsed = ExtensionType<'a>; - // TODO: Not all names are valid for types, we should validate that - type Error = InvalidTypeName; +impl Parse for SimpleExtensionsTypesItem { + type Parsed = CustomType; + type Error = ExtensionTypeError; - fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { - ctx.add_type(self); - Ok(ExtensionType { - uri: ctx.uri, - item: &self, - }) + fn parse(self, ctx: &mut ExtensionContext) -> Result { + let SimpleExtensionsTypesItem { + name, + description, + parameters, + structure: _, // TODO: Add structure support + variadic: _, // TODO: Add variadic support + } = self; + + // TODO: Not all names are valid for types, we should validate that + if name.is_empty() { + return Err(ExtensionTypeError::InvalidName { name }); + } + + let parameters = match parameters { + Some(type_param_defs) => { + let mut parsed_params = Vec::new(); + for item in type_param_defs.0 { + parsed_params.push(TypeParam::try_from(item)?); + } + parsed_params + } + None => Vec::new(), + }; + + let custom_type = CustomType { + name: name.clone(), + description, + definition: TypeDefinition::Reference(name), // TODO: Parse from structure field + parameters, + }; + + ctx.add_type(&custom_type); + Ok(custom_type) } } @@ -188,74 +465,18 @@ pub enum TypeParseError { UnimplementedVariant, } -/// A validated Type that wraps the original Type with its validated ParsedType representation -#[derive(Debug, Clone)] -pub struct ValidatedType<'a> { - /// The original Type from the YAML - original: &'a extType, - /// The validated, parsed representation - pub parsed: ParsedType<'a>, -} - -impl<'a> ValidatedType<'a> { - /// Get the parsed type representation - pub fn parsed_type(&self) -> &ParsedType<'a> { - &self.parsed - } -} - -impl<'a> From> for &'a extType { - fn from(validated: ValidatedType<'a>) -> Self { - validated.original - } -} - -impl<'a> Parse> for &'a extType { - type Parsed = ValidatedType<'a>; - type Error = TypeParseError; - - fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { - match self { - extType::Variant0(type_str) => { - // Parse the type string into ParsedType - let parsed_type = ParsedType::parse(type_str); - - // Add context validation - match &parsed_type { - ParsedType::NamedExtension(name, _nullable) => { - // Verify the extension type exists in the context - if !ctx.has_type(name) { - return Err(TypeParseError::ExtensionTypeNotFound { - name: name.to_string(), - }); - } - } - ParsedType::TypeVariable(id) | ParsedType::NullableTypeVariable(id) => { - // Validate type variable ID (must be >= 1) - if *id == 0 { - return Err(TypeParseError::InvalidTypeVariableId { id: *id }); - } - } - ParsedType::Builtin(_, _) => { - // Builtin types are always valid - } - ParsedType::Parameterized { .. } => { - // TODO: Add validation for parameterized types - unimplemented!("Parameterized type validation not yet implemented") - } - } +// TODO: ValidatedType will be updated when we implement proper type validation - Ok(ValidatedType { - original: self, - parsed: parsed_type, - }) - } - _ => Err(TypeParseError::UnimplementedVariant), - } - } -} +// TODO: Update this Parse implementation when ValidatedType and ParsedType are converted to owned types +// impl Parse for &extType { +// type Parsed = ValidatedType; +// type Error = TypeParseError; +// fn parse(self, ctx: &mut ExtensionContext) -> Result { +// todo!("Update when ValidatedType and ParsedType are owned") +// } +// } -/// Error for invalid ArgumentsItem specifications +/// Error for invalid ArgumentsItem specifications (TODO: Update when ArgumentPattern is owned) #[derive(Debug, thiserror::Error)] pub enum ArgumentsItemError { /// Type parsing failed @@ -269,75 +490,39 @@ pub enum ArgumentsItemError { }, } -impl<'a> Parse> for &'a ArgumentsItem { - type Parsed = ArgumentPattern<'a>; - type Error = ArgumentsItemError; - - fn parse(self, ctx: &mut ExtensionContext<'a>) -> Result { - match self { - ArgumentsItem::ValueArgument(value_arg) => { - // Parse the Type into ValidatedType, then convert to ArgumentPattern - let validated_type = value_arg.value.parse(ctx)?; - let parsed_type = &validated_type.parsed; - - match parsed_type { - ParsedType::TypeVariable(id) => Ok(ArgumentPattern::TypeVariable(*id)), - ParsedType::NullableTypeVariable(_) => { - panic!("Nullable type variables not allowed in argument position") - } - ParsedType::Builtin(builtin_type, nullable) => Ok(ArgumentPattern::Concrete( - ConcreteType::builtin(*builtin_type, *nullable), - )), - ParsedType::NamedExtension(name, nullable) => { - // Find the extension type by name using the context - // We know it exists because Type parsing already validated it - let ext_type = ctx - .get_type(name) - .expect("Extension type should exist after validation"); - - Ok(ArgumentPattern::Concrete(ConcreteType::extension( - ext_type, *nullable, - ))) - } - ParsedType::Parameterized { .. } => { - unimplemented!("Parameterized types not yet supported in argument patterns") - } - } - } - ArgumentsItem::EnumArgument(_) => Err(ArgumentsItemError::UnsupportedVariant { - variant: "EnumArgument".to_string(), - }), - ArgumentsItem::TypeArgument(_) => Err(ArgumentsItemError::UnsupportedVariant { - variant: "TypeArgument".to_string(), - }), - } - } -} +// TODO: Update this Parse implementation when ArgumentPattern is converted to owned type +// impl Parse for &ArgumentsItem { +// type Parsed = ArgumentPattern; +// type Error = ArgumentsItemError; +// fn parse(self, ctx: &mut ExtensionContext) -> Result { +// todo!("Update when ArgumentPattern is owned") +// } +// } /// Represents a known, specific type, either builtin or extension #[derive(Clone, Debug, PartialEq)] -pub enum KnownType<'a> { +pub enum KnownType { /// Built-in primitive types Builtin(BuiltinType), /// Custom types defined in extension YAML files - Extension(ExtensionType<'a>), + Extension(CustomType), } /// A concrete type, fully specified with nullability and parameters #[derive(Clone, Debug, PartialEq)] -pub struct ConcreteType<'a> { +pub struct ConcreteType { /// Base type, can be builtin or extension - pub base: KnownType<'a>, + pub base: KnownType, /// Is the overall type nullable? pub nullable: bool, // TODO: Add non-type parameters (e.g. integers, enum, etc.) /// Parameters for the type, if there are any - pub parameters: Vec>, + pub parameters: Vec, } -impl<'a> ConcreteType<'a> { +impl ConcreteType { /// Create a concrete type from a builtin type - pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType<'static> { + pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType { ConcreteType { base: KnownType::Builtin(builtin_type), nullable, @@ -345,8 +530,8 @@ impl<'a> ConcreteType<'a> { } } - /// Create a concrete type from an extension type - pub fn extension(t: ExtensionType<'a>, nullable: bool) -> Self { + /// Create a concrete type from a custom type + pub fn extension(t: CustomType, nullable: bool) -> Self { Self { base: KnownType::Extension(t), nullable, @@ -355,11 +540,7 @@ impl<'a> ConcreteType<'a> { } /// Create a parameterized concrete type - pub fn parameterized( - base: KnownType<'a>, - nullable: bool, - parameters: Vec>, - ) -> Self { + pub fn parameterized(base: KnownType, nullable: bool, parameters: Vec) -> Self { Self { base, nullable, @@ -428,29 +609,29 @@ impl<'a> ParsedType<'a> { } } -/// A pattern for function arguments that can match concrete types or type variables +/// A pattern for function arguments that can match concrete types or type variables (TODO: Remove lifetime when ArgumentPattern is owned) #[derive(Clone, Debug, PartialEq)] -pub enum ArgumentPattern<'a> { +pub enum ArgumentPattern { /// Type variable like any1, any2, etc. TypeVariable(u32), /// Concrete type pattern - Concrete(ConcreteType<'a>), + Concrete(ConcreteType), } -/// Result of matching an argument pattern against a concrete type +/// Result of matching an argument pattern against a concrete type (TODO: Remove lifetime when Match is owned) #[derive(Clone, Debug, PartialEq)] -pub enum Match<'a> { +pub enum Match { /// Pattern matched exactly (for concrete patterns) Concrete, /// Type variable bound to concrete type - Variable(u32, ConcreteType<'a>), + Variable(u32, ConcreteType), /// Match failed Fail, } -impl<'a> ArgumentPattern<'a> { +impl<'a> ArgumentPattern { /// Check if this pattern matches the given concrete type - pub fn matches(&self, concrete: &ConcreteType<'a>) -> Match<'a> { + pub fn matches(&self, concrete: &ConcreteType) -> Match { match self { ArgumentPattern::TypeVariable(id) => Match::Variable(*id, concrete.clone()), ArgumentPattern::Concrete(pattern_type) => { @@ -464,16 +645,16 @@ impl<'a> ArgumentPattern<'a> { } } -/// Type variable bindings from matching function arguments +/// Type variable bindings from matching function arguments (TODO: Remove lifetime when TypeBindings is owned) #[derive(Debug, Clone, PartialEq)] -pub struct TypeBindings<'a> { +pub struct TypeBindings { /// Map of type variable IDs (e.g. 1 for 'any1') to their concrete types - pub vars: HashMap>, + pub vars: HashMap, } -impl<'a> TypeBindings<'a> { +impl TypeBindings { /// Create type bindings by matching argument patterns against concrete arguments - pub fn new(patterns: &[ArgumentPattern<'a>], args: &[ConcreteType<'a>]) -> Option { + pub fn new(patterns: &[ArgumentPattern], args: &[ConcreteType]) -> Option { // Check length compatibility if patterns.len() != args.len() { unimplemented!("Handle variadic functions"); @@ -511,7 +692,172 @@ impl<'a> TypeBindings<'a> { } /// Get the bound type for a type variable, if any - pub fn get(&self, var_id: u32) -> Option<&ConcreteType<'a>> { + pub fn get(&self, var_id: u32) -> Option<&ConcreteType> { self.vars.get(&var_id) } } + +#[cfg(test)] +mod tests { + use super::*; + use url::Url; + + #[test] + fn test_extension_type_parse_basic() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri.clone()); + + let original_type_item = SimpleExtensionsTypesItem { + name: "MyType".to_string(), + description: Some("A test type".to_string()), + parameters: None, + structure: None, + variadic: None, + }; + + let result = original_type_item.clone().parse(&mut ctx); + assert!(result.is_ok()); + + let custom_type = result.unwrap(); + assert_eq!(custom_type.name, "MyType"); + assert_eq!(custom_type.description, Some("A test type".to_string())); + assert!(custom_type.parameters.is_empty()); + + // Test round-trip conversion + let converted_back: SimpleExtensionsTypesItem = custom_type.into(); + assert_eq!(converted_back.name, original_type_item.name); + assert_eq!(converted_back.description, original_type_item.description); + // Note: structure and variadic are TODO fields + } + + #[test] + fn test_extension_type_parse_with_parameters() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri.clone()); + + let original_type_item = SimpleExtensionsTypesItem { + name: "ParameterizedType".to_string(), + description: None, + parameters: Some(crate::text::simple_extensions::TypeParamDefs(vec![ + TypeParamDefsItem { + name: Some("length".to_string()), + description: Some("The length parameter".to_string()), + type_: TypeParamDefsItemType::Integer, + min: Some(1.0), + max: Some(1000.0), + optional: Some(false), + options: None, + }, + ])), + structure: None, + variadic: None, + }; + + let result = original_type_item.clone().parse(&mut ctx); + assert!(result.is_ok()); + + let custom_type = result.unwrap(); + assert_eq!(custom_type.name, "ParameterizedType"); + assert_eq!(custom_type.parameters.len(), 1); + + let param = &custom_type.parameters[0]; + assert_eq!(param.name, "length"); + assert_eq!(param.description, Some("The length parameter".to_string())); + if let ParamKind::Integer { min, max } = ¶m.kind { + assert_eq!(*min, Some(1)); + assert_eq!(*max, Some(1000)); + } else { + panic!("Expected Integer parameter kind"); + } + + // Test round-trip conversion + let converted_back: SimpleExtensionsTypesItem = custom_type.into(); + assert_eq!(converted_back.name, original_type_item.name); + assert_eq!(converted_back.description, original_type_item.description); + // Note: parameter and structure comparisons would require PartialEq implementations + } + + #[test] + fn test_extension_type_parse_empty_name_error() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri); + + let type_item = SimpleExtensionsTypesItem { + name: "".to_string(), // Empty name should cause error + description: None, + parameters: None, + structure: None, + variadic: None, + }; + + let result = type_item.parse(&mut ctx); + assert!(result.is_err()); + + if let Err(ExtensionTypeError::InvalidName { name }) = result { + assert_eq!(name, ""); + } else { + panic!("Expected InvalidName error"); + } + } + + #[test] + fn test_extension_context_type_tracking() { + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri.clone()); + + // Initially no types + assert!(!ctx.has_type("MyType")); + + let type_item = SimpleExtensionsTypesItem { + name: "MyType".to_string(), + description: None, + parameters: None, + structure: None, + variadic: None, + }; + + // Parse the type - this should add it to context + let _custom_type = type_item.parse(&mut ctx).unwrap(); + + // Now the context should have the type + assert!(ctx.has_type("MyType")); + + let retrieved_type = ctx.get_type("MyType"); + assert!(retrieved_type.is_some()); + assert_eq!(retrieved_type.unwrap().name, "MyType"); + } + + #[test] + fn test_type_param_conversion() { + let original_param = TypeParamDefsItem { + name: Some("test_param".to_string()), + description: Some("A test parameter".to_string()), + type_: TypeParamDefsItemType::Integer, + min: Some(0.0), + max: Some(100.0), + optional: Some(true), + options: None, + }; + + // Convert to owned TypeParam + let type_param = TypeParam::try_from(original_param.clone()).unwrap(); + assert_eq!(type_param.name, "test_param"); + assert_eq!(type_param.description, Some("A test parameter".to_string())); + + if let ParamKind::Integer { min, max } = type_param.kind { + assert_eq!(min, Some(0)); + assert_eq!(max, Some(100)); + } else { + panic!("Expected Integer parameter kind"); + } + + // Convert back to original type + let converted_back = TypeParamDefsItem::from(type_param); + assert_eq!(converted_back.name, original_param.name); + assert_eq!(converted_back.description, original_param.description); + assert_eq!(converted_back.type_, original_param.type_); + assert_eq!(converted_back.min, original_param.min); + assert_eq!(converted_back.max, original_param.max); + // Note: optional field is not used in our new structure + } +} From facb09be75891ef10ac1cf5fb04bf05c13cc1573 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 15 Aug 2025 17:56:18 -0400 Subject: [PATCH 05/31] Stripped down to just types --- .../proto/extensions/simple_extension_uri.rs | 2 +- src/parse/proto/plan_version.rs | 6 +- src/parse/proto/version.rs | 5 +- src/registry/context.rs | 9 +- src/registry/extension.rs | 400 ++---------------- src/registry/mod.rs | 18 +- src/registry/registry.rs | 181 +++----- src/registry/types.rs | 361 +++++++++++++--- 8 files changed, 417 insertions(+), 565 deletions(-) diff --git a/src/parse/proto/extensions/simple_extension_uri.rs b/src/parse/proto/extensions/simple_extension_uri.rs index efd97db6..a5015f4b 100644 --- a/src/parse/proto/extensions/simple_extension_uri.rs +++ b/src/parse/proto/extensions/simple_extension_uri.rs @@ -8,7 +8,7 @@ use url::Url; use crate::{ parse::{ context::{ContextError, ProtoContext}, - Anchor, Context, Parse, + Anchor, Parse, }, proto, }; diff --git a/src/parse/proto/plan_version.rs b/src/parse/proto/plan_version.rs index 25c06cc2..fdf90e82 100644 --- a/src/parse/proto/plan_version.rs +++ b/src/parse/proto/plan_version.rs @@ -3,11 +3,7 @@ //! Parsing of [proto::PlanVersion]. use crate::{ - parse::{ - context::{Context, ProtoContext}, - proto::Version, - Parse, - }, + parse::{context::ProtoContext, proto::Version, Parse}, proto, }; use thiserror::Error; diff --git a/src/parse/proto/version.rs b/src/parse/proto/version.rs index 63c9b0b8..a248d88e 100644 --- a/src/parse/proto/version.rs +++ b/src/parse/proto/version.rs @@ -3,10 +3,7 @@ //! Parsing of [proto::Version]. use crate::{ - parse::{ - context::{Context, ProtoContext}, - Parse, - }, + parse::{context::ProtoContext, Parse}, proto, version, }; use hex::FromHex; diff --git a/src/registry/context.rs b/src/registry/context.rs index d3f7e006..6675c225 100644 --- a/src/registry/context.rs +++ b/src/registry/context.rs @@ -6,8 +6,8 @@ use std::collections::HashMap; use url::Url; -use crate::parse::Context; use super::types::CustomType; +use crate::parse::Context; /// Context for parsing and validating extension definitions. /// @@ -17,7 +17,7 @@ pub struct ExtensionContext { /// The URI of the extension being parsed. pub uri: Url, /// Map of type names to their validated definitions - types: HashMap, + pub(crate) types: HashMap, } impl ExtensionContext { @@ -36,11 +36,12 @@ impl ExtensionContext { /// Add a type to the context after it has been validated pub(crate) fn add_type(&mut self, custom_type: &CustomType) { - self.types.insert(custom_type.name.clone(), custom_type.clone()); + self.types + .insert(custom_type.name.clone(), custom_type.clone()); } /// Get a type by name from the context, returning the CustomType - pub(crate) fn get_type(&self, name: &str) -> Option<&CustomType> { + pub fn get_type(&self, name: &str) -> Option<&CustomType> { self.types.get(name) } } diff --git a/src/registry/extension.rs b/src/registry/extension.rs index 44cb268c..f3eb02d1 100644 --- a/src/registry/extension.rs +++ b/src/registry/extension.rs @@ -1,399 +1,79 @@ // SPDX-License-Identifier: Apache-2.0 -//! Validated extension file wrapper. +//! Validated extension file wrapper for types. //! //! This module provides `ExtensionFile`, a validated wrapper around SimpleExtensions -//! that ensures extension data is valid on construction and provides safe accessor methods. +//! that focuses on type definitions and provides safe type lookup methods. +use std::collections::HashMap; use thiserror::Error; use url::Url; use crate::parse::Parse; -use crate::registry::types::{ExtensionTypeError, InvalidTypeName}; -use crate::text::simple_extensions::{ - AggregateFunction, AggregateFunctionImplsItem, Arguments, ArgumentsItem, ReturnValue, - ScalarFunction, ScalarFunctionImplsItem, SimpleExtensions, SimpleExtensionsTypesItem, Type, - WindowFunction, WindowFunctionImplsItem, -}; +use crate::registry::types::{CustomType, ExtensionTypeError}; +use crate::text::simple_extensions::SimpleExtensions; use super::context::ExtensionContext; -use super::types::{ArgumentPattern, ConcreteType, ParsedType, TypeBindings}; -/// Errors that can occur during extension validation +/// Errors that can occur during extension type validation #[derive(Debug, Error)] pub enum ValidationError { - /// A function implementation has None for arguments - #[error("Function '{function}' has implementation with missing arguments")] - MissingArguments { - /// The missing function - function: String, - }, - - /// A function implementation is missing a return type - #[error("Function '{function}' has implementation with missing return type")] - MissingReturnType { - /// The missing function - function: String, - }, - - /// A type string is malformed or unrecognized - #[error("Invalid type string '{type_str}' in function '{function}': todo!")] - InvalidArgument { - /// The function containing the invalid type - function: String, - /// The malformed type string - type_str: String, - // Reason why the type is invalid - // reason: todo!(), - }, - - /// A type name is invalid - #[error("{0}")] - InvalidTypeName(InvalidTypeName), - /// Extension type error #[error("Extension type error: {0}")] ExtensionTypeError(#[from] ExtensionTypeError), + /// Unresolved type reference in structure field + #[error("Type '{type_name}' referenced in '{originating}' structure not found")] + UnresolvedTypeReference { + /// The type name that could not be resolved + type_name: String, + /// The type that contains the unresolved reference + originating: String, + }, + /// Structure field cannot be nullable + #[error("Structure representation in type '{originating}' cannot be nullable")] + StructureCannotBeNullable { + /// The type that has a nullable structure + originating: String, + }, } -impl From for ValidationError { - fn from(err: InvalidTypeName) -> Self { - ValidationError::InvalidTypeName(err) - } -} - -// TODO: Add more validation errors for malformed argument patterns, return type patterns, etc. -/// A validated extension file containing functions and types from a single URI. -/// All functions should have valid argument and return type patterns. +/// A validated extension file containing types from a single URI. +/// All types are parsed and validated on construction. #[derive(Debug, Clone)] pub struct ExtensionFile { /// The URI this extension was loaded from pub uri: Url, - /// The validated extension data - extensions: SimpleExtensions, + /// Parsed and validated custom types + types: HashMap, } impl ExtensionFile { /// Create a validated extension file from raw data pub fn create(uri: Url, extensions: SimpleExtensions) -> Result { - // Parse/validate types first - they're referenced by functions + // Parse all types (may contain unresolved Extension(String) references) let mut ctx = ExtensionContext::new(uri.clone()); + let mut types = HashMap::new(); + for type_item in &extensions.types { - let _validated_type = type_item.clone().parse(&mut ctx)?; - } - - // Validate scalar functions - for function in &extensions.scalar_functions { - Self::validate_scalar_function(function)?; - } - - // Validate aggregate functions - for function in &extensions.aggregate_functions { - Self::validate_aggregate_function(function)?; - } - - // Validate window functions - for function in &extensions.window_functions { - Self::validate_window_function(function)?; - } - - Ok(Self { uri, extensions }) - } - - /// Find a scalar function by name - pub fn find_scalar_function(&self, name: &str) -> Option { - self.extensions - .scalar_functions - .iter() - .find(|f| f.name == name) - .map(|f| ScalarFunctionRef { - file: self, - function: f, - }) - } - - /// Find an aggregate function by name - pub fn find_aggregate_function(&self, name: &str) -> Option { - self.extensions - .aggregate_functions - .iter() - .find(|f| f.name == name) - .map(|f| AggregateFunctionRef { - file: self, - function: f, - }) - } - - /// Find a window function by name - pub fn find_window_function(&self, name: &str) -> Option { - self.extensions - .window_functions - .iter() - .find(|f| f.name == name) - .map(|f| WindowFunctionRef { - file: self, - function: f, - }) - } - - /// Find a type by name - pub fn find_type(&self, name: &str) -> Option<&SimpleExtensionsTypesItem> { - let types = self.extensions.types.as_slice(); - types.iter().find(|t| t.name == name) - } - - /// Create an argument pattern from a raw ArgumentsItem - fn argument_pattern_from_item(&self, item: &ArgumentsItem) -> Option { - match item { - ArgumentsItem::ValueArg(value_arg) => self.argument_pattern_from_type(&value_arg.value), - _ => unimplemented!("Handle non-ValueArg argument types"), - } - } - - /// Create an argument pattern from a type string - fn argument_pattern_from_type(&self, type_val: &Type) -> Option { - match type_val { - Type::Variant0(type_str) => { - let parsed_type = ParsedType::parse(type_str); - match parsed_type { - ParsedType::TypeVariable(id) => Some(ArgumentPattern::TypeVariable(id)), - ParsedType::NullableTypeVariable(_) => { - panic!("Nullable type variables not allowed in argument position") - } - ParsedType::Builtin(builtin_type, nullable) => Some(ArgumentPattern::Concrete( - ConcreteType::builtin(builtin_type, nullable), - )), - ParsedType::NamedExtension(name, nullability) => { - // Find the extension type by name using the find_type method - let ext_type = self - .find_type(name) - .expect("This should have been validated"); - - // TODO: Update when ExtensionType is fully integrated and ArgumentPattern is owned - todo!("Update when ExtensionType constructor is available") - } - ParsedType::Parameterized { .. } => { - unimplemented!("Parameterized types not yet supported in argument patterns") - } - } - } - _ => unimplemented!("Handle non-string type variants"), - } - } - - // Private validation methods - - fn validate_scalar_function(function: &ScalarFunction) -> Result<(), ValidationError> { - for impl_item in &function.impls { - // Check that arguments are present (can be empty, but not None) - if impl_item.args.is_none() { - return Err(ValidationError::MissingArguments { - function: function.name.clone(), - }); - } - - // TODO: Validate that return type is well-formed - // For now, we assume return_ field existence is enforced by the type system - } - Ok(()) - } - - fn validate_aggregate_function(function: &AggregateFunction) -> Result<(), ValidationError> { - // Note: args can legitimately be None for functions like count() that count records - // rather than field values, so we don't validate args presence here - Ok(()) - } - - fn validate_window_function(function: &WindowFunction) -> Result<(), ValidationError> { - // Note: args can legitimately be None for some window functions - Ok(()) - } -} - -/// Handle for a validated scalar function definition -pub struct ScalarFunctionRef<'a> { - file: &'a ExtensionFile, - function: &'a ScalarFunction, -} - -impl<'a> ScalarFunctionRef<'a> { - /// Get the function name - pub fn name(&self) -> &str { - &self.function.name - } - - /// Get all implementations as handles to specific type signatures - pub fn implementations(self) -> impl Iterator> { - self.function - .impls - .iter() - .map(move |impl_item| ScalarImplementation { - file: self.file, - impl_item, - }) - } -} - -/// Handle for a validated aggregate function definition -pub struct AggregateFunctionRef<'a> { - file: &'a ExtensionFile, - function: &'a AggregateFunction, -} - -impl<'a> AggregateFunctionRef<'a> { - /// Get the function name - pub fn name(&self) -> &str { - &self.function.name - } - - /// Get all implementations as handles to specific type signatures - pub fn implementations(&self) -> impl Iterator> + '_ { - self.function - .impls - .iter() - .map(move |impl_item| AggregateFunctionImplRef { - file: self.file, - impl_item, - }) - } -} - -/// Handle for a validated window function definition -pub struct WindowFunctionRef<'a> { - file: &'a ExtensionFile, - function: &'a WindowFunction, -} - -impl<'a> WindowFunctionRef<'a> { - /// Get the function name - pub fn name(&self) -> &str { - &self.function.name - } - - /// Get all implementations as handles to specific type signatures - pub fn implementations(&self) -> impl Iterator> + '_ { - self.function - .impls - .iter() - .map(move |impl_item| WindowFunctionImplRef { - file: self.file, - impl_item, - }) - } -} - -/// Handle for a specific scalar function implementation with validated signature -#[derive(Debug, Copy, Clone)] -pub struct ScalarImplementation<'a> { - file: &'a ExtensionFile, - impl_item: &'a ScalarFunctionImplsItem, -} - -impl<'a> ScalarImplementation<'a> { - /// Check if this implementation can be called with the given concrete argument types - /// Returns the inferred concrete return type if the call would succeed, None otherwise - pub fn call_with(&self, concrete_args: &[ConcreteType]) -> Option { - // Convert raw arguments to ArgumentPatterns using ExtensionFile context - let arg_patterns: Vec = self - .impl_item - .args - .as_ref() - .expect("validated to be present") - .iter() - .filter_map(|arg| self.file.argument_pattern_from_item(arg)) - .collect(); - - // Create type bindings by matching patterns against concrete arguments - let _bindings: TypeBindings = TypeBindings::new(&arg_patterns, concrete_args)?; - - if concrete_args.len() > 1_000_000 { - // For lifetime management - return concrete_args.first().cloned(); + let custom_type = type_item.clone().parse(&mut ctx)?; + types.insert(custom_type.name.clone(), custom_type); } - // If arguments match, parse and return the inferred return type - let return_type_str = match &self.impl_item.return_ { - ReturnValue(Type::Variant0(type_str)) => type_str, - _ => unimplemented!("Handle non-string return types"), - }; + // TODO: Validate that all Extension(String) references in structure + // fields exist Walk through all CustomType.structure fields and check + // that Extension(String) references can be resolved to actual types in + // the registry. - let parsed_return_type = ParsedType::parse(return_type_str); - match parsed_return_type { - ParsedType::Builtin(builtin_type, nullable) => { - Some(ConcreteType::builtin(builtin_type, nullable)) - } - ParsedType::TypeVariable(id) => { - // Look up the bound type for this variable - if let Some(bound_type) = _bindings.get(id) { - Some(bound_type.clone()) - } else { - None - } - } - ParsedType::NullableTypeVariable(id) => { - // Look up the bound type and make it nullable - if let Some(mut bound_type) = _bindings.get(id).cloned() { - bound_type.nullable = true; - Some(bound_type) - } else { - None - } - } - ParsedType::NamedExtension(name, nullable) => { - // Find the extension type by name - let ext_type = self - .file - .find_type(name) - .expect("This should have been validated"); - - // TODO: Update when ExtensionType is fully integrated - todo!("Update when ExtensionType constructor is available") - } - ParsedType::Parameterized { .. } => { - unimplemented!("Parameterized return types not yet supported") - } - } + Ok(Self { uri, types }) } -} - -/// Handle for a specific aggregate function implementation with validated signature -pub struct AggregateFunctionImplRef<'a> { - file: &'a ExtensionFile, - impl_item: &'a AggregateFunctionImplsItem, -} - -impl<'a> AggregateFunctionImplRef<'a> { - /// Get the argument signature (guaranteed to be present due to validation) - fn args(&self) -> &Arguments { - self.impl_item - .args - .as_ref() - .expect("validated to be present") - } - - /// Get the return type pattern - fn return_type(&self) -> &ReturnValue { - &self.impl_item.return_ - } -} - -/// Handle for a specific window function implementation with validated signature -pub struct WindowFunctionImplRef<'a> { - file: &'a ExtensionFile, - impl_item: &'a WindowFunctionImplsItem, -} -impl<'a> WindowFunctionImplRef<'a> { - /// Get the argument signature (guaranteed to be present due to validation) - fn args(&self) -> &Arguments { - self.impl_item - .args - .as_ref() - .expect("validated to be present") + /// Get a type by name + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.types.get(name) } - /// Get the return type pattern - fn return_type(&self) -> &ReturnValue { - &self.impl_item.return_ + /// Get an iterator over all types in this extension + pub fn types(&self) -> impl Iterator { + self.types.values() } } diff --git a/src/registry/mod.rs b/src/registry/mod.rs index a032c2f2..ee2fe109 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -18,19 +18,17 @@ //! //! ## Core Types //! -//! - [`ExtensionFile`]: Validated wrapper around a SimpleExtensions + URI -//! - [`ConcreteType`]: Fully-specified types for function arguments and return -//! values -//! - [`GlobalRegistry`]: Immutable registry for URI+name based function lookup +//! - [`ExtensionFile`]: Validated wrapper around SimpleExtensions + URI focusing on type definitions +//! - [`CustomType`]: Parsed and validated extension type definitions +//! - [`Registry`]: Immutable registry for URI+name based type lookup +//! +//! Currently only type definitions are supported. Function parsing will be added in a future update. mod context; mod extension; mod registry; pub mod types; -pub use extension::{ - AggregateFunctionImplRef, AggregateFunctionRef, ExtensionFile, ScalarFunctionRef, - ScalarImplementation, ValidationError, WindowFunctionImplRef, WindowFunctionRef, -}; -pub use registry::GlobalRegistry; -pub use types::ConcreteType; +pub use extension::{ExtensionFile, ValidationError}; +pub use registry::Registry; +pub use types::{ConcreteType, CustomType}; diff --git a/src/registry/registry.rs b/src/registry/registry.rs index b5e96b8c..9bb1a71d 100644 --- a/src/registry/registry.rs +++ b/src/registry/registry.rs @@ -6,62 +6,27 @@ //! - **Global Registry**: Immutable, reusable across plans, URI+name based lookup //! - **Local Registry**: Per-plan, anchor-based, references Global Registry (TODO) //! +//! Currently only type definitions are supported. Function parsing will be added in a future update. +//! //! This module is only available when the `registry` feature is enabled. #![cfg(feature = "registry")] -use thiserror::Error; use url::Url; -use crate::registry::ScalarFunctionRef; - -use super::{types::ConcreteType, ExtensionFile}; - -/// Errors that can occur when using the Global Registry -#[derive(Debug, Error, PartialEq)] -pub enum GlobalRegistryError { - /// The specified extension URI is not registered in this registry - #[error("Unknown extension URI: {0}")] - UnknownExtensionUri(String), - /// The specified function was not found in the given extension - #[error("Function '{function}' not found in extension '{uri}'")] - FunctionNotFound { - /// The extension URI where the function was expected - uri: String, - /// The name of the function that was not found - function: String, - }, - /// No function signature matches the provided arguments - #[error("No matching signature for function '{function}' in extension '{uri}' with provided arguments")] - NoMatchingSignature { - /// The extension URI containing the function - uri: String, - /// The name of the function with no matching signature - function: String, - }, -} - -impl GlobalRegistryError { - /// Create a FunctionNotFound error - pub fn not_found(uri: &Url, function: &str) -> Self { - Self::FunctionNotFound { - uri: uri.to_string(), - function: function.to_string(), - } - } -} +use super::{types::CustomType, ExtensionFile}; -/// Global Extension Registry that manages Substrait extensions +/// Extension Registry that manages Substrait extensions /// /// This registry is immutable and reusable across multiple plans. -/// It provides URI + name based lookup for function validation and signature matching. +/// It provides URI + name based lookup for extension types. Function parsing will be added later. #[derive(Debug)] -pub struct GlobalRegistry { +pub struct Registry { /// Pre-validated extension files extensions: Vec, } -impl GlobalRegistry { +impl Registry { /// Create a new Global Registry from validated extension files pub fn new(extensions: Vec) -> Self { Self { extensions } @@ -98,130 +63,94 @@ impl GlobalRegistry { // Private helper methods - fn get_extension(&self, uri: &Url) -> Result<&ExtensionFile, GlobalRegistryError> { - self.extensions - .iter() - .find(|ext| &ext.uri == uri) - .ok_or_else(|| GlobalRegistryError::UnknownExtensionUri(uri.to_string())) + fn get_extension(&self, uri: &Url) -> Option<&ExtensionFile> { + self.extensions.iter().find(|ext| &ext.uri == uri) } - /// Validate a scalar function call and return the concrete return type - pub fn validate_scalar_call<'a>( - &'a self, - uri: &Url, - name: &str, - args: &[ConcreteType], - ) -> Result { - let extension: &'a ExtensionFile = self.get_extension(uri)?; - let function_ref: ScalarFunctionRef<'a> = extension - .find_scalar_function(name) - .ok_or_else(|| GlobalRegistryError::not_found(uri, name))?; - - // Try each implementation until one matches - for impl_ref in function_ref.implementations() { - if let Some(return_type) = impl_ref.call_with(args) { - return Ok(return_type); - } - } - - Err(GlobalRegistryError::NoMatchingSignature { - uri: uri.to_string(), - function: name.to_string(), - }) + /// Get a type by URI and name + pub fn get_type(&self, uri: &Url, name: &str) -> Option<&CustomType> { + self.get_extension(uri)?.get_type(name) } } #[cfg(test)] mod tests { use super::*; - use crate::registry::types::{BuiltinType, ConcreteType}; use crate::text::simple_extensions::*; - fn create_test_extension() -> SimpleExtensions { + fn create_test_extension_with_types() -> SimpleExtensions { SimpleExtensions { - scalar_functions: vec![ScalarFunction { - name: "add".to_string(), - description: Some("Addition function".to_string()), - impls: vec![ScalarFunctionImplsItem { - args: Some(Arguments(vec![])), // Simplified for testing. TODO: Add real args - return_: ReturnValue(Type::Variant0("i32".to_string())), - deterministic: None, - implementation: None, - nullability: None, - options: None, - session_dependent: None, - variadic: None, - }], - }], + scalar_functions: vec![], aggregate_functions: vec![], window_functions: vec![], dependencies: Default::default(), type_variations: vec![], - types: vec![], + types: vec![SimpleExtensionsTypesItem { + name: "test_type".to_string(), + description: Some("A test type".to_string()), + parameters: None, + structure: None, + variadic: None, + }], } } #[test] fn test_new_registry() { let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let extension_file = ExtensionFile::create(uri.clone(), create_test_extension()).unwrap(); + let extension_file = + ExtensionFile::create(uri.clone(), create_test_extension_with_types()).unwrap(); let extensions = vec![extension_file]; - let registry = GlobalRegistry::new(extensions); + let registry = Registry::new(extensions); assert_eq!(registry.extensions().count(), 1); let extension_uris: Vec<&Url> = registry.extensions().map(|ext| &ext.uri).collect(); assert!(extension_uris.contains(&&uri)); } #[test] - fn test_validate_scalar_call_with_test_extension() { + fn test_type_lookup() { let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let extension_file = ExtensionFile::create(uri.clone(), create_test_extension()).unwrap(); + let extension_file = + ExtensionFile::create(uri.clone(), create_test_extension_with_types()).unwrap(); let extensions = vec![extension_file]; - let registry = GlobalRegistry::new(extensions); - let args: &[ConcreteType] = &[]; // Empty ConcreteType args + let registry = Registry::new(extensions); + + // Test successful type lookup + let found_type = registry.get_type(&uri, "test_type"); + assert!(found_type.is_some()); + assert_eq!(found_type.unwrap().name, "test_type"); - let result = registry.validate_scalar_call(&uri, "add", args); - assert!(result.is_ok()); + // Test missing type lookup + let missing_type = registry.get_type(&uri, "nonexistent_type"); + assert!(missing_type.is_none()); + + // Test missing extension lookup + let wrong_uri = Url::parse("https://example.com/wrong.yaml").unwrap(); + let missing_extension = registry.get_type(&wrong_uri, "test_type"); + assert!(missing_extension.is_none()); } #[test] - fn test_standard_extension() { - let registry = GlobalRegistry::from_core_extensions(); - let arithmetic_uri = Url::parse("https://github.com/substrait-io/substrait/raw/v0.57.0/extensions/functions_arithmetic.yaml").unwrap(); + fn test_from_core_extensions() { + let registry = Registry::from_core_extensions(); + assert!(registry.extensions().count() > 0); - // Test that add function fails with no arguments (should require 2 arguments) - let no_args: &[ConcreteType] = &[]; - let result_no_args = registry.validate_scalar_call(&arithmetic_uri, "add", no_args); - assert!( - result_no_args.is_err(), - "add function should fail with no arguments" - ); + // Find the unknown.yaml extension dynamically + let unknown_extension = registry + .extensions() + .find(|ext| ext.uri.path_segments().map(|s| s.last()) == Some(Some("unknown.yaml"))) + .expect("Should find unknown.yaml extension"); - // Test that add function succeeds with two i32 arguments and returns i32 - let i32_args = &[ - ConcreteType::builtin(BuiltinType::I32, false), - ConcreteType::builtin(BuiltinType::I32, false), - ]; - let result_with_args = registry.validate_scalar_call(&arithmetic_uri, "add", i32_args); + let unknown_type = unknown_extension.get_type("unknown"); assert!( - result_with_args.is_ok(), - "add function should succeed with two i32 arguments" + unknown_type.is_some(), + "Should find 'unknown' type in unknown.yaml extension" ); - // Verify it returns the correct concrete type (i32) - let return_type = result_with_args.unwrap(); - assert_eq!( - return_type, - ConcreteType::builtin(BuiltinType::I32, false), - "add(i32, i32) should return i32" - ); - } - - #[test] - fn test_from_core_extensions() { - let registry = GlobalRegistry::from_core_extensions(); - assert!(registry.extensions().count() > 0); + // Also test the registry's get_type method with the actual URI + let unknown_type_via_registry = registry.get_type(&unknown_extension.uri, "unknown"); + assert!(unknown_type_via_registry.is_some()); } } diff --git a/src/registry/types.rs b/src/registry/types.rs index 71e6a9fb..1192d35a 100644 --- a/src/registry/types.rs +++ b/src/registry/types.rs @@ -8,9 +8,10 @@ use crate::parse::Parse; use crate::registry::context::ExtensionContext; use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, TypeParamDefsItem, + EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefsItem, TypeParamDefsItemType, }; +use serde_json::Value; use std::collections::HashMap; use std::str::FromStr; use thiserror::Error; @@ -133,15 +134,6 @@ pub enum ParameterType { String, } -/// What a type actually represents - either a reference to another type or a compound structure -#[derive(Clone, Debug)] -pub enum TypeDefinition { - /// Reference to another type by name (e.g., "i32", "string", or custom type name) - Reference(String), - /// Compound structure with named fields - Struct(HashMap), -} - /// Type-safe parameter constraints based on parameter kind #[derive(Clone, Debug)] pub enum ParamKind { @@ -150,23 +142,26 @@ pub enum ParamKind { /// True/False parameter Boolean, /// Integer parameter with optional bounds - Integer { + Integer { /// Minimum value constraint - min: Option, + min: Option, /// Maximum value constraint - max: Option + max: Option, }, /// Enumeration parameter with predefined options - Enumeration { + Enumeration { /// Valid enumeration values - options: Vec + options: Vec, }, /// String parameter String, } impl ParamKind { - fn get_integer_bounds(min: Option, max: Option) -> Result<(Option, Option), TypeParamError> { + fn get_integer_bounds( + min: Option, + max: Option, + ) -> Result<(Option, Option), TypeParamError> { // Convert float bounds to integers, validating they are whole numbers let min_bound = if let Some(min_f) = min { if min_f.fract() != 0.0 { @@ -202,48 +197,73 @@ impl ParamKind { (TypeParamDefsItemType::Boolean, None, None, None) => Ok(ParamKind::Boolean), (TypeParamDefsItemType::Integer, min, max, None) => { let (min_bound, max_bound) = Self::get_integer_bounds(min, max)?; - Ok(ParamKind::Integer { min: min_bound, max: max_bound }) + Ok(ParamKind::Integer { + min: min_bound, + max: max_bound, + }) } (TypeParamDefsItemType::Enumeration, None, None, Some(enum_options)) => { - Ok(ParamKind::Enumeration { options: enum_options.0 }) + Ok(ParamKind::Enumeration { + options: enum_options.0, + }) } (TypeParamDefsItemType::String, None, None, None) => Ok(ParamKind::String), - + // Error cases - DataType with unexpected parameters - (TypeParamDefsItemType::DataType, Some(_), _, _) | (TypeParamDefsItemType::DataType, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::DataType }) + (TypeParamDefsItemType::DataType, Some(_), _, _) + | (TypeParamDefsItemType::DataType, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { + param_type: TypeParamDefsItemType::DataType, + }) } (TypeParamDefsItemType::DataType, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::DataType }) + Err(TypeParamError::UnexpectedEnumOptions { + param_type: TypeParamDefsItemType::DataType, + }) } - - // Error cases - Boolean with unexpected parameters - (TypeParamDefsItemType::Boolean, Some(_), _, _) | (TypeParamDefsItemType::Boolean, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::Boolean }) + + // Error cases - Boolean with unexpected parameters + (TypeParamDefsItemType::Boolean, Some(_), _, _) + | (TypeParamDefsItemType::Boolean, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { + param_type: TypeParamDefsItemType::Boolean, + }) } (TypeParamDefsItemType::Boolean, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::Boolean }) + Err(TypeParamError::UnexpectedEnumOptions { + param_type: TypeParamDefsItemType::Boolean, + }) } - + // Error cases - Integer with enum options (TypeParamDefsItemType::Integer, _, _, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::Integer }) + Err(TypeParamError::UnexpectedEnumOptions { + param_type: TypeParamDefsItemType::Integer, + }) } - + // Error cases - Enumeration with unexpected parameters - (TypeParamDefsItemType::Enumeration, Some(_), _, _) | (TypeParamDefsItemType::Enumeration, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::Enumeration }) + (TypeParamDefsItemType::Enumeration, Some(_), _, _) + | (TypeParamDefsItemType::Enumeration, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { + param_type: TypeParamDefsItemType::Enumeration, + }) } (TypeParamDefsItemType::Enumeration, None, None, None) => { Err(TypeParamError::MissingEnumOptions) } - + // Error cases - String with unexpected parameters - (TypeParamDefsItemType::String, Some(_), _, _) | (TypeParamDefsItemType::String, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { param_type: TypeParamDefsItemType::String }) + (TypeParamDefsItemType::String, Some(_), _, _) + | (TypeParamDefsItemType::String, _, Some(_), _) => { + Err(TypeParamError::UnexpectedMinMaxBounds { + param_type: TypeParamDefsItemType::String, + }) } (TypeParamDefsItemType::String, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { param_type: TypeParamDefsItemType::String }) + Err(TypeParamError::UnexpectedEnumOptions { + param_type: TypeParamDefsItemType::String, + }) } } } @@ -320,6 +340,15 @@ pub enum ExtensionTypeError { /// Parameter validation failed #[error("Invalid parameter: {0}")] InvalidParameter(#[from] TypeParamError), + /// Field type is invalid + #[error("Invalid structure field type: {0}")] + InvalidFieldType(String), + /// Structure representation cannot be nullable + #[error("Structure representation cannot be nullable: {type_string}")] + StructureCannotBeNullable { + /// The type string that was nullable + type_string: String, + }, } /// Error types for TypeParam validation @@ -330,23 +359,23 @@ pub enum TypeParamError { MissingName, /// Integer parameter has non-integer min/max values #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")] - InvalidIntegerBounds { + InvalidIntegerBounds { /// The invalid minimum value - min: Option, + min: Option, /// The invalid maximum value - max: Option + max: Option, }, /// Parameter type cannot have min/max bounds #[error("Parameter type '{param_type}' cannot have min/max bounds")] - UnexpectedMinMaxBounds { + UnexpectedMinMaxBounds { /// The parameter type that cannot have bounds - param_type: TypeParamDefsItemType + param_type: TypeParamDefsItemType, }, /// Parameter type cannot have enumeration options #[error("Parameter type '{param_type}' cannot have enumeration options")] - UnexpectedEnumOptions { + UnexpectedEnumOptions { /// The parameter type that cannot have options - param_type: TypeParamDefsItemType + param_type: TypeParamDefsItemType, }, /// Enumeration parameter is missing required options #[error("Enumeration parameter is missing required options")] @@ -360,8 +389,9 @@ pub struct CustomType { pub name: String, /// Optional description of this type pub description: Option, - /// What this type actually represents - pub definition: TypeDefinition, + /// How this type is represented (None = opaque, Some = structured representation) + /// If Some, nullable MUST be false + pub structure: Option, /// Parameters for this type (empty if none) pub parameters: Vec, // TODO: Add variadic field for variadic type support @@ -413,8 +443,8 @@ impl Parse for SimpleExtensionsTypesItem { name, description, parameters, - structure: _, // TODO: Add structure support - variadic: _, // TODO: Add variadic support + structure, + variadic: _, // TODO: Add variadic support } = self; // TODO: Not all names are valid for types, we should validate that @@ -433,10 +463,16 @@ impl Parse for SimpleExtensionsTypesItem { None => Vec::new(), }; + // Parse structure field if present + let structure = match structure { + Some(structure_data) => Some(ConcreteType::try_from(structure_data)?), + None => None, // Opaque type + }; + let custom_type = CustomType { name: name.clone(), description, - definition: TypeDefinition::Reference(name), // TODO: Parse from structure field + structure, parameters, }; @@ -445,6 +481,60 @@ impl Parse for SimpleExtensionsTypesItem { } } +impl TryFrom for ConcreteType { + type Error = ExtensionTypeError; + + fn try_from(ext_type: ExtType) -> Result { + match ext_type { + // Case: structure: "BINARY" (alias to another type) + ExtType::Variant0(type_string) => { + let parsed_type = ParsedType::parse(&type_string); + let concrete_type = ConcreteType::try_from(parsed_type)?; + + // Structure representation cannot be nullable + if concrete_type.nullable { + return Err(ExtensionTypeError::InvalidName { + name: format!( + "Structure representation '{}' cannot be nullable", + type_string + ), + }); + } + + Ok(concrete_type) + } + // Case: structure: { field1: type1, field2: type2 } (named struct) + ExtType::Variant1(field_map) => { + let mut field_names = Vec::new(); + let mut field_types = Vec::new(); + + for (field_name, field_type_variant) in field_map { + field_names.push(field_name); + + let field_type_str = match field_type_variant { + Value::String(s) => s, + _ => { + return Err(ExtensionTypeError::InvalidName { + name: field_type_variant.to_string(), + }) + } + }; + + let parsed_field_type = ParsedType::parse(&field_type_str); + let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; + field_types.push(field_concrete_type); + } + + Ok(ConcreteType { + base: KnownType::NStruct(field_names), + nullable: false, // Structure representation cannot be nullable + parameters: field_types, + }) + } + } + } +} + /// Error for invalid Type specifications #[derive(Debug, thiserror::Error)] pub enum TypeParseError { @@ -499,13 +589,31 @@ pub enum ArgumentsItemError { // } // } -/// Represents a known, specific type, either builtin or extension +/// Represents a known, specific type, either builtin, extension reference, or structured #[derive(Clone, Debug, PartialEq)] pub enum KnownType { /// Built-in primitive types Builtin(BuiltinType), - /// Custom types defined in extension YAML files - Extension(CustomType), + /// Custom types defined in extension YAML files (unresolved reference) + Extension(String), + /// Named struct with field names (corresponds to Substrait's NSTRUCT pseudo-type) + NStruct(Vec), +} + +impl FromStr for KnownType { + type Err = ExtensionTypeError; + + fn from_str(s: &str) -> Result { + // First try to parse as a builtin type + match BuiltinType::from_str(s) { + Ok(builtin) => Ok(KnownType::Builtin(builtin)), + Err(_) => { + // TODO: Validate that the string is a valid type name + // For now, treat all non-builtin strings as extension type references + Ok(KnownType::Extension(s.to_string())) + } + } + } } /// A concrete type, fully specified with nullability and parameters @@ -530,15 +638,28 @@ impl ConcreteType { } } - /// Create a concrete type from a custom type - pub fn extension(t: CustomType, nullable: bool) -> Self { + /// Create a concrete type from an extension type name + pub fn extension(type_name: impl Into, nullable: bool) -> Self { Self { - base: KnownType::Extension(t), + base: KnownType::Extension(type_name.into()), nullable, parameters: Vec::new(), } } + /// Create a concrete type for a named struct (NSTRUCT) + pub fn nstruct( + field_names: Vec, + field_types: Vec, + nullable: bool, + ) -> Self { + Self { + base: KnownType::NStruct(field_names), + nullable, + parameters: field_types, + } + } + /// Create a parameterized concrete type pub fn parameterized(base: KnownType, nullable: bool, parameters: Vec) -> Self { Self { @@ -549,6 +670,40 @@ impl ConcreteType { } } +impl<'a> TryFrom> for ConcreteType { + type Error = ExtensionTypeError; + + fn try_from(parsed: ParsedType<'a>) -> Result { + match parsed { + ParsedType::Builtin(builtin_type, nullable) => { + Ok(ConcreteType::builtin(builtin_type, nullable)) + } + ParsedType::NamedExtension(type_name, nullable) => { + Ok(ConcreteType::extension(type_name.to_string(), nullable)) + } + ParsedType::TypeVariable(_) | ParsedType::NullableTypeVariable(_) => { + Err(ExtensionTypeError::InvalidName { + name: "Type variables not allowed in structure definitions".to_string(), + }) + } + ParsedType::Parameterized { + base, + parameters, + nullable, + } => { + let base_concrete = ConcreteType::try_from(*base)?; + let param_concretes: Result, _> = + parameters.into_iter().map(ConcreteType::try_from).collect(); + Ok(ConcreteType::parameterized( + base_concrete.base, + nullable, + param_concretes?, + )) + } + } + } +} + /// A parsed type that can represent type variables, builtin types, extension types, or parameterized types #[derive(Clone, Debug, PartialEq)] pub enum ParsedType<'a> { @@ -700,6 +855,7 @@ impl TypeBindings { #[cfg(test)] mod tests { use super::*; + use serde_json::json; use url::Url; #[test] @@ -843,7 +999,7 @@ mod tests { let type_param = TypeParam::try_from(original_param.clone()).unwrap(); assert_eq!(type_param.name, "test_param"); assert_eq!(type_param.description, Some("A test parameter".to_string())); - + if let ParamKind::Integer { min, max } = type_param.kind { assert_eq!(min, Some(0)); assert_eq!(max, Some(100)); @@ -860,4 +1016,99 @@ mod tests { assert_eq!(converted_back.max, original_param.max); // Note: optional field is not used in our new structure } + + #[test] + fn test_simple_type_no_structure() { + // Test a simple opaque type (no structure field) + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri); + + let type_item = SimpleExtensionsTypesItem { + name: "unknown".to_string(), + description: Some("An opaque type".to_string()), + parameters: None, + structure: None, // Opaque type + variadic: None, + }; + + let result = type_item.parse(&mut ctx); + assert!(result.is_ok()); + + let custom_type = result.unwrap(); + assert_eq!(custom_type.name, "unknown"); + assert_eq!(custom_type.description, Some("An opaque type".to_string())); + assert!(custom_type.structure.is_none()); // Should be None for opaque type + assert!(custom_type.parameters.is_empty()); + } + + #[test] + fn test_types_with_structure() { + // Test a type with structure: "BINARY" (alias) + let uri = Url::parse("https://example.com/test.yaml").unwrap(); + let mut ctx = ExtensionContext::new(uri); + + let type_item = SimpleExtensionsTypesItem { + name: "coordinate".to_string(), + description: Some("A coordinate in some form".to_string()), + parameters: None, + structure: Some(ExtType::Variant0("fp64".to_string())), // Alias to fp64 + variadic: None, + }; + + let result = type_item.parse(&mut ctx); + assert!(result.is_ok()); + + let custom_type = result.unwrap(); + assert_eq!(custom_type.name, "coordinate"); + assert!(custom_type.structure.is_some()); + + let structure = custom_type.structure.unwrap(); + assert!(!structure.nullable); // Structure cannot be nullable + assert!(matches!( + structure.base, + KnownType::Builtin(BuiltinType::Fp64) + )); + + // Create a map structure like { latitude: "coordinate", longitude: "coordinate" } + let mut field_map = serde_json::Map::new(); + field_map.insert("latitude".to_string(), json!("coordinate")); + field_map.insert("longitude".to_string(), json!("coordinate")); + + let type_item = SimpleExtensionsTypesItem { + name: "point".to_string(), + description: Some("A 2D point".to_string()), + parameters: None, + structure: Some(ExtType::Variant1(field_map)), + variadic: None, + }; + + let result = type_item.parse(&mut ctx); + assert!(result.is_ok()); + + let custom_type = result.unwrap(); + assert_eq!(custom_type.name, "point"); + assert!(custom_type.structure.is_some()); + + let structure = custom_type.structure.unwrap(); + assert!(!structure.nullable); // Structure cannot be nullable + + // Should be NStruct with field names + if let KnownType::NStruct(field_names) = structure.base { + assert_eq!(field_names.len(), 2); + assert!(field_names.contains(&"latitude".to_string())); + assert!(field_names.contains(&"longitude".to_string())); + } else { + panic!("Expected NStruct base type"); + } + + // Should have 2 field types (parameters) + assert_eq!(structure.parameters.len(), 2); + for param in &structure.parameters { + if let KnownType::Extension(ref type_name) = param.base { + assert_eq!(type_name, "coordinate"); + } else { + panic!("Expected Extension type for coordinate reference"); + } + } + } } From 5a5ac2381330cceaf34453da0eb1e7548b36cac6 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Mon, 18 Aug 2025 13:43:16 -0400 Subject: [PATCH 06/31] Enable test only with correct feature flag --- src/registry/registry.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/registry/registry.rs b/src/registry/registry.rs index 9bb1a71d..1a0445b8 100644 --- a/src/registry/registry.rs +++ b/src/registry/registry.rs @@ -132,6 +132,7 @@ mod tests { assert!(missing_extension.is_none()); } + #[cfg(feature = "extensions")] #[test] fn test_from_core_extensions() { let registry = Registry::from_core_extensions(); From 78696769b8d41ee74ba8eececaa231bd350556ca Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Mon, 18 Aug 2025 13:58:32 -0400 Subject: [PATCH 07/31] Code cleanup --- Cargo.toml | 4 ++-- src/registry/mod.rs | 1 + src/registry/registry.rs | 7 ++----- src/registry/types.rs | 13 +++++-------- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ce48694..b0f48028 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,9 @@ include = [ [features] default = [] extensions = ["dep:serde_yaml", "dep:url"] -parse = ["dep:hex", "semver"] +parse = ["dep:hex", "dep:thiserror", "dep:url", "semver"] protoc = ["dep:protobuf-src"] -registry = ["dep:thiserror", "dep:url", "parse"] +registry = ["parse"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] diff --git a/src/registry/mod.rs b/src/registry/mod.rs index ee2fe109..82770902 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -26,6 +26,7 @@ mod context; mod extension; +#[allow(clippy::module_inception)] mod registry; pub mod types; diff --git a/src/registry/registry.rs b/src/registry/registry.rs index 1a0445b8..1d4baa69 100644 --- a/src/registry/registry.rs +++ b/src/registry/registry.rs @@ -19,7 +19,7 @@ use super::{types::CustomType, ExtensionFile}; /// Extension Registry that manages Substrait extensions /// /// This registry is immutable and reusable across multiple plans. -/// It provides URI + name based lookup for extension types. Function parsing will be added later. +/// It provides URI + name based lookup for extension types. Function parsing will be added in a future update. #[derive(Debug)] pub struct Registry { /// Pre-validated extension files @@ -51,10 +51,7 @@ impl Registry { .iter() .map(|(uri, simple_extensions)| { ExtensionFile::create(uri.clone(), simple_extensions.clone()) - .map_err(|err| { - eprintln!("Failed to create extension file for {}: {}", uri, err); - }) - .expect("Core extensions should be valid") + .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {uri}: {err}")) }) .collect(); diff --git a/src/registry/types.rs b/src/registry/types.rs index 1192d35a..de1fdb87 100644 --- a/src/registry/types.rs +++ b/src/registry/types.rs @@ -8,7 +8,7 @@ use crate::parse::Parse; use crate::registry::context::ExtensionContext; use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefsItem, + EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefs, TypeParamDefsItem, TypeParamDefsItemType, }; use serde_json::Value; @@ -419,7 +419,7 @@ impl From for SimpleExtensionsTypesItem { parameters: if custom_type.parameters.is_empty() { None } else { - Some(crate::text::simple_extensions::TypeParamDefs( + Some(TypeParamDefs( custom_type.parameters.into_iter().map(Into::into).collect(), )) }, @@ -494,10 +494,7 @@ impl TryFrom for ConcreteType { // Structure representation cannot be nullable if concrete_type.nullable { return Err(ExtensionTypeError::InvalidName { - name: format!( - "Structure representation '{}' cannot be nullable", - type_string - ), + name: format!("Structure representation '{type_string}' cannot be nullable"), }); } @@ -784,7 +781,7 @@ pub enum Match { Fail, } -impl<'a> ArgumentPattern { +impl ArgumentPattern { /// Check if this pattern matches the given concrete type pub fn matches(&self, concrete: &ConcreteType) -> Match { match self { @@ -894,7 +891,7 @@ mod tests { let original_type_item = SimpleExtensionsTypesItem { name: "ParameterizedType".to_string(), description: None, - parameters: Some(crate::text::simple_extensions::TypeParamDefs(vec![ + parameters: Some(TypeParamDefs(vec![ TypeParamDefsItem { name: Some("length".to_string()), description: Some("The length parameter".to_string()), From bc2a86bd8b68da247ff05c517366c690c14f3f2a Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Mon, 18 Aug 2025 16:08:39 -0400 Subject: [PATCH 08/31] Merged with existing parse --- Cargo.toml | 1 - src/lib.rs | 3 - src/parse/context.rs | 14 +- .../proto/extensions/simple_extension_uri.rs | 2 +- src/parse/proto/plan_version.rs | 2 +- src/parse/proto/version.rs | 2 +- src/parse/text/simple_extensions/argument.rs | 4 +- src/parse/text/simple_extensions/context.rs | 56 + .../text/simple_extensions}/extension.rs | 4 +- src/parse/text/simple_extensions/mod.rs | 98 +- .../text/simple_extensions}/registry.rs | 8 +- src/parse/text/simple_extensions/types.rs | 1093 ++++++++++++++++ src/registry/context.rs | 51 - src/registry/mod.rs | 35 - src/registry/types.rs | 1111 ----------------- 15 files changed, 1249 insertions(+), 1235 deletions(-) create mode 100644 src/parse/text/simple_extensions/context.rs rename src/{registry => parse/text/simple_extensions}/extension.rs (96%) rename src/{registry => parse/text/simple_extensions}/registry.rs (97%) create mode 100644 src/parse/text/simple_extensions/types.rs delete mode 100644 src/registry/context.rs delete mode 100644 src/registry/mod.rs delete mode 100644 src/registry/types.rs diff --git a/Cargo.toml b/Cargo.toml index b0f48028..78358392 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,6 @@ default = [] extensions = ["dep:serde_yaml", "dep:url"] parse = ["dep:hex", "dep:thiserror", "dep:url", "semver"] protoc = ["dep:protobuf-src"] -registry = ["parse"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] diff --git a/src/lib.rs b/src/lib.rs index f5376cee..508f8387 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,8 +122,5 @@ pub mod proto; pub mod text; pub mod version; -#[cfg(feature = "registry")] -pub mod registry; - #[cfg(feature = "parse")] pub mod parse; diff --git a/src/parse/context.rs b/src/parse/context.rs index 0eb54f08..994d29a4 100644 --- a/src/parse/context.rs +++ b/src/parse/context.rs @@ -59,7 +59,7 @@ pub enum ContextError { } #[cfg(test)] -pub(crate) mod tests { +pub(crate) mod fixtures { use std::collections::{hash_map::Entry, HashMap}; use crate::parse::{ @@ -72,14 +72,12 @@ pub(crate) mod tests { /// This currently mocks support for simple extensions (does not resolve or /// parse). pub struct Context { - empty_simple_extensions: SimpleExtensions, - simple_extensions: HashMap, SimpleExtensionUri>, + simple_extensions: HashMap, SimpleExtensions>, } impl Default for Context { fn default() -> Self { Self { - empty_simple_extensions: SimpleExtensions {}, simple_extensions: Default::default(), } } @@ -100,9 +98,10 @@ pub(crate) mod tests { // This is where we would resolve and then parse. // This check shows the use of the unsupported uri error. if let "http" | "https" | "file" = simple_extension_uri.uri().scheme() { - entry.insert(simple_extension_uri.clone()); + let ext = entry + .insert(SimpleExtensions::empty(simple_extension_uri.uri().clone())); // Here we just return an empty simple extensions. - Ok(&self.empty_simple_extensions) + Ok(ext) } else { Err(ContextError::UnsupportedURI(format!( "`{}` scheme not supported", @@ -118,8 +117,7 @@ pub(crate) mod tests { anchor: &Anchor, ) -> Result<&SimpleExtensions, ContextError> { self.simple_extensions - .contains_key(anchor) - .then_some(&self.empty_simple_extensions) + .get(anchor) .ok_or(ContextError::UndefinedSimpleExtension(*anchor)) } } diff --git a/src/parse/proto/extensions/simple_extension_uri.rs b/src/parse/proto/extensions/simple_extension_uri.rs index a5015f4b..016179eb 100644 --- a/src/parse/proto/extensions/simple_extension_uri.rs +++ b/src/parse/proto/extensions/simple_extension_uri.rs @@ -91,7 +91,7 @@ impl From for proto::extensions::SimpleExtensionUri { #[cfg(test)] mod tests { use super::*; - use crate::parse::{context::tests::Context, Context as _}; + use crate::parse::{context::fixtures::Context, Context as _}; #[test] fn parse() -> Result<(), SimpleExtensionUriError> { diff --git a/src/parse/proto/plan_version.rs b/src/parse/proto/plan_version.rs index fdf90e82..c7432df1 100644 --- a/src/parse/proto/plan_version.rs +++ b/src/parse/proto/plan_version.rs @@ -71,7 +71,7 @@ impl From for proto::PlanVersion { mod tests { use super::*; use crate::{ - parse::{context::tests::Context, proto::VersionError}, + parse::{context::fixtures::Context, proto::VersionError}, version, }; diff --git a/src/parse/proto/version.rs b/src/parse/proto/version.rs index a248d88e..e7858b83 100644 --- a/src/parse/proto/version.rs +++ b/src/parse/proto/version.rs @@ -142,7 +142,7 @@ impl From for proto::Version { #[cfg(test)] mod tests { use super::*; - use crate::parse::context::tests::Context; + use crate::parse::context::fixtures::Context; #[test] fn version() -> Result<(), VersionError> { diff --git a/src/parse/text/simple_extensions/argument.rs b/src/parse/text/simple_extensions/argument.rs index a83bd7a0..bda09471 100644 --- a/src/parse/text/simple_extensions/argument.rs +++ b/src/parse/text/simple_extensions/argument.rs @@ -371,7 +371,7 @@ impl From for ArgumentsItem { mod tests { use super::*; use crate::text::simple_extensions; - use crate::{parse::context::tests::Context, text}; + use crate::{parse::context::fixtures::Context, text}; #[test] fn parse_enum_argument() -> Result<(), ArgumentsItemError> { @@ -677,7 +677,7 @@ mod tests { #[test] fn parse_extensions() { use crate::extensions::EXTENSIONS; - use crate::parse::context::tests::Context; + use crate::parse::context::fixtures::Context; macro_rules! parse_arguments { ($url:expr, $fns:expr) => { diff --git a/src/parse/text/simple_extensions/context.rs b/src/parse/text/simple_extensions/context.rs new file mode 100644 index 00000000..e88dd48b --- /dev/null +++ b/src/parse/text/simple_extensions/context.rs @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Parsing context for extension processing. + +use std::collections::HashMap; + +use url::Url; + +use super::types::CustomType; +use crate::parse::Context; + +/// Parsing context for extension processing +/// +/// The context provides access to types defined in the same extension file during parsing. +/// This allows type references to be resolved within the same extension file. +#[derive(Debug)] +pub struct ExtensionContext { + /// The URI this extension is being loaded from + pub uri: Url, + /// Types defined in this extension file + types: HashMap, +} + +impl ExtensionContext { + /// Create a new extension context for the given URI + pub fn new(uri: Url) -> Self { + Self { + uri, + types: HashMap::new(), + } + } + + /// Add a type to the context + pub fn add_type(&mut self, custom_type: &CustomType) { + self.types.insert(custom_type.name.clone(), custom_type.clone()); + } + + /// Check if a type with the given name exists in the context + pub fn has_type(&self, name: &str) -> bool { + self.types.contains_key(name) + } + + /// Get a type by name from the context + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.types.get(name) + } + + /// Get an iterator over all types in the context + pub fn types(&self) -> impl Iterator { + self.types.values() + } +} + +impl Context for ExtensionContext { + // ExtensionContext implements the Context trait +} \ No newline at end of file diff --git a/src/registry/extension.rs b/src/parse/text/simple_extensions/extension.rs similarity index 96% rename from src/registry/extension.rs rename to src/parse/text/simple_extensions/extension.rs index f3eb02d1..3b1c1eb7 100644 --- a/src/registry/extension.rs +++ b/src/parse/text/simple_extensions/extension.rs @@ -10,7 +10,7 @@ use thiserror::Error; use url::Url; use crate::parse::Parse; -use crate::registry::types::{CustomType, ExtensionTypeError}; +use crate::parse::text::simple_extensions::types::{CustomType, ExtensionTypeError}; use crate::text::simple_extensions::SimpleExtensions; use super::context::ExtensionContext; @@ -76,4 +76,4 @@ impl ExtensionFile { pub fn types(&self) -> impl Iterator { self.types.values() } -} +} \ No newline at end of file diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index 5ad15f84..c9de0db5 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -2,7 +2,9 @@ //! Parsing of [text::simple_extensions] types. +use std::collections::HashMap; use thiserror::Error; +use url::Url; use crate::{ parse::{Context, Parse}, @@ -10,16 +12,88 @@ use crate::{ }; pub mod argument; +pub mod context; +pub mod extension; +pub mod registry; +pub mod types; -/// A parsed [text::simple_extensions::SimpleExtensions]. +pub use extension::ExtensionFile; +pub use registry::Registry; +pub use types::{ConcreteType, CustomType, ExtensionTypeError}; + +/// A parsed and validated [text::simple_extensions::SimpleExtensions]. +/// This replaces the TODO implementation with ExtensionFile functionality. pub struct SimpleExtensions { - // TODO + /// The URI this extension was loaded from + pub uri: Url, + /// Parsed and validated custom types + types: HashMap, } /// Parse errors for [text::simple_extensions::SimpleExtensions]. -#[derive(Debug, Error, PartialEq)] +#[derive(Debug, Error)] pub enum SimpleExtensionsError { - // TODO + /// Extension type error + #[error("Extension type error: {0}")] + ExtensionTypeError(#[from] ExtensionTypeError), + /// Unresolved type reference in structure field + #[error("Type '{type_name}' referenced in '{originating}' structure not found")] + UnresolvedTypeReference { + /// The type name that could not be resolved + type_name: String, + /// The type that contains the unresolved reference + originating: String, + }, + /// Structure field cannot be nullable + #[error("Structure representation in type '{originating}' cannot be nullable")] + StructureCannotBeNullable { + /// The type that has a nullable structure + originating: String, + }, +} + +impl SimpleExtensions { + /// Create a new, empty SimpleExtensions + pub fn empty(uri: Url) -> Self { + Self { + uri, + types: HashMap::new(), + } + } + + /// Create a validated SimpleExtensions from raw data and URI + pub fn create( + uri: Url, + extensions: text::simple_extensions::SimpleExtensions, + ) -> Result { + // Parse all types (may contain unresolved Extension(String) references) + let mut ctx = context::ExtensionContext::new(uri.clone()); + let mut types = HashMap::new(); + + for type_item in &extensions.types { + let custom_type = type_item.clone().parse(&mut ctx)?; + // Add the parsed type to the context so later types can reference it + ctx.add_type(&custom_type); + types.insert(custom_type.name.clone(), custom_type); + } + + // TODO: Validate that all Extension(String) references in structure + // fields exist Walk through all CustomType.structure fields and check + // that Extension(String) references can be resolved to actual types in + // the registry. + + Ok(Self { uri, types }) + } + + /// Get a type by name + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.types.get(name) + } + + /// Get an iterator over all types in this extension + pub fn types(&self) -> impl Iterator { + self.types.values() + } } impl Parse for text::simple_extensions::SimpleExtensions { @@ -27,21 +101,15 @@ impl Parse for text::simple_extensions::SimpleExtensions { type Error = SimpleExtensionsError; fn parse(self, _ctx: &mut C) -> Result { - // let text::simple_extensions::SimpleExtensions { - // aggregate_functions, - // dependencies, - // scalar_functions, - // type_variations, - // types, - // window_functions, - // } = self; - - todo!("text::simple_extensions::SimpleExtensions - https://github.com/substrait-io/substrait-rs/issues/157") + // For parsing without URI context, create a dummy URI + let dummy_uri = Url::parse("file:///unknown").unwrap(); + SimpleExtensions::create(dummy_uri, self) } } impl From for text::simple_extensions::SimpleExtensions { fn from(_value: SimpleExtensions) -> Self { - todo!("text::simple_extensions::SimpleExtensions - https://github.com/substrait-io/substrait-rs/issues/157") + // TODO: Implement conversion back to text representation + unimplemented!("Conversion from parsed SimpleExtensions back to text representation not yet implemented") } } diff --git a/src/registry/registry.rs b/src/parse/text/simple_extensions/registry.rs similarity index 97% rename from src/registry/registry.rs rename to src/parse/text/simple_extensions/registry.rs index 1d4baa69..9f1ffa15 100644 --- a/src/registry/registry.rs +++ b/src/parse/text/simple_extensions/registry.rs @@ -8,13 +8,13 @@ //! //! Currently only type definitions are supported. Function parsing will be added in a future update. //! -//! This module is only available when the `registry` feature is enabled. +//! This module is only available when the `parse` feature is enabled. -#![cfg(feature = "registry")] +#![cfg(feature = "parse")] use url::Url; -use super::{types::CustomType, ExtensionFile}; +use super::{types::CustomType, extension::ExtensionFile}; /// Extension Registry that manages Substrait extensions /// @@ -151,4 +151,4 @@ mod tests { let unknown_type_via_registry = registry.get_type(&unknown_extension.uri, "unknown"); assert!(unknown_type_via_registry.is_some()); } -} +} \ No newline at end of file diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs new file mode 100644 index 00000000..38765bbd --- /dev/null +++ b/src/parse/text/simple_extensions/types.rs @@ -0,0 +1,1093 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Concrete type system for function validation in the registry. +//! +//! This module provides a clean, type-safe wrapper around Substrait extension types, +//! separating function signature patterns from concrete argument types. + +use super::context::ExtensionContext; +use crate::parse::Parse; +use crate::text::simple_extensions::{ + EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefsItem, + TypeParamDefsItemType, +}; +use serde_json::Value; +use std::collections::HashMap; +use std::str::FromStr; +use thiserror::Error; + +/// Substrait built-in primitive types +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BuiltinType { + /// Boolean type - `bool` + Boolean, + /// 8-bit signed integer - `i8` + I8, + /// 16-bit signed integer - `i16` + I16, + /// 32-bit signed integer - `i32` + I32, + /// 64-bit signed integer - `i64` + I64, + /// 32-bit floating point - `fp32` + Fp32, + /// 64-bit floating point - `fp64` + Fp64, + /// Variable-length string - `string` + String, + /// Variable-length binary data - `binary` + Binary, + /// Calendar date - `date` + Date, + /// Time of day - `time` (deprecated, use precision_time) + Time, + /// Date and time - `timestamp` (deprecated, use precision_timestamp) + Timestamp, + /// Date and time with timezone - `timestamp_tz` (deprecated, use precision_timestamp_tz) + TimestampTz, + /// Year-month interval - `interval_year` + IntervalYear, + /// Day-time interval - `interval_day` + IntervalDay, + /// 128-bit UUID - `uuid` + Uuid, + /// Fixed-length decimal - `decimal` + Decimal, + /// Variable-length decimal - `decimal` + PrecisionDecimal, + /// Time with precision - `precision_time` + PrecisionTime, + /// Timestamp with precision - `precision_timestamp` + PrecisionTimestamp, + /// Timestamp with timezone and precision - `precision_timestamp_tz` + PrecisionTimestampTz, +} + +/// Error when a builtin type name is not recognized +#[derive(Debug, Error)] +#[error("Unrecognized builtin type: {0}")] +pub struct UnrecognizedBuiltin(String); + +impl FromStr for BuiltinType { + type Err = UnrecognizedBuiltin; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "bool" => Ok(BuiltinType::Boolean), + "i8" => Ok(BuiltinType::I8), + "i16" => Ok(BuiltinType::I16), + "i32" => Ok(BuiltinType::I32), + "i64" => Ok(BuiltinType::I64), + "fp32" => Ok(BuiltinType::Fp32), + "fp64" => Ok(BuiltinType::Fp64), + "string" => Ok(BuiltinType::String), + "binary" => Ok(BuiltinType::Binary), + "date" => Ok(BuiltinType::Date), + "time" => Ok(BuiltinType::Time), + "timestamp" => Ok(BuiltinType::Timestamp), + "timestamp_tz" => Ok(BuiltinType::TimestampTz), + "interval_year" => Ok(BuiltinType::IntervalYear), + "interval_day" => Ok(BuiltinType::IntervalDay), + "uuid" => Ok(BuiltinType::Uuid), + "decimal" => Ok(BuiltinType::Decimal), + "precision_decimal" => Ok(BuiltinType::PrecisionDecimal), + "precision_time" => Ok(BuiltinType::PrecisionTime), + "precision_timestamp" => Ok(BuiltinType::PrecisionTimestamp), + "precision_timestamp_tz" => Ok(BuiltinType::PrecisionTimestampTz), + _ => Err(UnrecognizedBuiltin(s.to_string())), + } + } +} + +/// Parameter type information for type definitions +#[derive(Clone, Debug, PartialEq)] +pub enum ParameterType { + /// Data type parameter + DataType, + /// Integer parameter with range constraints + Integer { + /// Minimum value (inclusive), if specified + min: Option, + /// Maximum value (inclusive), if specified + max: Option, + }, + /// Enumeration parameter + Enumeration { + /// Valid enumeration values + options: Vec, + }, + /// Boolean parameter + Boolean, + /// String parameter + String, +} + +impl ParameterType { + /// Convert back to raw TypeParamDefsItemType + fn raw_type(&self) -> TypeParamDefsItemType { + match self { + ParameterType::DataType => TypeParamDefsItemType::DataType, + ParameterType::Boolean => TypeParamDefsItemType::Boolean, + ParameterType::Integer { .. } => TypeParamDefsItemType::Integer, + ParameterType::Enumeration { .. } => TypeParamDefsItemType::Enumeration, + ParameterType::String => TypeParamDefsItemType::String, + } + } + + /// Extract raw bounds for integer parameters (min, max) + fn raw_bounds(&self) -> (Option, Option) { + match self { + ParameterType::Integer { min, max } => (min.map(|i| i as f64), max.map(|i| i as f64)), + _ => (None, None), + } + } + + /// Extract raw enum options for enumeration parameters + fn raw_options(&self) -> Option { + match self { + ParameterType::Enumeration { options } => Some(EnumOptions(options.clone())), + _ => None, + } + } + + /// Check if a parameter value is valid for this parameter type + pub fn is_valid_value(&self, value: &Value) -> bool { + match (self, value) { + (ParameterType::DataType, Value::String(_)) => true, + (ParameterType::Integer { min, max }, Value::Number(n)) => { + if let Some(i) = n.as_i64() { + min.map_or(true, |min_val| i >= min_val) + && max.map_or(true, |max_val| i <= max_val) + } else { + false + } + } + (ParameterType::Enumeration { options }, Value::String(s)) => options.contains(s), + (ParameterType::Boolean, Value::Bool(_)) => true, + (ParameterType::String, Value::String(_)) => true, + _ => false, + } + } + + fn from_yaml( + t: TypeParamDefsItemType, + opts: Option, + ) -> Result { + Ok(match t { + TypeParamDefsItemType::DataType => Self::DataType, + TypeParamDefsItemType::Boolean => Self::Boolean, + TypeParamDefsItemType::Integer => Self::Integer { + min: None, + max: None, + }, + TypeParamDefsItemType::Enumeration => { + let options = opts.ok_or(TypeParamError::MissingEnumOptions)?.0; // Extract Vec from EnumOptions + Self::Enumeration { options } + } + TypeParamDefsItemType::String => Self::String, + }) + } +} + +/// A validated type parameter with name and constraints +#[derive(Clone, Debug, PartialEq)] +pub struct TypeParam { + /// Parameter name (e.g., "K" for a type variable) + pub name: String, + /// Parameter type constraints + pub param_type: ParameterType, + /// Human-readable description + pub description: Option, +} + +impl TypeParam { + /// Create a new type parameter + pub fn new(name: String, param_type: ParameterType, description: Option) -> Self { + Self { + name, + param_type, + description, + } + } + + /// Check if a parameter value is valid + pub fn is_valid_value(&self, value: &Value) -> bool { + self.param_type.is_valid_value(value) + } +} + +impl TryFrom for TypeParam { + type Error = TypeParamError; + + fn try_from(item: TypeParamDefsItem) -> Result { + let name = item.name.ok_or(TypeParamError::MissingName)?; + let param_type = ParameterType::from_yaml(item.type_, item.options)?; + + Ok(Self { + name, + param_type, + description: item.description, + }) + } +} + +/// Error types for extension type validation +#[derive(Debug, Error, PartialEq)] +pub enum ExtensionTypeError { + /// Extension type name is invalid + #[error("Invalid extension type name: {name}")] + InvalidName { + /// The invalid name + name: String, + }, + /// Parameter validation failed + #[error("Invalid parameter: {0}")] + InvalidParameter(#[from] TypeParamError), + /// Field type is invalid + #[error("Invalid structure field type: {0}")] + InvalidFieldType(String), + /// Structure representation cannot be nullable + #[error("Structure representation cannot be nullable: {type_string}")] + StructureCannotBeNullable { + /// The type string that was nullable + type_string: String, + }, +} + +/// Error types for TypeParam validation +#[derive(Debug, Error, PartialEq)] +pub enum TypeParamError { + /// Parameter name is missing + #[error("Parameter name is required")] + MissingName, + /// Integer parameter has non-integer min/max values + #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")] + InvalidIntegerBounds { + /// The invalid minimum value + min: Option, + /// The invalid maximum value + max: Option, + }, + /// Enumeration parameter is missing options + #[error("Enumeration parameter is missing options")] + MissingEnumOptions, +} + +/// A validated custom extension type definition +#[derive(Clone, Debug, PartialEq)] +pub struct CustomType { + /// Type name + pub name: String, + /// Type parameters (e.g., for generic types) + pub parameters: Vec, + /// Concrete structure definition, if any + pub structure: Option, + /// Whether this type can have variadic parameters + pub variadic: Option, + /// Human-readable description + pub description: Option, +} + +impl CustomType { + /// Check if this type name is valid according to Substrait naming rules + pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> { + if name.is_empty() { + return Err(InvalidTypeName(name.to_string())); + } + + // Basic validation - could be extended with more rules + if name.contains(|c: char| c.is_whitespace()) { + return Err(InvalidTypeName(name.to_string())); + } + + Ok(()) + } + + /// Create a new custom type with validation + pub fn new( + name: String, + parameters: Vec, + structure: Option, + variadic: Option, + description: Option, + ) -> Result { + Self::validate_name(&name) + .map_err(|InvalidTypeName(name)| ExtensionTypeError::InvalidName { name })?; + + Ok(Self { + name, + parameters, + structure, + variadic, + description, + }) + } +} + +impl From for SimpleExtensionsTypesItem { + fn from(value: CustomType) -> Self { + // Convert parameters back to TypeParamDefs if any + let parameters = if value.parameters.is_empty() { + None + } else { + Some(crate::text::simple_extensions::TypeParamDefs( + value + .parameters + .into_iter() + .map(|param| { + let (min, max) = param.param_type.raw_bounds(); + TypeParamDefsItem { + name: Some(param.name), + description: param.description, + type_: param.param_type.raw_type(), + min, + max, + options: param.param_type.raw_options(), + optional: None, + } + }) + .collect(), + )) + }; + + // Convert structure back to Type if any - this is a simplified implementation + let structure = value.structure.map(|_concrete_type| { + // TODO: Implement proper conversion from ConcreteType back to ExtType + // For now, use a placeholder + ExtType::Variant0("placeholder_structure".to_string()) + }); + + SimpleExtensionsTypesItem { + name: value.name, + description: value.description, + parameters, + structure, + variadic: value.variadic, + } + } +} + +impl Parse for SimpleExtensionsTypesItem { + type Parsed = CustomType; + type Error = ExtensionTypeError; + + fn parse(self, _ctx: &mut ExtensionContext) -> Result { + let name = self.name; + CustomType::validate_name(&name) + .map_err(|InvalidTypeName(name)| ExtensionTypeError::InvalidName { name })?; + + let parameters = if let Some(param_defs) = self.parameters { + param_defs + .0 + .into_iter() + .map(|param| TypeParam::try_from(param)) + .collect::, _>>()? + } else { + Vec::new() + }; + + let structure = match self.structure { + Some(structure_data) => Some(ConcreteType::try_from(structure_data)?), + None => None, + }; + + let custom_type = CustomType { + name, + parameters, + structure, + variadic: self.variadic, + description: self.description, + }; + + Ok(custom_type) + } +} + +impl TryFrom for ConcreteType { + type Error = ExtensionTypeError; + + fn try_from(ext_type: ExtType) -> Result { + match ext_type { + // Case: structure: "BINARY" (alias to another type) + ExtType::Variant0(type_string) => { + let parsed_type = ParsedType::parse(&type_string); + let concrete_type = ConcreteType::try_from(parsed_type)?; + + // Structure representation cannot be nullable + if concrete_type.nullable { + return Err(ExtensionTypeError::InvalidName { + name: format!( + "Structure representation '{type_string}' cannot be nullable" + ), + }); + } + + Ok(concrete_type) + } + // Case: structure: { field1: type1, field2: type2 } (named struct) + ExtType::Variant1(field_map) => { + let mut field_names = Vec::new(); + let mut field_types = Vec::new(); + + for (field_name, field_type_value) in field_map { + field_names.push(field_name); + + // field_type_value is serde_json::Value, need to extract string + let type_string = match field_type_value { + Value::String(type_str) => type_str, + _ => { + return Err(ExtensionTypeError::InvalidFieldType( + "Struct field types must be strings".to_string(), + )); + } + }; + let parsed_field_type = ParsedType::parse(&type_string); + + let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; + field_types.push(field_concrete_type); + } + + Ok(ConcreteType { + known_type: KnownType::Struct { + field_names, + field_types, + }, + nullable: false, // Structure definitions cannot be nullable + }) + } + } + } +} + +/// Invalid type name error +#[derive(Debug, Error)] +#[error("Invalid type name: {0}")] +pub struct InvalidTypeName(String); + +/// Error for invalid Type specifications +#[derive(Debug, thiserror::Error)] +pub enum TypeParseError { + /// Extension type name not found in context + #[error("Extension type '{name}' not found")] + ExtensionTypeNotFound { + /// The extension type name that was not found + name: String, + }, + /// Type variable ID is invalid (must be >= 1) + #[error("Type variable 'any{id}' is invalid (must be >= 1)")] + InvalidTypeVariableId { + /// The invalid type variable ID + id: u32, + }, + /// Unimplemented Type variant + #[error("Unimplemented Type variant")] + UnimplementedVariant, +} + +// TODO: ValidatedType will be updated when we implement proper type validation + +// TODO: Update this Parse implementation when ValidatedType and ParsedType are converted to owned types +// impl Parse for &extType { +// type Parsed = ValidatedType; +// type Error = TypeParseError; +// fn parse(self, ctx: &mut ExtensionContext) -> Result { +// todo!("Update when ValidatedType and ParsedType are owned") +// } +// } + +/// Error for invalid function call specifications +#[derive(Debug, thiserror::Error)] +pub enum FunctionCallError { + /// Type parsing failed + #[error("Type parsing failed: {0}")] + TypeParseError(#[from] TypeParseError), + /// Unsupported ArgumentsItem variant + #[error("Unimplemented ArgumentsItem variant: {variant}")] + UnimplementedVariant { + /// The unsupported variant name + variant: String, + }, +} + +// TODO: Update this Parse implementation when ArgumentPattern is converted to owned type +// impl Parse for &simple_extensions::ArgumentsItem { +// type Parsed = ArgumentPattern; +// type Error = FunctionCallError; +// fn parse(self, ctx: &mut ExtensionContext) -> Result { +// todo!("Update when ArgumentPattern is owned") +// } +// } + +/// Known Substrait types (builtin + extension references) +#[derive(Clone, Debug, PartialEq)] +pub enum KnownType { + /// Built-in Substrait primitive type + Builtin(BuiltinType), + /// Reference to an extension type by name + Extension(String), + /// List type with element type + List(Box), + /// Map type with key and value types + Map { + /// Key type + key: Box, + /// Value type + value: Box, + }, + /// Struct type with named fields + Struct { + /// Field names + field_names: Vec, + /// Field types + field_types: Vec, + }, + /// Type variable (e.g., any1, any2) + TypeVariable(u32), +} + +/// A concrete, fully-resolved type instance +#[derive(Clone, Debug, PartialEq)] +pub struct ConcreteType { + /// The known type information + pub known_type: KnownType, + /// Whether this type is nullable + pub nullable: bool, +} + +impl ConcreteType { + /// Create a new builtin type + pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::Builtin(builtin_type), + nullable, + } + } + + /// Create a new extension type reference + pub fn extension(name: String, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::Extension(name), + nullable, + } + } + + /// Create a new list type + pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::List(Box::new(element_type)), + nullable, + } + } + + /// Create a new map type + pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::Map { + key: Box::new(key_type), + value: Box::new(value_type), + }, + nullable, + } + } + + /// Create a new struct type + pub fn nstruct( + field_names: Vec, + field_types: Vec, + nullable: bool, + ) -> ConcreteType { + ConcreteType { + known_type: KnownType::Struct { + field_names, + field_types, + }, + nullable, + } + } + + /// Create a new type variable + pub fn type_variable(id: u32, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::TypeVariable(id), + nullable, + } + } + + /// Check if this type matches another type exactly + pub fn matches(&self, other: &ConcreteType) -> bool { + self == other + } + + /// Check if this type is compatible with another type (considering nullability) + pub fn is_compatible_with(&self, other: &ConcreteType) -> bool { + // Types must match exactly, but nullable types can accept non-nullable values + self.known_type == other.known_type && (self.nullable || !other.nullable) + } +} + +impl<'a> TryFrom> for ConcreteType { + type Error = ExtensionTypeError; + + fn try_from(parsed_type: ParsedType<'a>) -> Result { + match parsed_type { + ParsedType::Builtin(builtin_type, nullability) => Ok(ConcreteType::builtin( + builtin_type, + nullability.unwrap_or(false), + )), + ParsedType::Extension(ext_name, nullability) => Ok(ConcreteType::extension( + ext_name.to_string(), + nullability.unwrap_or(false), + )), + ParsedType::List(element_type, nullability) => { + let element_concrete = ConcreteType::try_from(*element_type)?; + Ok(ConcreteType::list( + element_concrete, + nullability.unwrap_or(false), + )) + } + ParsedType::Map(key_type, value_type, nullability) => { + let key_concrete = ConcreteType::try_from(*key_type)?; + let value_concrete = ConcreteType::try_from(*value_type)?; + Ok(ConcreteType::map( + key_concrete, + value_concrete, + nullability.unwrap_or(false), + )) + } + ParsedType::Struct(field_types, nullability) => { + let field_names: Vec = (0..field_types.len()) + .map(|i| format!("field{}", i)) + .collect(); + let concrete_field_types: Result, _> = field_types + .into_iter() + .map(ConcreteType::try_from) + .collect(); + Ok(ConcreteType::nstruct( + field_names, + concrete_field_types?, + nullability.unwrap_or(false), + )) + } + ParsedType::TypeVariable(id, nullability) => Ok(ConcreteType::type_variable( + id, + nullability.unwrap_or(false), + )), + ParsedType::NamedExtension(type_str, nullability) => Ok(ConcreteType::extension( + type_str.to_string(), + nullability.unwrap_or(false), + )), + } + } +} + +/// A parsed type from a type string, with lifetime tied to the original string +#[derive(Clone, Debug, PartialEq)] +pub enum ParsedType<'a> { + /// Built-in type + Builtin(BuiltinType, Option), + /// Extension type reference + Extension(&'a str, Option), + /// List type + List(Box>, Option), + /// Map type + Map(Box>, Box>, Option), + /// Struct type + Struct(Vec>, Option), + /// Type variable (e.g., any1, any2) + TypeVariable(u32, Option), + /// Named extension type (unresolved) + NamedExtension(&'a str, Option), +} + +impl<'a> ParsedType<'a> { + /// Parse a type string into a ParsedType + pub fn parse(type_str: &'a str) -> Self { + // Simple parsing implementation - could be more sophisticated + let (base_type, nullable) = if type_str.ends_with('?') { + (&type_str[..type_str.len() - 1], Some(true)) + } else { + (type_str, Some(false)) + }; + + // Handle type variables like any1, any2, etc. + if let Some(suffix) = base_type.strip_prefix("any") { + if let Ok(id) = suffix.parse::() { + if id >= 1 { + return ParsedType::TypeVariable(id, nullable); + } + } + } + + // Try to parse as builtin type + if let Ok(builtin_type) = BuiltinType::from_str(base_type) { + return ParsedType::Builtin(builtin_type, nullable); + } + + // Otherwise, treat as extension type + ParsedType::NamedExtension(base_type, nullable) + } +} + +/// A pattern for function arguments that can match concrete types or type variables (TODO: Remove lifetime when ArgumentPattern is owned) +#[derive(Clone, Debug, PartialEq)] +pub enum ArgumentPattern { + /// Type variable like any1, any2, etc. + TypeVariable(u32), + /// Concrete type pattern + Concrete(ConcreteType), +} + +/// Result of matching an argument pattern against a concrete type (TODO: Remove lifetime when Match is owned) +#[derive(Clone, Debug, PartialEq)] +pub enum Match { + /// Pattern matched exactly (for concrete patterns) + Concrete, + /// Type variable bound to concrete type + Variable(u32, ConcreteType), + /// Match failed + Fail, +} + +impl ArgumentPattern { + /// Check if this pattern matches the given concrete type + pub fn matches(&self, concrete: &ConcreteType) -> Match { + match self { + ArgumentPattern::TypeVariable(id) => Match::Variable(*id, concrete.clone()), + ArgumentPattern::Concrete(pattern_type) => { + if pattern_type == concrete { + Match::Concrete + } else { + Match::Fail + } + } + } + } +} + +/// Type variable bindings from matching function arguments (TODO: Remove lifetime when TypeBindings is owned) +#[derive(Debug, Clone, PartialEq)] +pub struct TypeBindings { + /// Map of type variable IDs (e.g. 1 for 'any1') to their concrete types + pub vars: HashMap, +} + +impl TypeBindings { + /// Create type bindings by matching argument patterns against concrete arguments + pub fn new(patterns: &[ArgumentPattern], args: &[ConcreteType]) -> Option { + let mut vars = HashMap::new(); + + if patterns.len() != args.len() { + return None; + } + + for (pattern, arg) in patterns.iter().zip(args.iter()) { + match pattern.matches(arg) { + Match::Concrete => {} // Pattern matched, nothing to bind + Match::Variable(id, concrete_type) => { + // Check if this type variable is already bound to a different type + if let Some(existing_type) = vars.get(&id) { + if existing_type != &concrete_type { + return None; // Conflict: same variable bound to different types + } + } else { + vars.insert(id, concrete_type); + } + } + Match::Fail => return None, // Pattern did not match + } + } + + Some(TypeBindings { vars }) + } + + /// Get the concrete type bound to a type variable + pub fn get_binding(&self, var_id: u32) -> Option<&ConcreteType> { + self.vars.get(&var_id) + } + + /// Check if all type variables are bound + pub fn is_complete(&self, expected_vars: &[u32]) -> bool { + expected_vars.iter().all(|var| self.vars.contains_key(var)) + } +} + +#[cfg(test)] +mod tests { + use super::super::context::ExtensionContext; + use super::*; + use crate::text; + use crate::text::simple_extensions; + use url::Url; + + #[test] + fn test_builtin_type_parsing() { + assert_eq!(BuiltinType::from_str("i32").unwrap(), BuiltinType::I32); + assert_eq!( + BuiltinType::from_str("string").unwrap(), + BuiltinType::String + ); + assert!(BuiltinType::from_str("invalid").is_err()); + } + + #[test] + fn test_parsed_type_simple() { + let parsed = ParsedType::parse("i32"); + assert_eq!(parsed, ParsedType::Builtin(BuiltinType::I32, Some(false))); + + let parsed_nullable = ParsedType::parse("i32?"); + assert_eq!( + parsed_nullable, + ParsedType::Builtin(BuiltinType::I32, Some(true)) + ); + } + + #[test] + fn test_parsed_type_variables() { + let parsed = ParsedType::parse("any1"); + assert_eq!(parsed, ParsedType::TypeVariable(1, Some(false))); + + let parsed_nullable = ParsedType::parse("any2?"); + assert_eq!(parsed_nullable, ParsedType::TypeVariable(2, Some(true))); + + // Invalid type variable ID (must be >= 1) + let parsed_invalid = ParsedType::parse("any0"); + assert_eq!( + parsed_invalid, + ParsedType::NamedExtension("any0", Some(false)) + ); + } + + #[test] + fn test_concrete_type_creation() { + let int_type = ConcreteType::builtin(BuiltinType::I32, false); + assert_eq!( + int_type, + ConcreteType { + known_type: KnownType::Builtin(BuiltinType::I32), + nullable: false + } + ); + + let list_type = ConcreteType::list(int_type.clone(), true); + assert_eq!( + list_type, + ConcreteType { + known_type: KnownType::List(Box::new(int_type)), + nullable: true + } + ); + } + + #[test] + fn test_argument_pattern_matching() { + let concrete_int = ConcreteType::builtin(BuiltinType::I32, false); + let concrete_string = ConcreteType::builtin(BuiltinType::String, false); + + // Test concrete pattern matching + let concrete_pattern = ArgumentPattern::Concrete(concrete_int.clone()); + assert_eq!(concrete_pattern.matches(&concrete_int), Match::Concrete); + assert_eq!(concrete_pattern.matches(&concrete_string), Match::Fail); + + // Test type variable pattern + let var_pattern = ArgumentPattern::TypeVariable(1); + assert_eq!( + var_pattern.matches(&concrete_int), + Match::Variable(1, concrete_int.clone()) + ); + } + + #[test] + fn test_type_bindings() { + let patterns = vec![ + ArgumentPattern::TypeVariable(1), + ArgumentPattern::TypeVariable(1), // Same variable should bind to same type + ]; + let args = vec![ + ConcreteType::builtin(BuiltinType::I32, false), + ConcreteType::builtin(BuiltinType::I32, false), + ]; + + let bindings = TypeBindings::new(&patterns, &args).unwrap(); + assert_eq!( + bindings.get_binding(1), + Some(&ConcreteType::builtin(BuiltinType::I32, false)) + ); + + // Test conflicting bindings + let conflicting_args = vec![ + ConcreteType::builtin(BuiltinType::I32, false), + ConcreteType::builtin(BuiltinType::String, false), + ]; + assert!(TypeBindings::new(&patterns, &conflicting_args).is_none()); + } + + #[test] + fn test_parameter_type_validation() { + let int_param = ParameterType::Integer { + min: Some(1), + max: Some(10), + }; + + assert!(int_param.is_valid_value(&Value::Number(5.into()))); + assert!(!int_param.is_valid_value(&Value::Number(0.into()))); + assert!(!int_param.is_valid_value(&Value::Number(11.into()))); + assert!(!int_param.is_valid_value(&Value::String("not a number".into()))); + + let enum_param = ParameterType::Enumeration { + options: vec!["OVERFLOW".to_string(), "ERROR".to_string()], + }; + + assert!(enum_param.is_valid_value(&Value::String("OVERFLOW".into()))); + assert!(!enum_param.is_valid_value(&Value::String("INVALID".into()))); + } + + #[test] + fn test_custom_type_creation() -> Result<(), ExtensionTypeError> { + let custom_type = CustomType::new( + "MyType".to_string(), + vec![], + Some(ConcreteType::builtin(BuiltinType::I32, false)), + None, + Some("A custom type".to_string()), + )?; + + assert_eq!(custom_type.name, "MyType"); + assert_eq!(custom_type.parameters.len(), 0); + assert!(custom_type.structure.is_some()); + Ok(()) + } + + #[test] + fn test_invalid_type_names() { + // Empty name should be invalid + assert!(CustomType::validate_name("").is_err()); + // Name with whitespace should be invalid + assert!(CustomType::validate_name("bad name").is_err()); + // Valid name should pass + assert!(CustomType::validate_name("GoodName").is_ok()); + } + + #[test] + fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> { + // Test simple type string alias + let ext_type = text::simple_extensions::Type::Variant0("i32".to_string()); + let concrete = ConcreteType::try_from(ext_type)?; + assert_eq!(concrete, ConcreteType::builtin(BuiltinType::I32, false)); + + // Test struct type + let mut field_map = serde_json::Map::new(); + field_map.insert( + "field1".to_string(), + serde_json::Value::String("fp64".to_string()), + ); + let ext_type = text::simple_extensions::Type::Variant1(field_map); + let concrete = ConcreteType::try_from(ext_type)?; + + if let KnownType::Struct { + field_names, + field_types, + } = &concrete.known_type + { + assert_eq!(field_names, &vec!["field1".to_string()]); + assert_eq!(field_types.len(), 1); + assert_eq!( + field_types[0], + ConcreteType::builtin(BuiltinType::Fp64, false) + ); + } else { + panic!("Expected struct type"); + } + + Ok(()) + } + + #[test] + fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> { + let type_item = simple_extensions::SimpleExtensionsTypesItem { + name: "TestType".to_string(), + description: Some("A test type".to_string()), + parameters: None, + structure: Some(text::simple_extensions::Type::Variant0( + "BINARY".to_string(), + )), // Alias to fp64 + variadic: None, + }; + + let mut ctx = ExtensionContext::new(Url::parse("https://example.com/test.yaml").unwrap()); + let custom_type = type_item.parse(&mut ctx)?; + assert_eq!(custom_type.name, "TestType"); + assert_eq!(custom_type.description, Some("A test type".to_string())); + assert!(custom_type.structure.is_some()); + + if let Some(structure) = &custom_type.structure { + assert_eq!( + structure.known_type, + KnownType::Builtin(BuiltinType::Binary) + ); + } + + Ok(()) + } + + #[test] + fn test_custom_type_with_struct() -> Result<(), ExtensionTypeError> { + let mut field_map = serde_json::Map::new(); + field_map.insert( + "x".to_string(), + serde_json::Value::String("fp64".to_string()), + ); + field_map.insert( + "y".to_string(), + serde_json::Value::String("fp64".to_string()), + ); + + let type_item = simple_extensions::SimpleExtensionsTypesItem { + name: "Point".to_string(), + description: Some("A 2D point".to_string()), + parameters: None, + structure: Some(text::simple_extensions::Type::Variant1(field_map)), + variadic: None, + }; + + let mut ctx = ExtensionContext::new(Url::parse("https://example.com/test.yaml").unwrap()); + let custom_type = type_item.parse(&mut ctx)?; + assert_eq!(custom_type.name, "Point"); + + if let Some(ConcreteType { + known_type: + KnownType::Struct { + field_names, + field_types, + }, + .. + }) = &custom_type.structure + { + assert!(field_names.contains(&"x".to_string())); + assert!(field_names.contains(&"y".to_string())); + assert_eq!(field_types.len(), 2); + // Note: HashMap iteration order is not guaranteed, so we just check the types exist + assert!(field_types + .iter() + .all(|t| matches!(t.known_type, KnownType::Builtin(BuiltinType::Fp64)))); + } else { + panic!("Expected struct type"); + } + + Ok(()) + } + + #[test] + fn test_nullable_structure_rejected() { + let ext_type = text::simple_extensions::Type::Variant0("i32?".to_string()); + let result = ConcreteType::try_from(ext_type); + if let Err(ExtensionTypeError::InvalidName { name }) = result { + assert!(name.contains("Structure representation")); + assert!(name.contains("cannot be nullable")); + } else { + panic!( + "Expected nullable structure to be rejected, got: {:?}", + result + ); + } + } +} diff --git a/src/registry/context.rs b/src/registry/context.rs deleted file mode 100644 index 6675c225..00000000 --- a/src/registry/context.rs +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -//! Extension parsing context for validation. - -use std::collections::HashMap; - -use url::Url; - -use super::types::CustomType; -use crate::parse::Context; - -/// Context for parsing and validating extension definitions. -/// -/// This context accumulates validated types as they are parsed, -/// allowing later elements to reference previously validated types. -pub struct ExtensionContext { - /// The URI of the extension being parsed. - pub uri: Url, - /// Map of type names to their validated definitions - pub(crate) types: HashMap, -} - -impl ExtensionContext { - /// Create a new extension context for parsing. - pub fn new(uri: Url) -> Self { - Self { - uri, - types: HashMap::new(), - } - } - - /// Check if a type with the given name exists in the context - pub fn has_type(&self, name: &str) -> bool { - self.types.contains_key(name) - } - - /// Add a type to the context after it has been validated - pub(crate) fn add_type(&mut self, custom_type: &CustomType) { - self.types - .insert(custom_type.name.clone(), custom_type.clone()); - } - - /// Get a type by name from the context, returning the CustomType - pub fn get_type(&self, name: &str) -> Option<&CustomType> { - self.types.get(name) - } -} - -impl Context for ExtensionContext { - // Implementation required by the Context trait -} diff --git a/src/registry/mod.rs b/src/registry/mod.rs deleted file mode 100644 index 82770902..00000000 --- a/src/registry/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -//! Substrait Extension Registry -//! -//! This module provides types and methods that abstract over Substrait -//! SimpleExtensions. -//! -//! ## Design Philosophy -//! -//! Internally, the types in this module are handles to the raw parsed -//! SimpleExtensions from the text module. Externally, they provide a coherent -//! interface that hides those internal details and presents methods where -//! extensions are validated on creation and then assumed valid thereafter. -//! -//! This allows for a clean API that externally follows the "parse don't -//! validate" principle, with an API that encourages users to work with -//! validated extensions without worrying about their internal structure, -//! without needing to add entirely new parse trees - the type tree can be -//! recreated on-demand. -//! -//! ## Core Types -//! -//! - [`ExtensionFile`]: Validated wrapper around SimpleExtensions + URI focusing on type definitions -//! - [`CustomType`]: Parsed and validated extension type definitions -//! - [`Registry`]: Immutable registry for URI+name based type lookup -//! -//! Currently only type definitions are supported. Function parsing will be added in a future update. - -mod context; -mod extension; -#[allow(clippy::module_inception)] -mod registry; -pub mod types; - -pub use extension::{ExtensionFile, ValidationError}; -pub use registry::Registry; -pub use types::{ConcreteType, CustomType}; diff --git a/src/registry/types.rs b/src/registry/types.rs deleted file mode 100644 index de1fdb87..00000000 --- a/src/registry/types.rs +++ /dev/null @@ -1,1111 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -//! Concrete type system for function validation in the registry. -//! -//! This module provides a clean, type-safe wrapper around Substrait extension types, -//! separating function signature patterns from concrete argument types. - -use crate::parse::Parse; -use crate::registry::context::ExtensionContext; -use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefs, TypeParamDefsItem, - TypeParamDefsItemType, -}; -use serde_json::Value; -use std::collections::HashMap; -use std::str::FromStr; -use thiserror::Error; - -/// Substrait built-in primitive types -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum BuiltinType { - /// Boolean type - `bool` - Boolean, - /// 8-bit signed integer - `i8` - I8, - /// 16-bit signed integer - `i16` - I16, - /// 32-bit signed integer - `i32` - I32, - /// 64-bit signed integer - `i64` - I64, - /// 32-bit floating point - `fp32` - Fp32, - /// 64-bit floating point - `fp64` - Fp64, - /// Variable-length string - `string` - String, - /// Variable-length binary data - `binary` - Binary, - /// Calendar date - `date` - Date, - /// Time of day - `time` (deprecated, use precision_time) - Time, - /// Date and time - `timestamp` (deprecated, use precision_timestamp) - Timestamp, - /// Date and time with timezone - `timestamp_tz` (deprecated, use precision_timestamp_tz) - TimestampTz, - /// Year-month interval - `interval_year` - IntervalYear, - /// Day-time interval - `interval_day` - IntervalDay, - /// Compound interval - `interval_compound` - IntervalCompound, - /// UUID type - `uuid` - Uuid, - /// Fixed-length character string - `fixed_char` - FixedChar, - /// Variable-length character string - `varchar` - VarChar, - /// Fixed-length binary data - `fixed_binary` - FixedBinary, - /// Decimal number - `decimal` - Decimal, - /// Time with precision - `precision_time` - PrecisionTime, - /// Timestamp with precision - `precision_timestamp` - PrecisionTimestamp, - /// Timestamp with timezone and precision - `precision_timestamp_tz` - PrecisionTimestampTz, - /// Struct/record type - `struct` - Struct, - /// List/array type - `list` - List, - /// Map/dictionary type - `map` - Map, - /// User-defined type - `user_defined` - UserDefined, -} - -#[derive(Debug, thiserror::Error)] -/// Error for unrecognized builtin type strings -#[error("Unrecognized builtin type")] -pub struct UnrecognizedBuiltin; - -impl FromStr for BuiltinType { - type Err = UnrecognizedBuiltin; - - fn from_str(s: &str) -> Result { - match s { - "boolean" => Ok(BuiltinType::Boolean), - "i8" => Ok(BuiltinType::I8), - "i16" => Ok(BuiltinType::I16), - "i32" => Ok(BuiltinType::I32), - "i64" => Ok(BuiltinType::I64), - "fp32" => Ok(BuiltinType::Fp32), - "fp64" => Ok(BuiltinType::Fp64), - "string" => Ok(BuiltinType::String), - "binary" => Ok(BuiltinType::Binary), - "date" => Ok(BuiltinType::Date), - "time" => Ok(BuiltinType::Time), - "timestamp" => Ok(BuiltinType::Timestamp), - "timestamp_tz" => Ok(BuiltinType::TimestampTz), - "interval_year" => Ok(BuiltinType::IntervalYear), - "interval_day" => Ok(BuiltinType::IntervalDay), - "interval_compound" => Ok(BuiltinType::IntervalCompound), - "uuid" => Ok(BuiltinType::Uuid), - "fixed_char" => Ok(BuiltinType::FixedChar), - "varchar" => Ok(BuiltinType::VarChar), - "fixed_binary" => Ok(BuiltinType::FixedBinary), - "decimal" => Ok(BuiltinType::Decimal), - "precision_time" => Ok(BuiltinType::PrecisionTime), - "precision_timestamp" => Ok(BuiltinType::PrecisionTimestamp), - "precision_timestamp_tz" => Ok(BuiltinType::PrecisionTimestampTz), - "struct" => Ok(BuiltinType::Struct), - "list" => Ok(BuiltinType::List), - "map" => Ok(BuiltinType::Map), - "user_defined" => Ok(BuiltinType::UserDefined), - _ => Err(UnrecognizedBuiltin), - } - } -} -/// Parameter type for extension type definitions -#[derive(Clone, Debug, PartialEq)] -pub enum ParameterType { - /// A type name - DataType, - /// True/False - Boolean, - /// Integer - Integer, - /// A particular enum - Enum, - /// A string - String, -} - -/// Type-safe parameter constraints based on parameter kind -#[derive(Clone, Debug)] -pub enum ParamKind { - /// A type name parameter - DataType, - /// True/False parameter - Boolean, - /// Integer parameter with optional bounds - Integer { - /// Minimum value constraint - min: Option, - /// Maximum value constraint - max: Option, - }, - /// Enumeration parameter with predefined options - Enumeration { - /// Valid enumeration values - options: Vec, - }, - /// String parameter - String, -} - -impl ParamKind { - fn get_integer_bounds( - min: Option, - max: Option, - ) -> Result<(Option, Option), TypeParamError> { - // Convert float bounds to integers, validating they are whole numbers - let min_bound = if let Some(min_f) = min { - if min_f.fract() != 0.0 { - return Err(TypeParamError::InvalidIntegerBounds { min, max }); - } - Some(min_f as i64) - } else { - None - }; - - let max_bound = if let Some(max_f) = max { - if max_f.fract() != 0.0 { - return Err(TypeParamError::InvalidIntegerBounds { min, max }); - } - Some(max_f as i64) - } else { - None - }; - - Ok((min_bound, max_bound)) - } - - /// Create a ParamKind from TypeParamDefsItemType and associated fields - fn try_from_item_parts( - item_type: TypeParamDefsItemType, - min: Option, - max: Option, - options: Option, - ) -> Result { - match (item_type, min, max, options) { - // Valid cases - each type with its expected parameters - (TypeParamDefsItemType::DataType, None, None, None) => Ok(ParamKind::DataType), - (TypeParamDefsItemType::Boolean, None, None, None) => Ok(ParamKind::Boolean), - (TypeParamDefsItemType::Integer, min, max, None) => { - let (min_bound, max_bound) = Self::get_integer_bounds(min, max)?; - Ok(ParamKind::Integer { - min: min_bound, - max: max_bound, - }) - } - (TypeParamDefsItemType::Enumeration, None, None, Some(enum_options)) => { - Ok(ParamKind::Enumeration { - options: enum_options.0, - }) - } - (TypeParamDefsItemType::String, None, None, None) => Ok(ParamKind::String), - - // Error cases - DataType with unexpected parameters - (TypeParamDefsItemType::DataType, Some(_), _, _) - | (TypeParamDefsItemType::DataType, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { - param_type: TypeParamDefsItemType::DataType, - }) - } - (TypeParamDefsItemType::DataType, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { - param_type: TypeParamDefsItemType::DataType, - }) - } - - // Error cases - Boolean with unexpected parameters - (TypeParamDefsItemType::Boolean, Some(_), _, _) - | (TypeParamDefsItemType::Boolean, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { - param_type: TypeParamDefsItemType::Boolean, - }) - } - (TypeParamDefsItemType::Boolean, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { - param_type: TypeParamDefsItemType::Boolean, - }) - } - - // Error cases - Integer with enum options - (TypeParamDefsItemType::Integer, _, _, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { - param_type: TypeParamDefsItemType::Integer, - }) - } - - // Error cases - Enumeration with unexpected parameters - (TypeParamDefsItemType::Enumeration, Some(_), _, _) - | (TypeParamDefsItemType::Enumeration, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { - param_type: TypeParamDefsItemType::Enumeration, - }) - } - (TypeParamDefsItemType::Enumeration, None, None, None) => { - Err(TypeParamError::MissingEnumOptions) - } - - // Error cases - String with unexpected parameters - (TypeParamDefsItemType::String, Some(_), _, _) - | (TypeParamDefsItemType::String, _, Some(_), _) => { - Err(TypeParamError::UnexpectedMinMaxBounds { - param_type: TypeParamDefsItemType::String, - }) - } - (TypeParamDefsItemType::String, None, None, Some(_)) => { - Err(TypeParamError::UnexpectedEnumOptions { - param_type: TypeParamDefsItemType::String, - }) - } - } - } -} - -/// Type parameter definition for custom types -#[derive(Clone, Debug)] -pub struct TypeParam { - /// Name of the parameter (required) - pub name: String, - /// Optional description of the parameter - pub description: Option, - /// Type-safe parameter constraints - pub kind: ParamKind, -} - -impl TryFrom for TypeParam { - type Error = TypeParamError; - - fn try_from(item: TypeParamDefsItem) -> Result { - let name = item.name.ok_or(TypeParamError::MissingName)?; - - let kind = ParamKind::try_from_item_parts(item.type_, item.min, item.max, item.options)?; - - Ok(Self { - name, - description: item.description, - kind, - }) - } -} - -impl From for TypeParamDefsItem { - fn from(param_def: TypeParam) -> Self { - let (param_type, min, max, options) = match param_def.kind { - ParamKind::DataType => (TypeParamDefsItemType::DataType, None, None, None), - ParamKind::Boolean => (TypeParamDefsItemType::Boolean, None, None, None), - ParamKind::Integer { min, max } => ( - TypeParamDefsItemType::Integer, - min.map(|i| i as f64), - max.map(|i| i as f64), - None, - ), - ParamKind::Enumeration { options } => ( - TypeParamDefsItemType::Enumeration, - None, - None, - Some(EnumOptions(options)), - ), - ParamKind::String => (TypeParamDefsItemType::String, None, None, None), - }; - - Self { - name: Some(param_def.name), - description: param_def.description, - type_: param_type, - min, - max, - optional: None, // Not needed for type definitions - options, - } - } -} - -/// Error types for ExtensionType parsing -#[derive(Debug, Error, PartialEq)] -pub enum ExtensionTypeError { - /// Extension type name is invalid - #[error("Invalid extension type name: {name}")] - InvalidName { - /// The invalid name - name: String, - }, - /// Parameter validation failed - #[error("Invalid parameter: {0}")] - InvalidParameter(#[from] TypeParamError), - /// Field type is invalid - #[error("Invalid structure field type: {0}")] - InvalidFieldType(String), - /// Structure representation cannot be nullable - #[error("Structure representation cannot be nullable: {type_string}")] - StructureCannotBeNullable { - /// The type string that was nullable - type_string: String, - }, -} - -/// Error types for TypeParam validation -#[derive(Debug, Error, PartialEq)] -pub enum TypeParamError { - /// Parameter name is missing - #[error("Parameter name is required")] - MissingName, - /// Integer parameter has non-integer min/max values - #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")] - InvalidIntegerBounds { - /// The invalid minimum value - min: Option, - /// The invalid maximum value - max: Option, - }, - /// Parameter type cannot have min/max bounds - #[error("Parameter type '{param_type}' cannot have min/max bounds")] - UnexpectedMinMaxBounds { - /// The parameter type that cannot have bounds - param_type: TypeParamDefsItemType, - }, - /// Parameter type cannot have enumeration options - #[error("Parameter type '{param_type}' cannot have enumeration options")] - UnexpectedEnumOptions { - /// The parameter type that cannot have options - param_type: TypeParamDefsItemType, - }, - /// Enumeration parameter is missing required options - #[error("Enumeration parameter is missing required options")] - MissingEnumOptions, -} - -/// A custom type definition -#[derive(Clone, Debug)] -pub struct CustomType { - /// The name of this custom type - pub name: String, - /// Optional description of this type - pub description: Option, - /// How this type is represented (None = opaque, Some = structured representation) - /// If Some, nullable MUST be false - pub structure: Option, - /// Parameters for this type (empty if none) - pub parameters: Vec, - // TODO: Add variadic field for variadic type support -} - -impl PartialEq for CustomType { - fn eq(&self, other: &Self) -> bool { - // Name should be unique for a given extension file - self.name == other.name - } -} - -impl CustomType { - /// Get the name of this custom type - pub fn name(&self) -> &str { - &self.name - } -} - -impl From for SimpleExtensionsTypesItem { - fn from(custom_type: CustomType) -> Self { - Self { - name: custom_type.name, - description: custom_type.description, - parameters: if custom_type.parameters.is_empty() { - None - } else { - Some(TypeParamDefs( - custom_type.parameters.into_iter().map(Into::into).collect(), - )) - }, - structure: None, // TODO: Add structure support - variadic: None, // TODO: Add variadic support - } - } -} - -#[derive(Debug, thiserror::Error)] -#[error("Invalid type name: {0}")] -/// Error for invalid type names in extension definitions -pub struct InvalidTypeName(String); - -impl Parse for SimpleExtensionsTypesItem { - type Parsed = CustomType; - type Error = ExtensionTypeError; - - fn parse(self, ctx: &mut ExtensionContext) -> Result { - let SimpleExtensionsTypesItem { - name, - description, - parameters, - structure, - variadic: _, // TODO: Add variadic support - } = self; - - // TODO: Not all names are valid for types, we should validate that - if name.is_empty() { - return Err(ExtensionTypeError::InvalidName { name }); - } - - let parameters = match parameters { - Some(type_param_defs) => { - let mut parsed_params = Vec::new(); - for item in type_param_defs.0 { - parsed_params.push(TypeParam::try_from(item)?); - } - parsed_params - } - None => Vec::new(), - }; - - // Parse structure field if present - let structure = match structure { - Some(structure_data) => Some(ConcreteType::try_from(structure_data)?), - None => None, // Opaque type - }; - - let custom_type = CustomType { - name: name.clone(), - description, - structure, - parameters, - }; - - ctx.add_type(&custom_type); - Ok(custom_type) - } -} - -impl TryFrom for ConcreteType { - type Error = ExtensionTypeError; - - fn try_from(ext_type: ExtType) -> Result { - match ext_type { - // Case: structure: "BINARY" (alias to another type) - ExtType::Variant0(type_string) => { - let parsed_type = ParsedType::parse(&type_string); - let concrete_type = ConcreteType::try_from(parsed_type)?; - - // Structure representation cannot be nullable - if concrete_type.nullable { - return Err(ExtensionTypeError::InvalidName { - name: format!("Structure representation '{type_string}' cannot be nullable"), - }); - } - - Ok(concrete_type) - } - // Case: structure: { field1: type1, field2: type2 } (named struct) - ExtType::Variant1(field_map) => { - let mut field_names = Vec::new(); - let mut field_types = Vec::new(); - - for (field_name, field_type_variant) in field_map { - field_names.push(field_name); - - let field_type_str = match field_type_variant { - Value::String(s) => s, - _ => { - return Err(ExtensionTypeError::InvalidName { - name: field_type_variant.to_string(), - }) - } - }; - - let parsed_field_type = ParsedType::parse(&field_type_str); - let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; - field_types.push(field_concrete_type); - } - - Ok(ConcreteType { - base: KnownType::NStruct(field_names), - nullable: false, // Structure representation cannot be nullable - parameters: field_types, - }) - } - } - } -} - -/// Error for invalid Type specifications -#[derive(Debug, thiserror::Error)] -pub enum TypeParseError { - /// Extension type name not found in context - #[error("Extension type '{name}' not found")] - ExtensionTypeNotFound { - /// The extension type name that was not found - name: String, - }, - /// Type variable ID is invalid (must be >= 1) - #[error("Type variable 'any{id}' is invalid (must be >= 1)")] - InvalidTypeVariableId { - /// The invalid type variable ID - id: u32, - }, - /// Unimplemented Type variant - #[error("Unimplemented Type variant")] - UnimplementedVariant, -} - -// TODO: ValidatedType will be updated when we implement proper type validation - -// TODO: Update this Parse implementation when ValidatedType and ParsedType are converted to owned types -// impl Parse for &extType { -// type Parsed = ValidatedType; -// type Error = TypeParseError; -// fn parse(self, ctx: &mut ExtensionContext) -> Result { -// todo!("Update when ValidatedType and ParsedType are owned") -// } -// } - -/// Error for invalid ArgumentsItem specifications (TODO: Update when ArgumentPattern is owned) -#[derive(Debug, thiserror::Error)] -pub enum ArgumentsItemError { - /// Type parsing failed - #[error("Type parsing failed: {0}")] - TypeParseError(#[from] TypeParseError), - /// Unsupported ArgumentsItem variant - #[error("Unimplemented ArgumentsItem variant: {variant}")] - UnimplementedVariant { - /// The unsupported variant name - variant: String, - }, -} - -// TODO: Update this Parse implementation when ArgumentPattern is converted to owned type -// impl Parse for &ArgumentsItem { -// type Parsed = ArgumentPattern; -// type Error = ArgumentsItemError; -// fn parse(self, ctx: &mut ExtensionContext) -> Result { -// todo!("Update when ArgumentPattern is owned") -// } -// } - -/// Represents a known, specific type, either builtin, extension reference, or structured -#[derive(Clone, Debug, PartialEq)] -pub enum KnownType { - /// Built-in primitive types - Builtin(BuiltinType), - /// Custom types defined in extension YAML files (unresolved reference) - Extension(String), - /// Named struct with field names (corresponds to Substrait's NSTRUCT pseudo-type) - NStruct(Vec), -} - -impl FromStr for KnownType { - type Err = ExtensionTypeError; - - fn from_str(s: &str) -> Result { - // First try to parse as a builtin type - match BuiltinType::from_str(s) { - Ok(builtin) => Ok(KnownType::Builtin(builtin)), - Err(_) => { - // TODO: Validate that the string is a valid type name - // For now, treat all non-builtin strings as extension type references - Ok(KnownType::Extension(s.to_string())) - } - } - } -} - -/// A concrete type, fully specified with nullability and parameters -#[derive(Clone, Debug, PartialEq)] -pub struct ConcreteType { - /// Base type, can be builtin or extension - pub base: KnownType, - /// Is the overall type nullable? - pub nullable: bool, - // TODO: Add non-type parameters (e.g. integers, enum, etc.) - /// Parameters for the type, if there are any - pub parameters: Vec, -} - -impl ConcreteType { - /// Create a concrete type from a builtin type - pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType { - ConcreteType { - base: KnownType::Builtin(builtin_type), - nullable, - parameters: Vec::new(), - } - } - - /// Create a concrete type from an extension type name - pub fn extension(type_name: impl Into, nullable: bool) -> Self { - Self { - base: KnownType::Extension(type_name.into()), - nullable, - parameters: Vec::new(), - } - } - - /// Create a concrete type for a named struct (NSTRUCT) - pub fn nstruct( - field_names: Vec, - field_types: Vec, - nullable: bool, - ) -> Self { - Self { - base: KnownType::NStruct(field_names), - nullable, - parameters: field_types, - } - } - - /// Create a parameterized concrete type - pub fn parameterized(base: KnownType, nullable: bool, parameters: Vec) -> Self { - Self { - base, - nullable, - parameters, - } - } -} - -impl<'a> TryFrom> for ConcreteType { - type Error = ExtensionTypeError; - - fn try_from(parsed: ParsedType<'a>) -> Result { - match parsed { - ParsedType::Builtin(builtin_type, nullable) => { - Ok(ConcreteType::builtin(builtin_type, nullable)) - } - ParsedType::NamedExtension(type_name, nullable) => { - Ok(ConcreteType::extension(type_name.to_string(), nullable)) - } - ParsedType::TypeVariable(_) | ParsedType::NullableTypeVariable(_) => { - Err(ExtensionTypeError::InvalidName { - name: "Type variables not allowed in structure definitions".to_string(), - }) - } - ParsedType::Parameterized { - base, - parameters, - nullable, - } => { - let base_concrete = ConcreteType::try_from(*base)?; - let param_concretes: Result, _> = - parameters.into_iter().map(ConcreteType::try_from).collect(); - Ok(ConcreteType::parameterized( - base_concrete.base, - nullable, - param_concretes?, - )) - } - } - } -} - -/// A parsed type that can represent type variables, builtin types, extension types, or parameterized types -#[derive(Clone, Debug, PartialEq)] -pub enum ParsedType<'a> { - /// Type variable like any1, any2, etc. - TypeVariable(u32), - /// Nullable type variable like any1?, any2?, etc.; used in return types - NullableTypeVariable(u32), - /// Built-in primitive type, with nullability - Builtin(BuiltinType, bool), - /// Extension type for the given name, with nullability. URI not known at this level. - NamedExtension(&'a str, bool), - /// Parameterized type - Parameterized { - /// Base type, can be builtin or extension - base: Box>, - /// Parameters for that type - parameters: Vec>, - /// Is the overall type nullable? - nullable: bool, - }, -} - -impl<'a> ParsedType<'a> { - /// Parse a type string into a ParsedType - pub fn parse(type_str: &'a str) -> ParsedType<'a> { - // Strip nullability - let (type_str, nullability) = if let Some(rest) = type_str.strip_suffix('?') { - (rest, true) - } else { - (type_str, false) - }; - - // Handle any expressions - if let Some(rest) = type_str.strip_prefix("any") { - if let Ok(id) = rest.parse::() { - if nullability { - // any1? etc. are nullable type variables - permissible in - // return position - return ParsedType::NullableTypeVariable(id); - } else { - return ParsedType::TypeVariable(id); - } - } - } - - // Handle parameterized types like "list" (future implementation) - if type_str.contains('<') && type_str.ends_with('>') { - unimplemented!("Parameterized types not yet implemented: {}", type_str); - } - - // Try to parse as builtin type - if let Ok(builtin_type) = BuiltinType::from_str(type_str) { - return ParsedType::Builtin(builtin_type, nullability); - } - - // Not a builtin or type variable - assume it's an extension type name - ParsedType::NamedExtension(type_str, nullability) - } -} - -/// A pattern for function arguments that can match concrete types or type variables (TODO: Remove lifetime when ArgumentPattern is owned) -#[derive(Clone, Debug, PartialEq)] -pub enum ArgumentPattern { - /// Type variable like any1, any2, etc. - TypeVariable(u32), - /// Concrete type pattern - Concrete(ConcreteType), -} - -/// Result of matching an argument pattern against a concrete type (TODO: Remove lifetime when Match is owned) -#[derive(Clone, Debug, PartialEq)] -pub enum Match { - /// Pattern matched exactly (for concrete patterns) - Concrete, - /// Type variable bound to concrete type - Variable(u32, ConcreteType), - /// Match failed - Fail, -} - -impl ArgumentPattern { - /// Check if this pattern matches the given concrete type - pub fn matches(&self, concrete: &ConcreteType) -> Match { - match self { - ArgumentPattern::TypeVariable(id) => Match::Variable(*id, concrete.clone()), - ArgumentPattern::Concrete(pattern_type) => { - if pattern_type == concrete { - Match::Concrete - } else { - Match::Fail - } - } - } - } -} - -/// Type variable bindings from matching function arguments (TODO: Remove lifetime when TypeBindings is owned) -#[derive(Debug, Clone, PartialEq)] -pub struct TypeBindings { - /// Map of type variable IDs (e.g. 1 for 'any1') to their concrete types - pub vars: HashMap, -} - -impl TypeBindings { - /// Create type bindings by matching argument patterns against concrete arguments - pub fn new(patterns: &[ArgumentPattern], args: &[ConcreteType]) -> Option { - // Check length compatibility - if patterns.len() != args.len() { - unimplemented!("Handle variadic functions"); - } - - let mut vars = HashMap::new(); - - // Match each pattern against corresponding argument - for (pattern, arg) in patterns.iter().zip(args.iter()) { - match pattern.matches(arg) { - Match::Concrete => { - // Concrete pattern matched, continue - continue; - } - Match::Variable(id, concrete_type) => { - // Check for consistency with existing bindings - if let Some(existing_binding) = vars.get(&id) { - if existing_binding != &concrete_type { - // Conflicting binding - type variable bound to different types - return None; - } - } else { - // New binding - vars.insert(id, concrete_type); - } - } - Match::Fail => { - // Pattern didn't match - return None; - } - } - } - - Some(TypeBindings { vars }) - } - - /// Get the bound type for a type variable, if any - pub fn get(&self, var_id: u32) -> Option<&ConcreteType> { - self.vars.get(&var_id) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - use url::Url; - - #[test] - fn test_extension_type_parse_basic() { - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri.clone()); - - let original_type_item = SimpleExtensionsTypesItem { - name: "MyType".to_string(), - description: Some("A test type".to_string()), - parameters: None, - structure: None, - variadic: None, - }; - - let result = original_type_item.clone().parse(&mut ctx); - assert!(result.is_ok()); - - let custom_type = result.unwrap(); - assert_eq!(custom_type.name, "MyType"); - assert_eq!(custom_type.description, Some("A test type".to_string())); - assert!(custom_type.parameters.is_empty()); - - // Test round-trip conversion - let converted_back: SimpleExtensionsTypesItem = custom_type.into(); - assert_eq!(converted_back.name, original_type_item.name); - assert_eq!(converted_back.description, original_type_item.description); - // Note: structure and variadic are TODO fields - } - - #[test] - fn test_extension_type_parse_with_parameters() { - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri.clone()); - - let original_type_item = SimpleExtensionsTypesItem { - name: "ParameterizedType".to_string(), - description: None, - parameters: Some(TypeParamDefs(vec![ - TypeParamDefsItem { - name: Some("length".to_string()), - description: Some("The length parameter".to_string()), - type_: TypeParamDefsItemType::Integer, - min: Some(1.0), - max: Some(1000.0), - optional: Some(false), - options: None, - }, - ])), - structure: None, - variadic: None, - }; - - let result = original_type_item.clone().parse(&mut ctx); - assert!(result.is_ok()); - - let custom_type = result.unwrap(); - assert_eq!(custom_type.name, "ParameterizedType"); - assert_eq!(custom_type.parameters.len(), 1); - - let param = &custom_type.parameters[0]; - assert_eq!(param.name, "length"); - assert_eq!(param.description, Some("The length parameter".to_string())); - if let ParamKind::Integer { min, max } = ¶m.kind { - assert_eq!(*min, Some(1)); - assert_eq!(*max, Some(1000)); - } else { - panic!("Expected Integer parameter kind"); - } - - // Test round-trip conversion - let converted_back: SimpleExtensionsTypesItem = custom_type.into(); - assert_eq!(converted_back.name, original_type_item.name); - assert_eq!(converted_back.description, original_type_item.description); - // Note: parameter and structure comparisons would require PartialEq implementations - } - - #[test] - fn test_extension_type_parse_empty_name_error() { - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri); - - let type_item = SimpleExtensionsTypesItem { - name: "".to_string(), // Empty name should cause error - description: None, - parameters: None, - structure: None, - variadic: None, - }; - - let result = type_item.parse(&mut ctx); - assert!(result.is_err()); - - if let Err(ExtensionTypeError::InvalidName { name }) = result { - assert_eq!(name, ""); - } else { - panic!("Expected InvalidName error"); - } - } - - #[test] - fn test_extension_context_type_tracking() { - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri.clone()); - - // Initially no types - assert!(!ctx.has_type("MyType")); - - let type_item = SimpleExtensionsTypesItem { - name: "MyType".to_string(), - description: None, - parameters: None, - structure: None, - variadic: None, - }; - - // Parse the type - this should add it to context - let _custom_type = type_item.parse(&mut ctx).unwrap(); - - // Now the context should have the type - assert!(ctx.has_type("MyType")); - - let retrieved_type = ctx.get_type("MyType"); - assert!(retrieved_type.is_some()); - assert_eq!(retrieved_type.unwrap().name, "MyType"); - } - - #[test] - fn test_type_param_conversion() { - let original_param = TypeParamDefsItem { - name: Some("test_param".to_string()), - description: Some("A test parameter".to_string()), - type_: TypeParamDefsItemType::Integer, - min: Some(0.0), - max: Some(100.0), - optional: Some(true), - options: None, - }; - - // Convert to owned TypeParam - let type_param = TypeParam::try_from(original_param.clone()).unwrap(); - assert_eq!(type_param.name, "test_param"); - assert_eq!(type_param.description, Some("A test parameter".to_string())); - - if let ParamKind::Integer { min, max } = type_param.kind { - assert_eq!(min, Some(0)); - assert_eq!(max, Some(100)); - } else { - panic!("Expected Integer parameter kind"); - } - - // Convert back to original type - let converted_back = TypeParamDefsItem::from(type_param); - assert_eq!(converted_back.name, original_param.name); - assert_eq!(converted_back.description, original_param.description); - assert_eq!(converted_back.type_, original_param.type_); - assert_eq!(converted_back.min, original_param.min); - assert_eq!(converted_back.max, original_param.max); - // Note: optional field is not used in our new structure - } - - #[test] - fn test_simple_type_no_structure() { - // Test a simple opaque type (no structure field) - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri); - - let type_item = SimpleExtensionsTypesItem { - name: "unknown".to_string(), - description: Some("An opaque type".to_string()), - parameters: None, - structure: None, // Opaque type - variadic: None, - }; - - let result = type_item.parse(&mut ctx); - assert!(result.is_ok()); - - let custom_type = result.unwrap(); - assert_eq!(custom_type.name, "unknown"); - assert_eq!(custom_type.description, Some("An opaque type".to_string())); - assert!(custom_type.structure.is_none()); // Should be None for opaque type - assert!(custom_type.parameters.is_empty()); - } - - #[test] - fn test_types_with_structure() { - // Test a type with structure: "BINARY" (alias) - let uri = Url::parse("https://example.com/test.yaml").unwrap(); - let mut ctx = ExtensionContext::new(uri); - - let type_item = SimpleExtensionsTypesItem { - name: "coordinate".to_string(), - description: Some("A coordinate in some form".to_string()), - parameters: None, - structure: Some(ExtType::Variant0("fp64".to_string())), // Alias to fp64 - variadic: None, - }; - - let result = type_item.parse(&mut ctx); - assert!(result.is_ok()); - - let custom_type = result.unwrap(); - assert_eq!(custom_type.name, "coordinate"); - assert!(custom_type.structure.is_some()); - - let structure = custom_type.structure.unwrap(); - assert!(!structure.nullable); // Structure cannot be nullable - assert!(matches!( - structure.base, - KnownType::Builtin(BuiltinType::Fp64) - )); - - // Create a map structure like { latitude: "coordinate", longitude: "coordinate" } - let mut field_map = serde_json::Map::new(); - field_map.insert("latitude".to_string(), json!("coordinate")); - field_map.insert("longitude".to_string(), json!("coordinate")); - - let type_item = SimpleExtensionsTypesItem { - name: "point".to_string(), - description: Some("A 2D point".to_string()), - parameters: None, - structure: Some(ExtType::Variant1(field_map)), - variadic: None, - }; - - let result = type_item.parse(&mut ctx); - assert!(result.is_ok()); - - let custom_type = result.unwrap(); - assert_eq!(custom_type.name, "point"); - assert!(custom_type.structure.is_some()); - - let structure = custom_type.structure.unwrap(); - assert!(!structure.nullable); // Structure cannot be nullable - - // Should be NStruct with field names - if let KnownType::NStruct(field_names) = structure.base { - assert_eq!(field_names.len(), 2); - assert!(field_names.contains(&"latitude".to_string())); - assert!(field_names.contains(&"longitude".to_string())); - } else { - panic!("Expected NStruct base type"); - } - - // Should have 2 field types (parameters) - assert_eq!(structure.parameters.len(), 2); - for param in &structure.parameters { - if let KnownType::Extension(ref type_name) = param.base { - assert_eq!(type_name, "coordinate"); - } else { - panic!("Expected Extension type for coordinate reference"); - } - } - } -} From 1b49a6c0c5ab1c6906e1dcb933648349abe5653e Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Mon, 18 Aug 2025 16:46:11 -0400 Subject: [PATCH 09/31] Mostly working but builtin types are a bit messed up --- src/parse/context.rs | 18 ++-- src/parse/text/simple_extensions/extension.rs | 79 ----------------- .../{context.rs => extensions.rs} | 25 ++---- src/parse/text/simple_extensions/file.rs | 79 +++++++++++++++++ src/parse/text/simple_extensions/mod.rs | 87 ++----------------- src/parse/text/simple_extensions/registry.rs | 18 ++-- src/parse/text/simple_extensions/types.rs | 30 ++++--- 7 files changed, 130 insertions(+), 206 deletions(-) delete mode 100644 src/parse/text/simple_extensions/extension.rs rename src/parse/text/simple_extensions/{context.rs => extensions.rs} (71%) create mode 100644 src/parse/text/simple_extensions/file.rs diff --git a/src/parse/context.rs b/src/parse/context.rs index 994d29a4..1b56346e 100644 --- a/src/parse/context.rs +++ b/src/parse/context.rs @@ -5,7 +5,7 @@ use thiserror::Error; use crate::parse::{ - proto::extensions::SimpleExtensionUri, text::simple_extensions::SimpleExtensions, Anchor, Parse, + proto::extensions::SimpleExtensionUri, text::simple_extensions::ExtensionFile, Anchor, Parse, }; /// A parse context. @@ -33,13 +33,13 @@ pub trait ProtoContext: Context { fn add_simple_extension_uri( &mut self, simple_extension_uri: &SimpleExtensionUri, - ) -> Result<&SimpleExtensions, ContextError>; + ) -> Result<&ExtensionFile, ContextError>; /// Returns the simple extensions for the given simple extension anchor. fn simple_extensions( &self, anchor: &Anchor, - ) -> Result<&SimpleExtensions, ContextError>; + ) -> Result<&ExtensionFile, ContextError>; } /// Parse context errors. @@ -64,7 +64,7 @@ pub(crate) mod fixtures { use crate::parse::{ context::ContextError, proto::extensions::SimpleExtensionUri, - text::simple_extensions::SimpleExtensions, Anchor, + text::simple_extensions::ExtensionFile, Anchor, }; /// A test context. @@ -72,7 +72,7 @@ pub(crate) mod fixtures { /// This currently mocks support for simple extensions (does not resolve or /// parse). pub struct Context { - simple_extensions: HashMap, SimpleExtensions>, + simple_extensions: HashMap, ExtensionFile>, } impl Default for Context { @@ -89,7 +89,7 @@ pub(crate) mod fixtures { fn add_simple_extension_uri( &mut self, simple_extension_uri: &crate::parse::proto::extensions::SimpleExtensionUri, - ) -> Result<&SimpleExtensions, ContextError> { + ) -> Result<&ExtensionFile, ContextError> { match self.simple_extensions.entry(simple_extension_uri.anchor()) { Entry::Occupied(_) => Err(ContextError::DuplicateSimpleExtension( simple_extension_uri.anchor(), @@ -98,8 +98,8 @@ pub(crate) mod fixtures { // This is where we would resolve and then parse. // This check shows the use of the unsupported uri error. if let "http" | "https" | "file" = simple_extension_uri.uri().scheme() { - let ext = entry - .insert(SimpleExtensions::empty(simple_extension_uri.uri().clone())); + let ext = + entry.insert(ExtensionFile::empty(simple_extension_uri.uri().clone())); // Here we just return an empty simple extensions. Ok(ext) } else { @@ -115,7 +115,7 @@ pub(crate) mod fixtures { fn simple_extensions( &self, anchor: &Anchor, - ) -> Result<&SimpleExtensions, ContextError> { + ) -> Result<&ExtensionFile, ContextError> { self.simple_extensions .get(anchor) .ok_or(ContextError::UndefinedSimpleExtension(*anchor)) diff --git a/src/parse/text/simple_extensions/extension.rs b/src/parse/text/simple_extensions/extension.rs deleted file mode 100644 index 3b1c1eb7..00000000 --- a/src/parse/text/simple_extensions/extension.rs +++ /dev/null @@ -1,79 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -//! Validated extension file wrapper for types. -//! -//! This module provides `ExtensionFile`, a validated wrapper around SimpleExtensions -//! that focuses on type definitions and provides safe type lookup methods. - -use std::collections::HashMap; -use thiserror::Error; -use url::Url; - -use crate::parse::Parse; -use crate::parse::text::simple_extensions::types::{CustomType, ExtensionTypeError}; -use crate::text::simple_extensions::SimpleExtensions; - -use super::context::ExtensionContext; - -/// Errors that can occur during extension type validation -#[derive(Debug, Error)] -pub enum ValidationError { - /// Extension type error - #[error("Extension type error: {0}")] - ExtensionTypeError(#[from] ExtensionTypeError), - /// Unresolved type reference in structure field - #[error("Type '{type_name}' referenced in '{originating}' structure not found")] - UnresolvedTypeReference { - /// The type name that could not be resolved - type_name: String, - /// The type that contains the unresolved reference - originating: String, - }, - /// Structure field cannot be nullable - #[error("Structure representation in type '{originating}' cannot be nullable")] - StructureCannotBeNullable { - /// The type that has a nullable structure - originating: String, - }, -} - -/// A validated extension file containing types from a single URI. -/// All types are parsed and validated on construction. -#[derive(Debug, Clone)] -pub struct ExtensionFile { - /// The URI this extension was loaded from - pub uri: Url, - /// Parsed and validated custom types - types: HashMap, -} - -impl ExtensionFile { - /// Create a validated extension file from raw data - pub fn create(uri: Url, extensions: SimpleExtensions) -> Result { - // Parse all types (may contain unresolved Extension(String) references) - let mut ctx = ExtensionContext::new(uri.clone()); - let mut types = HashMap::new(); - - for type_item in &extensions.types { - let custom_type = type_item.clone().parse(&mut ctx)?; - types.insert(custom_type.name.clone(), custom_type); - } - - // TODO: Validate that all Extension(String) references in structure - // fields exist Walk through all CustomType.structure fields and check - // that Extension(String) references can be resolved to actual types in - // the registry. - - Ok(Self { uri, types }) - } - - /// Get a type by name - pub fn get_type(&self, name: &str) -> Option<&CustomType> { - self.types.get(name) - } - - /// Get an iterator over all types in this extension - pub fn types(&self) -> impl Iterator { - self.types.values() - } -} \ No newline at end of file diff --git a/src/parse/text/simple_extensions/context.rs b/src/parse/text/simple_extensions/extensions.rs similarity index 71% rename from src/parse/text/simple_extensions/context.rs rename to src/parse/text/simple_extensions/extensions.rs index e88dd48b..5e3c59c7 100644 --- a/src/parse/text/simple_extensions/context.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -4,8 +4,6 @@ use std::collections::HashMap; -use url::Url; - use super::types::CustomType; use crate::parse::Context; @@ -13,26 +11,17 @@ use crate::parse::Context; /// /// The context provides access to types defined in the same extension file during parsing. /// This allows type references to be resolved within the same extension file. -#[derive(Debug)] -pub struct ExtensionContext { - /// The URI this extension is being loaded from - pub uri: Url, +#[derive(Debug, Default)] +pub struct SimpleExtensions { /// Types defined in this extension file types: HashMap, } -impl ExtensionContext { - /// Create a new extension context for the given URI - pub fn new(uri: Url) -> Self { - Self { - uri, - types: HashMap::new(), - } - } - +impl SimpleExtensions { /// Add a type to the context pub fn add_type(&mut self, custom_type: &CustomType) { - self.types.insert(custom_type.name.clone(), custom_type.clone()); + self.types + .insert(custom_type.name.clone(), custom_type.clone()); } /// Check if a type with the given name exists in the context @@ -51,6 +40,6 @@ impl ExtensionContext { } } -impl Context for ExtensionContext { +impl Context for SimpleExtensions { // ExtensionContext implements the Context trait -} \ No newline at end of file +} diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs new file mode 100644 index 00000000..f0e6fad2 --- /dev/null +++ b/src/parse/text/simple_extensions/file.rs @@ -0,0 +1,79 @@ +use url::Url; + +use super::{CustomType, SimpleExtensions, SimpleExtensionsError}; +use crate::parse::{Context, Parse}; +use crate::text; + +/// A parsed and validated [text::simple_extensions::SimpleExtensions]. +#[derive(Debug)] +pub struct ExtensionFile { + /// The URI this extension was loaded from + pub uri: Url, + /// The extension data containing types and eventually functions + extension: SimpleExtensions, +} + +impl ExtensionFile { + /// Create a new, empty SimpleExtensions + pub fn empty(uri: Url) -> Self { + Self { + uri, + extension: SimpleExtensions::default(), + } + } + + /// Create a validated SimpleExtensions from raw data and URI + pub fn create( + uri: Url, + extensions: text::simple_extensions::SimpleExtensions, + ) -> Result { + // Parse all types (may contain unresolved Extension(String) references) + let mut extension = SimpleExtensions::default(); + + for type_item in &extensions.types { + let custom_type = type_item.clone().parse(&mut extension)?; + // Add the parsed type to the context so later types can reference it + extension.add_type(&custom_type); + } + + // TODO: Validate that all Extension(String) references in structure + // fields exist Walk through all CustomType.structure fields and check + // that Extension(String) references can be resolved to actual types in + // the registry. + + Ok(Self { uri, extension }) + } + + /// Get a type by name + pub fn get_type(&self, name: &str) -> Option<&CustomType> { + self.extension.get_type(name) + } + + /// Get an iterator over all types in this extension + pub fn types(&self) -> impl Iterator { + self.extension.types() + } + + /// Get a reference to the underlying SimpleExtension + pub fn extension(&self) -> &SimpleExtensions { + &self.extension + } +} + +impl Parse for text::simple_extensions::SimpleExtensions { + type Parsed = ExtensionFile; + type Error = SimpleExtensionsError; + + fn parse(self, _ctx: &mut C) -> Result { + // For parsing without URI context, create a dummy URI + let dummy_uri = Url::parse("file:///unknown").unwrap(); + ExtensionFile::create(dummy_uri, self) + } +} + +impl From for text::simple_extensions::SimpleExtensions { + fn from(_value: ExtensionFile) -> Self { + // TODO: Implement conversion back to text representation + unimplemented!("Conversion from parsed SimpleExtensions back to text representation not yet implemented") + } +} diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index c9de0db5..9f0e7386 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -2,34 +2,19 @@ //! Parsing of [text::simple_extensions] types. -use std::collections::HashMap; use thiserror::Error; -use url::Url; - -use crate::{ - parse::{Context, Parse}, - text, -}; pub mod argument; -pub mod context; -pub mod extension; -pub mod registry; +mod extensions; +mod file; +mod registry; pub mod types; -pub use extension::ExtensionFile; +pub use extensions::SimpleExtensions; +pub use file::ExtensionFile; pub use registry::Registry; pub use types::{ConcreteType, CustomType, ExtensionTypeError}; -/// A parsed and validated [text::simple_extensions::SimpleExtensions]. -/// This replaces the TODO implementation with ExtensionFile functionality. -pub struct SimpleExtensions { - /// The URI this extension was loaded from - pub uri: Url, - /// Parsed and validated custom types - types: HashMap, -} - /// Parse errors for [text::simple_extensions::SimpleExtensions]. #[derive(Debug, Error)] pub enum SimpleExtensionsError { @@ -51,65 +36,3 @@ pub enum SimpleExtensionsError { originating: String, }, } - -impl SimpleExtensions { - /// Create a new, empty SimpleExtensions - pub fn empty(uri: Url) -> Self { - Self { - uri, - types: HashMap::new(), - } - } - - /// Create a validated SimpleExtensions from raw data and URI - pub fn create( - uri: Url, - extensions: text::simple_extensions::SimpleExtensions, - ) -> Result { - // Parse all types (may contain unresolved Extension(String) references) - let mut ctx = context::ExtensionContext::new(uri.clone()); - let mut types = HashMap::new(); - - for type_item in &extensions.types { - let custom_type = type_item.clone().parse(&mut ctx)?; - // Add the parsed type to the context so later types can reference it - ctx.add_type(&custom_type); - types.insert(custom_type.name.clone(), custom_type); - } - - // TODO: Validate that all Extension(String) references in structure - // fields exist Walk through all CustomType.structure fields and check - // that Extension(String) references can be resolved to actual types in - // the registry. - - Ok(Self { uri, types }) - } - - /// Get a type by name - pub fn get_type(&self, name: &str) -> Option<&CustomType> { - self.types.get(name) - } - - /// Get an iterator over all types in this extension - pub fn types(&self) -> impl Iterator { - self.types.values() - } -} - -impl Parse for text::simple_extensions::SimpleExtensions { - type Parsed = SimpleExtensions; - type Error = SimpleExtensionsError; - - fn parse(self, _ctx: &mut C) -> Result { - // For parsing without URI context, create a dummy URI - let dummy_uri = Url::parse("file:///unknown").unwrap(); - SimpleExtensions::create(dummy_uri, self) - } -} - -impl From for text::simple_extensions::SimpleExtensions { - fn from(_value: SimpleExtensions) -> Self { - // TODO: Implement conversion back to text representation - unimplemented!("Conversion from parsed SimpleExtensions back to text representation not yet implemented") - } -} diff --git a/src/parse/text/simple_extensions/registry.rs b/src/parse/text/simple_extensions/registry.rs index 9f1ffa15..a019b665 100644 --- a/src/parse/text/simple_extensions/registry.rs +++ b/src/parse/text/simple_extensions/registry.rs @@ -14,7 +14,7 @@ use url::Url; -use super::{types::CustomType, extension::ExtensionFile}; +use super::{types::CustomType, ExtensionFile}; /// Extension Registry that manages Substrait extensions /// @@ -46,7 +46,7 @@ impl Registry { // Force evaluation of core extensions LazyLock::force(&EXTENSIONS); - // Convert HashMap to Vec + // Convert HashMap to Vec let extensions: Vec = EXTENSIONS .iter() .map(|(uri, simple_extensions)| { @@ -72,8 +72,10 @@ impl Registry { #[cfg(test)] mod tests { - use super::*; - use crate::text::simple_extensions::*; + use super::ExtensionFile as ParsedSimpleExtensions; + use super::Registry; + use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem}; + use url::Url; fn create_test_extension_with_types() -> SimpleExtensions { SimpleExtensions { @@ -96,7 +98,8 @@ mod tests { fn test_new_registry() { let uri = Url::parse("https://example.com/test.yaml").unwrap(); let extension_file = - ExtensionFile::create(uri.clone(), create_test_extension_with_types()).unwrap(); + ParsedSimpleExtensions::create(uri.clone(), create_test_extension_with_types()) + .unwrap(); let extensions = vec![extension_file]; let registry = Registry::new(extensions); @@ -109,7 +112,8 @@ mod tests { fn test_type_lookup() { let uri = Url::parse("https://example.com/test.yaml").unwrap(); let extension_file = - ExtensionFile::create(uri.clone(), create_test_extension_with_types()).unwrap(); + ParsedSimpleExtensions::create(uri.clone(), create_test_extension_with_types()) + .unwrap(); let extensions = vec![extension_file]; let registry = Registry::new(extensions); @@ -151,4 +155,4 @@ mod tests { let unknown_type_via_registry = registry.get_type(&unknown_extension.uri, "unknown"); assert!(unknown_type_via_registry.is_some()); } -} \ No newline at end of file +} diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 38765bbd..f0a973e5 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -5,7 +5,7 @@ //! This module provides a clean, type-safe wrapper around Substrait extension types, //! separating function signature patterns from concrete argument types. -use super::context::ExtensionContext; +use super::extensions::SimpleExtensions; use crate::parse::Parse; use crate::text::simple_extensions::{ EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefsItem, @@ -367,11 +367,11 @@ impl From for SimpleExtensionsTypesItem { } } -impl Parse for SimpleExtensionsTypesItem { +impl Parse for SimpleExtensionsTypesItem { type Parsed = CustomType; type Error = ExtensionTypeError; - fn parse(self, _ctx: &mut ExtensionContext) -> Result { + fn parse(self, _ctx: &mut SimpleExtensions) -> Result { let name = self.name; CustomType::validate_name(&name) .map_err(|InvalidTypeName(name)| ExtensionTypeError::InvalidName { name })?; @@ -415,10 +415,8 @@ impl TryFrom for ConcreteType { // Structure representation cannot be nullable if concrete_type.nullable { - return Err(ExtensionTypeError::InvalidName { - name: format!( - "Structure representation '{type_string}' cannot be nullable" - ), + return Err(ExtensionTypeError::StructureCannotBeNullable { + type_string: type_string, }); } @@ -579,6 +577,17 @@ impl ConcreteType { } } + /// Create a new struct type + pub fn r#struct(element_type: ConcreteType, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::Struct { + field_names: vec!["field1".into(), "field2".into()], + field_types: vec![element_type], + }, + nullable, + } + } + /// Create a new map type pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType { ConcreteType { @@ -813,11 +822,10 @@ impl TypeBindings { #[cfg(test)] mod tests { - use super::super::context::ExtensionContext; + use super::super::extensions::SimpleExtensions; use super::*; use crate::text; use crate::text::simple_extensions; - use url::Url; #[test] fn test_builtin_type_parsing() { @@ -1013,7 +1021,7 @@ mod tests { variadic: None, }; - let mut ctx = ExtensionContext::new(Url::parse("https://example.com/test.yaml").unwrap()); + let mut ctx = SimpleExtensions::default(); let custom_type = type_item.parse(&mut ctx)?; assert_eq!(custom_type.name, "TestType"); assert_eq!(custom_type.description, Some("A test type".to_string())); @@ -1049,7 +1057,7 @@ mod tests { variadic: None, }; - let mut ctx = ExtensionContext::new(Url::parse("https://example.com/test.yaml").unwrap()); + let mut ctx = SimpleExtensions::default(); let custom_type = type_item.parse(&mut ctx)?; assert_eq!(custom_type.name, "Point"); From 6dc4302e75d1db7e766f796d02ca6f9b6f602d65 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Mon, 18 Aug 2025 17:24:23 -0400 Subject: [PATCH 10/31] now looks better --- src/parse/text/simple_extensions/types.rs | 175 +++++++++++++++------- 1 file changed, 122 insertions(+), 53 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index f0a973e5..ebfbe017 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -8,7 +8,7 @@ use super::extensions::SimpleExtensions; use crate::parse::Parse; use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefsItem, + EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefs, TypeParamDefsItem, TypeParamDefsItemType, }; use serde_json::Value; @@ -16,7 +16,7 @@ use std::collections::HashMap; use std::str::FromStr; use thiserror::Error; -/// Substrait built-in primitive types +/// Substrait built-in primitive types (no parameters required) #[derive(Clone, Debug, PartialEq, Eq)] pub enum BuiltinType { /// Boolean type - `bool` @@ -39,28 +39,79 @@ pub enum BuiltinType { Binary, /// Calendar date - `date` Date, - /// Time of day - `time` (deprecated, use precision_time) + /// Time of day - `time` (deprecated, use CompoundType::PrecisionTime) Time, - /// Date and time - `timestamp` (deprecated, use precision_timestamp) + /// Date and time - `timestamp` (deprecated, use CompoundType::PrecisionTimestamp) Timestamp, - /// Date and time with timezone - `timestamp_tz` (deprecated, use precision_timestamp_tz) + /// Date and time with timezone - `timestamp_tz` (deprecated, use CompoundType::PrecisionTimestampTz) TimestampTz, /// Year-month interval - `interval_year` IntervalYear, - /// Day-time interval - `interval_day` - IntervalDay, /// 128-bit UUID - `uuid` Uuid, - /// Fixed-length decimal - `decimal` - Decimal, - /// Variable-length decimal - `decimal` - PrecisionDecimal, - /// Time with precision - `precision_time` - PrecisionTime, - /// Timestamp with precision - `precision_timestamp` - PrecisionTimestamp, - /// Timestamp with timezone and precision - `precision_timestamp_tz` - PrecisionTimestampTz, +} + +/// Parameter for parameterized types +#[derive(Clone, Debug, PartialEq)] +pub enum TypeParameter { + /// Integer parameter (e.g., precision, scale) + Integer(i64), + /// Type parameter (nested type) + Type(Box), + /// String parameter + String(String), +} + +/// Parameterized builtin types that require parameters +#[derive(Clone, Debug, PartialEq)] +pub enum CompoundType { + /// Fixed-length character string FIXEDCHAR + FixedChar { + /// Length (number of characters), must be >= 1 + length: i32, + }, + /// Variable-length character string VARCHAR + VarChar { + /// Maximum length (number of characters), must be >= 1 + length: i32, + }, + /// Fixed-length binary data FIXEDBINARY + FixedBinary { + /// Length (number of bytes), must be >= 1 + length: i32, + }, + /// Fixed-point decimal DECIMAL + Decimal { + /// Precision (total digits), <= 38 + precision: i32, + /// Scale (digits after decimal point), 0 <= S <= P + scale: i32, + }, + /// Time with sub-second precision PRECISIONTIME

+ PrecisionTime { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Timestamp with sub-second precision PRECISIONTIMESTAMP

+ PrecisionTimestamp { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Timezone-aware timestamp with precision PRECISIONTIMESTAMPTZ

+ PrecisionTimestampTz { + /// Sub-second precision digits (0-12: seconds to picoseconds) + precision: i32, + }, + /// Day-time interval INTERVAL_DAY

+ IntervalDay { + /// Sub-second precision digits (0-9: seconds to nanoseconds) + precision: i32, + }, + /// Compound interval INTERVAL_COMPOUND

+ IntervalCompound { + /// Sub-second precision digits + precision: i32, + }, } /// Error when a builtin type name is not recognized @@ -87,13 +138,7 @@ impl FromStr for BuiltinType { "timestamp" => Ok(BuiltinType::Timestamp), "timestamp_tz" => Ok(BuiltinType::TimestampTz), "interval_year" => Ok(BuiltinType::IntervalYear), - "interval_day" => Ok(BuiltinType::IntervalDay), "uuid" => Ok(BuiltinType::Uuid), - "decimal" => Ok(BuiltinType::Decimal), - "precision_decimal" => Ok(BuiltinType::PrecisionDecimal), - "precision_time" => Ok(BuiltinType::PrecisionTime), - "precision_timestamp" => Ok(BuiltinType::PrecisionTimestamp), - "precision_timestamp_tz" => Ok(BuiltinType::PrecisionTimestampTz), _ => Err(UnrecognizedBuiltin(s.to_string())), } } @@ -330,7 +375,7 @@ impl From for SimpleExtensionsTypesItem { let parameters = if value.parameters.is_empty() { None } else { - Some(crate::text::simple_extensions::TypeParamDefs( + Some(TypeParamDefs( value .parameters .into_iter() @@ -446,7 +491,7 @@ impl TryFrom for ConcreteType { } Ok(ConcreteType { - known_type: KnownType::Struct { + known_type: KnownType::NamedStruct { field_names, field_types, }, @@ -519,10 +564,17 @@ pub enum FunctionCallError { /// Known Substrait types (builtin + extension references) #[derive(Clone, Debug, PartialEq)] pub enum KnownType { - /// Built-in Substrait primitive type + /// Simple built-in Substrait primitive type (no parameters) Builtin(BuiltinType), - /// Reference to an extension type by name - Extension(String), + /// Parameterized built-in types + Compound(CompoundType), + /// Extension type with optional parameters + Extension { + /// Extension type name + name: String, + /// Type parameters + parameters: Vec, + }, /// List type with element type List(Box), /// Map type with key and value types @@ -532,11 +584,13 @@ pub enum KnownType { /// Value type value: Box, }, - /// Struct type with named fields - Struct { + /// Struct type (ordered fields without names) + Struct(Vec), + /// Named struct type (nstruct - ordered fields with names) + NamedStruct { /// Field names field_names: Vec, - /// Field types + /// Field types (same order as field_names) field_types: Vec, }, /// Type variable (e.g., any1, any2) @@ -561,10 +615,33 @@ impl ConcreteType { } } - /// Create a new extension type reference + /// Create a new compound (parameterized) type + pub fn compound(compound_type: CompoundType, nullable: bool) -> ConcreteType { + ConcreteType { + known_type: KnownType::Compound(compound_type), + nullable, + } + } + + /// Create a new extension type reference (without parameters) pub fn extension(name: String, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Extension(name), + known_type: KnownType::Extension { + name, + parameters: Vec::new(), + }, + nullable, + } + } + + /// Create a new parameterized extension type + pub fn extension_with_params( + name: String, + parameters: Vec, + nullable: bool, + ) -> ConcreteType { + ConcreteType { + known_type: KnownType::Extension { name, parameters }, nullable, } } @@ -577,13 +654,10 @@ impl ConcreteType { } } - /// Create a new struct type - pub fn r#struct(element_type: ConcreteType, nullable: bool) -> ConcreteType { + /// Create a new struct type (ordered fields without names) + pub fn r#struct(field_types: Vec, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Struct { - field_names: vec!["field1".into(), "field2".into()], - field_types: vec![element_type], - }, + known_type: KnownType::Struct(field_types), nullable, } } @@ -599,14 +673,14 @@ impl ConcreteType { } } - /// Create a new struct type - pub fn nstruct( + /// Create a new named struct type (nstruct - ordered fields with names) + pub fn named_struct( field_names: Vec, field_types: Vec, nullable: bool, ) -> ConcreteType { ConcreteType { - known_type: KnownType::Struct { + known_type: KnownType::NamedStruct { field_names, field_types, }, @@ -664,15 +738,11 @@ impl<'a> TryFrom> for ConcreteType { )) } ParsedType::Struct(field_types, nullability) => { - let field_names: Vec = (0..field_types.len()) - .map(|i| format!("field{}", i)) - .collect(); let concrete_field_types: Result, _> = field_types .into_iter() .map(ConcreteType::try_from) .collect(); - Ok(ConcreteType::nstruct( - field_names, + Ok(ConcreteType::r#struct( concrete_field_types?, nullability.unwrap_or(false), )) @@ -991,7 +1061,7 @@ mod tests { let ext_type = text::simple_extensions::Type::Variant1(field_map); let concrete = ConcreteType::try_from(ext_type)?; - if let KnownType::Struct { + if let KnownType::NamedStruct { field_names, field_types, } = &concrete.known_type @@ -1003,7 +1073,7 @@ mod tests { ConcreteType::builtin(BuiltinType::Fp64, false) ); } else { - panic!("Expected struct type"); + panic!("Expected named struct type"); } Ok(()) @@ -1063,7 +1133,7 @@ mod tests { if let Some(ConcreteType { known_type: - KnownType::Struct { + KnownType::NamedStruct { field_names, field_types, }, @@ -1088,9 +1158,8 @@ mod tests { fn test_nullable_structure_rejected() { let ext_type = text::simple_extensions::Type::Variant0("i32?".to_string()); let result = ConcreteType::try_from(ext_type); - if let Err(ExtensionTypeError::InvalidName { name }) = result { - assert!(name.contains("Structure representation")); - assert!(name.contains("cannot be nullable")); + if let Err(ExtensionTypeError::StructureCannotBeNullable { type_string }) = result { + assert!(type_string.contains("i32?")); } else { panic!( "Expected nullable structure to be rejected, got: {:?}", From 9c593a109667b2ceefb5a2fd8ea18ca86ee3ad6c Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Tue, 19 Aug 2025 17:06:04 -0400 Subject: [PATCH 11/31] On the way, trying to get validation working --- src/parse/context.rs | 9 +- .../text/simple_extensions/extensions.rs | 68 ++- src/parse/text/simple_extensions/file.rs | 122 +++- src/parse/text/simple_extensions/mod.rs | 25 +- src/parse/text/simple_extensions/registry.rs | 2 +- src/parse/text/simple_extensions/types.rs | 533 +++++++++++++----- 6 files changed, 575 insertions(+), 184 deletions(-) diff --git a/src/parse/context.rs b/src/parse/context.rs index 1b56346e..11d1311a 100644 --- a/src/parse/context.rs +++ b/src/parse/context.rs @@ -71,18 +71,11 @@ pub(crate) mod fixtures { /// /// This currently mocks support for simple extensions (does not resolve or /// parse). + #[derive(Default)] pub struct Context { simple_extensions: HashMap, ExtensionFile>, } - impl Default for Context { - fn default() -> Self { - Self { - simple_extensions: Default::default(), - } - } - } - impl super::Context for Context {} impl super::ProtoContext for Context { diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 5e3c59c7..75eb0702 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -2,10 +2,13 @@ //! Parsing context for extension processing. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use super::types::CustomType; -use crate::parse::Context; +use crate::{ + parse::{Context, Parse}, + text::simple_extensions::SimpleExtensions as RawExtensions, +}; /// Parsing context for extension processing /// @@ -40,6 +43,65 @@ impl SimpleExtensions { } } -impl Context for SimpleExtensions { +/// A context for parsing simple extensions. +#[derive(Debug, Default)] +pub struct TypeContext { + /// Types that have been seen so far, not yet resolved. + known: HashSet, + /// Types that have been linked to, not yet resolved. + linked: HashSet, +} + +impl TypeContext { + /// Mark a type as found + pub fn found(&mut self, name: &str) { + self.linked.remove(name); + self.known.insert(name.to_string()); + } + + /// Mark a type as linked to - some other type or function references it, but we haven't seen it. + pub fn linked(&mut self, name: &str) { + if !self.known.contains(name) { + self.linked.insert(name.to_string()); + } + } +} + +impl Context for TypeContext { // ExtensionContext implements the Context trait } + +// Implement parsing for the raw text representation to produce an `ExtensionFile`. +impl Parse for RawExtensions { + type Parsed = SimpleExtensions; + type Error = super::SimpleExtensionsError; + + fn parse(self, ctx: &mut TypeContext) -> Result { + let mut extension = SimpleExtensions::default(); + + for type_item in self.types { + let custom_type = Parse::parse(type_item, ctx)?; + // Add the parsed type to the context so later types can reference it + extension.add_type(&custom_type); + } + + Ok(extension) + } +} + +// Implement conversion from parsed form back to raw text representation. +impl From for RawExtensions { + fn from(value: SimpleExtensions) -> Self { + // Minimal types-only conversion to satisfy tests + let types = value.types().cloned().map(Into::into).collect(); + RawExtensions { + types, + // TODO: Implement conversion back to raw representation + aggregate_functions: vec![], + dependencies: HashMap::new(), + scalar_functions: vec![], + type_variations: vec![], + window_functions: vec![], + } + } +} diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index f0e6fad2..88c027cc 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -1,10 +1,12 @@ use url::Url; use super::{CustomType, SimpleExtensions, SimpleExtensionsError}; -use crate::parse::{Context, Parse}; -use crate::text; +use crate::parse::text::simple_extensions::extensions::TypeContext; +use crate::parse::Parse; +use crate::text::simple_extensions::SimpleExtensions as RawExtensions; +use std::io::Read; -/// A parsed and validated [text::simple_extensions::SimpleExtensions]. +/// A parsed and validated [RawExtensions]. #[derive(Debug)] pub struct ExtensionFile { /// The URI this extension was loaded from @@ -23,23 +25,12 @@ impl ExtensionFile { } /// Create a validated SimpleExtensions from raw data and URI - pub fn create( - uri: Url, - extensions: text::simple_extensions::SimpleExtensions, - ) -> Result { + pub fn create(uri: Url, extensions: RawExtensions) -> Result { // Parse all types (may contain unresolved Extension(String) references) - let mut extension = SimpleExtensions::default(); + let mut ctx = TypeContext::default(); + let extension = Parse::parse(extensions, &mut ctx)?; - for type_item in &extensions.types { - let custom_type = type_item.clone().parse(&mut extension)?; - // Add the parsed type to the context so later types can reference it - extension.add_type(&custom_type); - } - - // TODO: Validate that all Extension(String) references in structure - // fields exist Walk through all CustomType.structure fields and check - // that Extension(String) references can be resolved to actual types in - // the registry. + // TODO: Use ctx.known/ctx.linked to validate unresolved references and cross-file links. Ok(Self { uri, extension }) } @@ -58,22 +49,93 @@ impl ExtensionFile { pub fn extension(&self) -> &SimpleExtensions { &self.extension } -} -impl Parse for text::simple_extensions::SimpleExtensions { - type Parsed = ExtensionFile; - type Error = SimpleExtensionsError; + /// Read an extension file from a reader and a URI string. + /// + /// - `uri_str`: a string that parses to Url (e.g., file:///...) used to tag the extension + /// - `reader`: any `Read` instance with the YAML content + /// + /// Returns a parsed and validated `ExtensionFile` or an error. + #[cfg(feature = "extensions")] + pub fn read, R: Read>(uri: U, reader: R) -> Result + where + SimpleExtensionsError: From, + { + let raw: RawExtensions = serde_yaml::from_reader(reader)?; + let uri = uri.try_into()?; + Self::create(uri, raw) + } - fn parse(self, _ctx: &mut C) -> Result { - // For parsing without URI context, create a dummy URI - let dummy_uri = Url::parse("file:///unknown").unwrap(); - ExtensionFile::create(dummy_uri, self) + /// Read an extension file from a string slice. + #[cfg(feature = "extensions")] + pub fn read_from_str, S: AsRef>( + uri: U, + s: S, + ) -> Result + where + SimpleExtensionsError: From, + { + let raw: RawExtensions = serde_yaml::from_str(s.as_ref())?; + let uri = uri.try_into()?; + Self::create(uri, raw) } } -impl From for text::simple_extensions::SimpleExtensions { - fn from(_value: ExtensionFile) -> Self { - // TODO: Implement conversion back to text representation - unimplemented!("Conversion from parsed SimpleExtensions back to text representation not yet implemented") +// Parsing and conversion implementations are defined on `SimpleExtensions` in `extensions.rs`. + +#[cfg(test)] +mod tests { + use crate::{ + parse::text::simple_extensions::types::ParameterType as RawParameterType, + text::simple_extensions::SimpleExtensions as RawExtensions, + }; + + use super::*; + + #[test] + fn yaml_round_trip_integer_param_bounds() { + // A minimal YAML extension file with a single type that has integer bounds on a parameter + let yaml = r#" +%YAML 1.2 +--- +types: + - name: "ParamTest" + parameters: + - name: "K" + type: integer + min: 1 + max: 10 +"#; + + let ext = ExtensionFile::read_from_str("file:///param_test.yaml", yaml).expect("parse ok"); + + // Validate parsed parameter bounds + let ty = ext.get_type("ParamTest").expect("type exists"); + assert_eq!(ty.parameters.len(), 1); + match &ty.parameters[0].param_type { + RawParameterType::Integer { min, max } => { + assert_eq!(min, &Some(1)); + assert_eq!(max, &Some(10)); + } + other => panic!("unexpected param type: {other:?}"), + } + + // Convert back to text::simple_extensions and assert bounds are preserved as f64 + let back: RawExtensions = ext.extension.into(); + let item = back + .types + .into_iter() + .find(|t| t.name == "ParamTest") + .expect("round-tripped type present"); + let param_defs = item.parameters.expect("params present"); + assert_eq!(param_defs.0.len(), 1); + let p = ¶m_defs.0[0]; + assert_eq!(p.name.as_deref(), Some("K")); + assert!(matches!( + p.type_, + crate::text::simple_extensions::TypeParamDefsItemType::Integer + )); + assert_eq!(p.min, Some(1.0)); + assert_eq!(p.max, Some(10.0)); } } diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index 9f0e7386..1b0f73b5 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -2,6 +2,8 @@ //! Parsing of [text::simple_extensions] types. +use std::convert::Infallible; + use thiserror::Error; pub mod argument; @@ -11,6 +13,7 @@ mod registry; pub mod types; pub use extensions::SimpleExtensions; +pub use extensions::TypeContext; pub use file::ExtensionFile; pub use registry::Registry; pub use types::{ConcreteType, CustomType, ExtensionTypeError}; @@ -21,6 +24,15 @@ pub enum SimpleExtensionsError { /// Extension type error #[error("Extension type error: {0}")] ExtensionTypeError(#[from] ExtensionTypeError), + /// Failed to parse SimpleExtensions YAML + #[error("YAML parse error: {0}")] + YamlParse(#[from] serde_yaml::Error), + /// I/O error while reading extension content + #[error("io error: {0}")] + Io(#[from] std::io::Error), + /// Invalid URI provided + #[error("invalid URI: {0}")] + InvalidUri(#[from] url::ParseError), /// Unresolved type reference in structure field #[error("Type '{type_name}' referenced in '{originating}' structure not found")] UnresolvedTypeReference { @@ -29,10 +41,11 @@ pub enum SimpleExtensionsError { /// The type that contains the unresolved reference originating: String, }, - /// Structure field cannot be nullable - #[error("Structure representation in type '{originating}' cannot be nullable")] - StructureCannotBeNullable { - /// The type that has a nullable structure - originating: String, - }, +} + +// Needed for certain conversions - e.g. Url -> Url - to succeed. +impl From for SimpleExtensionsError { + fn from(_: Infallible) -> Self { + unreachable!() + } } diff --git a/src/parse/text/simple_extensions/registry.rs b/src/parse/text/simple_extensions/registry.rs index a019b665..2564a492 100644 --- a/src/parse/text/simple_extensions/registry.rs +++ b/src/parse/text/simple_extensions/registry.rs @@ -142,7 +142,7 @@ mod tests { // Find the unknown.yaml extension dynamically let unknown_extension = registry .extensions() - .find(|ext| ext.uri.path_segments().map(|s| s.last()) == Some(Some("unknown.yaml"))) + .find(|ext| ext.uri.path_segments().map(|mut s| s.next_back()) == Some(Some("unknown.yaml"))) .expect("Should find unknown.yaml extension"); let unknown_type = unknown_extension.get_type("unknown"); diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index ebfbe017..115940ca 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -5,17 +5,60 @@ //! This module provides a clean, type-safe wrapper around Substrait extension types, //! separating function signature patterns from concrete argument types. -use super::extensions::SimpleExtensions; +use super::extensions::TypeContext; use crate::parse::Parse; use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, Type as ExtType, TypeParamDefs, TypeParamDefsItem, + EnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs, TypeParamDefsItem, TypeParamDefsItemType, }; use serde_json::Value; use std::collections::HashMap; +use std::fmt; use std::str::FromStr; use thiserror::Error; +/// Write a sequence of items separated by a separator, with a start and end +/// delimiter. +/// +/// Start and end are only included in the output if there is at least one item. +fn write_separated( + f: &mut fmt::Formatter<'_>, + iter: I, + start: &str, + end: &str, + sep: &str, +) -> fmt::Result +where + I: IntoIterator, + T: fmt::Display, +{ + let mut it = iter.into_iter(); + if let Some(first) = it.next() { + f.write_str(start)?; + write!(f, "{first}")?; + for item in it { + f.write_str(sep)?; + write!(f, "{item}")?; + } + f.write_str(end) + } else { + Ok(()) + } +} + +/// A pair of a key and a value, separated by a separator. For display purposes. +struct KeyValueDisplay(K, V, &'static str); + +impl fmt::Display for KeyValueDisplay +where + K: fmt::Display, + V: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}{}{}", self.0, self.2, self.1) + } +} + /// Substrait built-in primitive types (no parameters required) #[derive(Clone, Debug, PartialEq, Eq)] pub enum BuiltinType { @@ -51,18 +94,51 @@ pub enum BuiltinType { Uuid, } +impl fmt::Display for BuiltinType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + BuiltinType::Boolean => "bool", + BuiltinType::I8 => "i8", + BuiltinType::I16 => "i16", + BuiltinType::I32 => "i32", + BuiltinType::I64 => "i64", + BuiltinType::Fp32 => "fp32", + BuiltinType::Fp64 => "fp64", + BuiltinType::String => "string", + BuiltinType::Binary => "binary", + BuiltinType::Date => "date", + BuiltinType::Time => "time", + BuiltinType::Timestamp => "timestamp", + BuiltinType::TimestampTz => "timestamp_tz", + BuiltinType::IntervalYear => "interval_year", + BuiltinType::Uuid => "uuid", + }; + f.write_str(s) + } +} + /// Parameter for parameterized types #[derive(Clone, Debug, PartialEq)] pub enum TypeParameter { /// Integer parameter (e.g., precision, scale) Integer(i64), /// Type parameter (nested type) - Type(Box), + Type(ConcreteType), /// String parameter String(String), } -/// Parameterized builtin types that require parameters +impl fmt::Display for TypeParameter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypeParameter::Integer(i) => write!(f, "{i}"), + TypeParameter::Type(t) => write!(f, "{t}"), + TypeParameter::String(s) => write!(f, "{s}"), + } + } +} + +/// Parameterized builtin types that require non-type parameters #[derive(Clone, Debug, PartialEq)] pub enum CompoundType { /// Fixed-length character string FIXEDCHAR @@ -114,6 +190,30 @@ pub enum CompoundType { }, } +impl fmt::Display for CompoundType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CompoundType::FixedChar { length } => write!(f, "FIXEDCHAR<{length}>"), + CompoundType::VarChar { length } => write!(f, "VARCHAR<{length}>"), + CompoundType::FixedBinary { length } => write!(f, "FIXEDBINARY<{length}>"), + CompoundType::Decimal { precision, scale } => { + write!(f, "DECIMAL<{precision}, {scale}>") + } + CompoundType::PrecisionTime { precision } => write!(f, "PRECISIONTIME<{precision}>"), + CompoundType::PrecisionTimestamp { precision } => { + write!(f, "PRECISIONTIMESTAMP<{precision}>") + } + CompoundType::PrecisionTimestampTz { precision } => { + write!(f, "PRECISIONTIMESTAMPTZ<{precision}>") + } + CompoundType::IntervalDay { precision } => write!(f, "INTERVAL_DAY<{precision}>"), + CompoundType::IntervalCompound { precision } => { + write!(f, "INTERVAL_COMPOUND<{precision}>") + } + } + } +} + /// Error when a builtin type name is not recognized #[derive(Debug, Error)] #[error("Unrecognized builtin type: {0}")] @@ -201,8 +301,7 @@ impl ParameterType { (ParameterType::DataType, Value::String(_)) => true, (ParameterType::Integer { min, max }, Value::Number(n)) => { if let Some(i) = n.as_i64() { - min.map_or(true, |min_val| i >= min_val) - && max.map_or(true, |max_val| i <= max_val) + min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val) } else { false } @@ -214,17 +313,24 @@ impl ParameterType { } } - fn from_yaml( + fn from_raw( t: TypeParamDefsItemType, opts: Option, + min: Option, + max: Option, ) -> Result { Ok(match t { TypeParamDefsItemType::DataType => Self::DataType, TypeParamDefsItemType::Boolean => Self::Boolean, - TypeParamDefsItemType::Integer => Self::Integer { - min: None, - max: None, - }, + TypeParamDefsItemType::Integer => { + // TODO: This truncates from float to int; probably fine + let min_i = min.map(|n| n as i64); + let max_i = max.map(|n| n as i64); + Self::Integer { + min: min_i, + max: max_i, + } + } TypeParamDefsItemType::Enumeration => { let options = opts.ok_or(TypeParamError::MissingEnumOptions)?.0; // Extract Vec from EnumOptions Self::Enumeration { options } @@ -266,7 +372,7 @@ impl TryFrom for TypeParam { fn try_from(item: TypeParamDefsItem) -> Result { let name = item.name.ok_or(TypeParamError::MissingName)?; - let param_type = ParameterType::from_yaml(item.type_, item.options)?; + let param_type = ParameterType::from_raw(item.type_, item.options, item.min, item.max)?; Ok(Self { name, @@ -280,10 +386,15 @@ impl TryFrom for TypeParam { #[derive(Debug, Error, PartialEq)] pub enum ExtensionTypeError { /// Extension type name is invalid - #[error("Invalid extension type name: {name}")] - InvalidName { - /// The invalid name - name: String, + #[error("{0}")] + InvalidName(#[from] InvalidTypeName), + /// Any type variable is invalid for concrete types + #[error("Any type variable is invalid for concrete types: any{}{}", id, nullability.then_some("?").unwrap_or(""))] + InvalidAnyTypeVariable { + /// The type variable name + id: u32, + /// Whether the type variable is nullable + nullability: bool, }, /// Parameter validation failed #[error("Invalid parameter: {0}")] @@ -356,8 +467,7 @@ impl CustomType { variadic: Option, description: Option, ) -> Result { - Self::validate_name(&name) - .map_err(|InvalidTypeName(name)| ExtensionTypeError::InvalidName { name })?; + Self::validate_name(&name)?; Ok(Self { name, @@ -395,12 +505,8 @@ impl From for SimpleExtensionsTypesItem { )) }; - // Convert structure back to Type if any - this is a simplified implementation - let structure = value.structure.map(|_concrete_type| { - // TODO: Implement proper conversion from ConcreteType back to ExtType - // For now, use a placeholder - ExtType::Variant0("placeholder_structure".to_string()) - }); + // Convert structure back to Type if any + let structure = value.structure.map(Into::into); SimpleExtensionsTypesItem { name: value.name, @@ -412,81 +518,106 @@ impl From for SimpleExtensionsTypesItem { } } -impl Parse for SimpleExtensionsTypesItem { +impl Parse for SimpleExtensionsTypesItem { type Parsed = CustomType; type Error = ExtensionTypeError; - fn parse(self, _ctx: &mut SimpleExtensions) -> Result { + fn parse(self, ctx: &mut TypeContext) -> Result { let name = self.name; - CustomType::validate_name(&name) - .map_err(|InvalidTypeName(name)| ExtensionTypeError::InvalidName { name })?; + CustomType::validate_name(&name)?; + + // Register this type as found + ctx.found(&name); let parameters = if let Some(param_defs) = self.parameters { param_defs .0 .into_iter() - .map(|param| TypeParam::try_from(param)) + .map(TypeParam::try_from) .collect::, _>>()? } else { Vec::new() }; + // Parse structure with context, so referenced extension types are recorded as linked let structure = match self.structure { - Some(structure_data) => Some(ConcreteType::try_from(structure_data)?), + Some(structure_data) => Some(Parse::parse(structure_data, ctx)?), None => None, }; - let custom_type = CustomType { + Ok(CustomType { name, parameters, structure, variadic: self.variadic, description: self.description, - }; - - Ok(custom_type) + }) } } -impl TryFrom for ConcreteType { +impl Parse for RawType { + type Parsed = ConcreteType; type Error = ExtensionTypeError; - fn try_from(ext_type: ExtType) -> Result { - match ext_type { - // Case: structure: "BINARY" (alias to another type) - ExtType::Variant0(type_string) => { + fn parse(self, ctx: &mut TypeContext) -> Result { + // Walk a ParsedType and record extension type references in the context + fn visit_parsed_type_references<'a, F>(p: &'a ParsedType<'a>, on_ext: &mut F) + where + F: FnMut(&'a str), + { + match p { + ParsedType::Extension(name, _) | ParsedType::NamedExtension(name, _) => { + on_ext(name) + } + ParsedType::List(t, _) => visit_parsed_type_references(t, on_ext), + ParsedType::Map(k, v, _) => { + visit_parsed_type_references(k, on_ext); + visit_parsed_type_references(v, on_ext); + } + ParsedType::Struct(ts, _) => { + for t in ts { + visit_parsed_type_references(t, on_ext) + } + } + ParsedType::Builtin(..) | ParsedType::TypeVariable(..) => {} + } + } + + match self { + RawType::Variant0(type_string) => { let parsed_type = ParsedType::parse(&type_string); - let concrete_type = ConcreteType::try_from(parsed_type)?; + let mut link = |name: &str| ctx.linked(name); + visit_parsed_type_references(&parsed_type, &mut link); + let concrete = ConcreteType::try_from(parsed_type)?; // Structure representation cannot be nullable - if concrete_type.nullable { - return Err(ExtensionTypeError::StructureCannotBeNullable { - type_string: type_string, - }); + if concrete.nullable { + return Err(ExtensionTypeError::StructureCannotBeNullable { type_string }); } - Ok(concrete_type) + Ok(concrete) } - // Case: structure: { field1: type1, field2: type2 } (named struct) - ExtType::Variant1(field_map) => { + RawType::Variant1(field_map) => { let mut field_names = Vec::new(); let mut field_types = Vec::new(); for (field_name, field_type_value) in field_map { field_names.push(field_name); - // field_type_value is serde_json::Value, need to extract string let type_string = match field_type_value { - Value::String(type_str) => type_str, + serde_json::Value::String(s) => s, _ => { return Err(ExtensionTypeError::InvalidFieldType( "Struct field types must be strings".to_string(), - )); + )) } }; - let parsed_field_type = ParsedType::parse(&type_string); + let parsed_field_type = ParsedType::parse(&type_string); + let mut link = |name: &str| ctx.linked(name); + visit_parsed_type_references(&parsed_field_type, &mut link); let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; + field_types.push(field_concrete_type); } @@ -495,7 +626,7 @@ impl TryFrom for ConcreteType { field_names, field_types, }, - nullable: false, // Structure definitions cannot be nullable + nullable: false, }) } } @@ -503,8 +634,8 @@ impl TryFrom for ConcreteType { } /// Invalid type name error -#[derive(Debug, Error)] -#[error("Invalid type name: {0}")] +#[derive(Debug, Error, PartialEq)] +#[error("{0}")] pub struct InvalidTypeName(String); /// Error for invalid Type specifications @@ -530,7 +661,7 @@ pub enum TypeParseError { // TODO: ValidatedType will be updated when we implement proper type validation // TODO: Update this Parse implementation when ValidatedType and ParsedType are converted to owned types -// impl Parse for &extType { +// impl Parse for &RawType { // type Parsed = ValidatedType; // type Error = TypeParseError; // fn parse(self, ctx: &mut ExtensionContext) -> Result { @@ -593,8 +724,33 @@ pub enum KnownType { /// Field types (same order as field_names) field_types: Vec, }, - /// Type variable (e.g., any1, any2) - TypeVariable(u32), +} + +impl fmt::Display for KnownType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KnownType::Builtin(b) => write!(f, "{b}"), + KnownType::Compound(c) => write!(f, "{c}"), + KnownType::Extension { name, parameters } => { + write!(f, "{name}")?; + write_separated(f, parameters.iter(), "<", ">", ", ") + } + KnownType::List(elem) => write!(f, "List<{elem}>"), + KnownType::Map { key, value } => write!(f, "Map<{key}, {value}>"), + KnownType::Struct(types) => write_separated(f, types.iter(), "Struct<", ">", ", "), + KnownType::NamedStruct { + field_names, + field_types, + } => { + let kvs = field_names + .iter() + .zip(field_types.iter()) + .map(|(k, v)| KeyValueDisplay(k, v, ": ")); + + write_separated(f, kvs, "{", "}", ", ") + } + } + } } /// A concrete, fully-resolved type instance @@ -688,14 +844,6 @@ impl ConcreteType { } } - /// Create a new type variable - pub fn type_variable(id: u32, nullable: bool) -> ConcreteType { - ConcreteType { - known_type: KnownType::TypeVariable(id), - nullable, - } - } - /// Check if this type matches another type exactly pub fn matches(&self, other: &ConcreteType) -> bool { self == other @@ -708,53 +856,71 @@ impl ConcreteType { } } +impl fmt::Display for ConcreteType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.known_type)?; + if self.nullable { + write!(f, "?")?; + } + Ok(()) + } +} + +impl From for RawType { + fn from(val: ConcreteType) -> Self { + match val.known_type { + KnownType::NamedStruct { + field_names, + field_types, + } => { + let mut map = serde_json::Map::new(); + for (name, ty) in field_names.into_iter().zip(field_types.into_iter()) { + if let Some(v) = map.insert(name, serde_json::Value::String(ty.to_string())) { + // This should not happen - you should not have + // duplicate field names in a NamedStruct + panic!("duplicate value '{v:?}' in NamedStruct"); + } + } + RawType::Variant1(map) + } + _ => RawType::Variant0(val.to_string()), + } + } +} + impl<'a> TryFrom> for ConcreteType { type Error = ExtensionTypeError; fn try_from(parsed_type: ParsedType<'a>) -> Result { match parsed_type { - ParsedType::Builtin(builtin_type, nullability) => Ok(ConcreteType::builtin( - builtin_type, - nullability.unwrap_or(false), - )), - ParsedType::Extension(ext_name, nullability) => Ok(ConcreteType::extension( - ext_name.to_string(), - nullability.unwrap_or(false), - )), + ParsedType::Builtin(builtin_type, nullability) => { + Ok(ConcreteType::builtin(builtin_type, nullability)) + } + ParsedType::Extension(ext_name, nullability) => { + Ok(ConcreteType::extension(ext_name.to_string(), nullability)) + } ParsedType::List(element_type, nullability) => { let element_concrete = ConcreteType::try_from(*element_type)?; - Ok(ConcreteType::list( - element_concrete, - nullability.unwrap_or(false), - )) + Ok(ConcreteType::list(element_concrete, nullability)) } ParsedType::Map(key_type, value_type, nullability) => { let key_concrete = ConcreteType::try_from(*key_type)?; let value_concrete = ConcreteType::try_from(*value_type)?; - Ok(ConcreteType::map( - key_concrete, - value_concrete, - nullability.unwrap_or(false), - )) + Ok(ConcreteType::map(key_concrete, value_concrete, nullability)) } ParsedType::Struct(field_types, nullability) => { let concrete_field_types: Result, _> = field_types .into_iter() .map(ConcreteType::try_from) .collect(); - Ok(ConcreteType::r#struct( - concrete_field_types?, - nullability.unwrap_or(false), - )) + Ok(ConcreteType::r#struct(concrete_field_types?, nullability)) + } + ParsedType::TypeVariable(id, nullability) => { + Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability }) + } + ParsedType::NamedExtension(type_str, nullability) => { + Ok(ConcreteType::extension(type_str.to_string(), nullability)) } - ParsedType::TypeVariable(id, nullability) => Ok(ConcreteType::type_variable( - id, - nullability.unwrap_or(false), - )), - ParsedType::NamedExtension(type_str, nullability) => Ok(ConcreteType::extension( - type_str.to_string(), - nullability.unwrap_or(false), - )), } } } @@ -763,29 +929,28 @@ impl<'a> TryFrom> for ConcreteType { #[derive(Clone, Debug, PartialEq)] pub enum ParsedType<'a> { /// Built-in type - Builtin(BuiltinType, Option), + Builtin(BuiltinType, bool), /// Extension type reference - Extension(&'a str, Option), + Extension(&'a str, bool), /// List type - List(Box>, Option), + List(Box>, bool), /// Map type - Map(Box>, Box>, Option), + Map(Box>, Box>, bool), /// Struct type - Struct(Vec>, Option), + Struct(Vec>, bool), /// Type variable (e.g., any1, any2) - TypeVariable(u32, Option), + TypeVariable(u32, bool), /// Named extension type (unresolved) - NamedExtension(&'a str, Option), + NamedExtension(&'a str, bool), } impl<'a> ParsedType<'a> { /// Parse a type string into a ParsedType pub fn parse(type_str: &'a str) -> Self { // Simple parsing implementation - could be more sophisticated - let (base_type, nullable) = if type_str.ends_with('?') { - (&type_str[..type_str.len() - 1], Some(true)) - } else { - (type_str, Some(false)) + let (base_type, nullable) = match type_str.strip_suffix('?') { + Some(base) => (base, true), + None => (type_str, false), }; // Handle type variables like any1, any2, etc. @@ -892,9 +1057,8 @@ impl TypeBindings { #[cfg(test)] mod tests { - use super::super::extensions::SimpleExtensions; + use super::super::extensions::TypeContext; use super::*; - use crate::text; use crate::text::simple_extensions; #[test] @@ -910,29 +1074,23 @@ mod tests { #[test] fn test_parsed_type_simple() { let parsed = ParsedType::parse("i32"); - assert_eq!(parsed, ParsedType::Builtin(BuiltinType::I32, Some(false))); + assert_eq!(parsed, ParsedType::Builtin(BuiltinType::I32, false)); let parsed_nullable = ParsedType::parse("i32?"); - assert_eq!( - parsed_nullable, - ParsedType::Builtin(BuiltinType::I32, Some(true)) - ); + assert_eq!(parsed_nullable, ParsedType::Builtin(BuiltinType::I32, true)); } #[test] fn test_parsed_type_variables() { let parsed = ParsedType::parse("any1"); - assert_eq!(parsed, ParsedType::TypeVariable(1, Some(false))); + assert_eq!(parsed, ParsedType::TypeVariable(1, false)); let parsed_nullable = ParsedType::parse("any2?"); - assert_eq!(parsed_nullable, ParsedType::TypeVariable(2, Some(true))); + assert_eq!(parsed_nullable, ParsedType::TypeVariable(2, true)); // Invalid type variable ID (must be >= 1) let parsed_invalid = ParsedType::parse("any0"); - assert_eq!( - parsed_invalid, - ParsedType::NamedExtension("any0", Some(false)) - ); + assert_eq!(parsed_invalid, ParsedType::NamedExtension("any0", false)); } #[test] @@ -1019,6 +1177,111 @@ mod tests { assert!(!enum_param.is_valid_value(&Value::String("INVALID".into()))); } + #[test] + fn test_integer_param_bounds_round_trip() { + // Valid bounds now use lossy cast from f64 to i64; fractional parts are truncated toward zero + let item = simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.0), + max: Some(10.0), + options: None, + optional: None, + }; + let tp = TypeParam::try_from(item).expect("should parse integer bounds"); + match tp.param_type { + ParameterType::Integer { min, max } => { + assert_eq!(min, Some(1)); + assert_eq!(max, Some(10)); + } + _ => panic!("expected integer param type"), + } + + // Fractional min is truncated + let trunc = simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.5), + max: None, + options: None, + optional: None, + }; + let tp = TypeParam::try_from(trunc).expect("should parse with truncation"); + match tp.param_type { + ParameterType::Integer { min, max } => { + assert_eq!(min, Some(1)); + assert_eq!(max, None); + } + _ => panic!("expected integer param type"), + } + } + + #[test] + fn test_custom_type_round_trip_alias() -> Result<(), ExtensionTypeError> { + let custom = CustomType::new( + "AliasType".to_string(), + vec![], + Some(ConcreteType::builtin(BuiltinType::I32, false)), + None, + Some("desc".to_string()), + )?; + let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx)?; + assert_eq!(parsed.name, custom.name); + assert_eq!(parsed.description, custom.description); + assert_eq!(parsed.structure, custom.structure); + Ok(()) + } + + #[test] + fn test_custom_type_round_trip_named_struct() -> Result<(), ExtensionTypeError> { + let fields = vec![ + ( + "x".to_string(), + ConcreteType::builtin(BuiltinType::Fp64, false), + ), + ( + "y".to_string(), + ConcreteType::builtin(BuiltinType::Fp64, false), + ), + ]; + let (names, types): (Vec<_>, Vec<_>) = fields.into_iter().unzip(); + let custom = CustomType::new( + "Point".to_string(), + vec![], + Some(ConcreteType::named_struct( + names.clone(), + types.clone(), + false, + )), + None, + None, + )?; + let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx)?; + assert_eq!(parsed.name, custom.name); + if let Some(ConcreteType { + known_type: + KnownType::NamedStruct { + field_names, + field_types, + }, + nullable, + }) = parsed.structure + { + assert!(!nullable); + assert_eq!(field_names, names); + assert_eq!(field_types, types); + } else { + panic!("expected named struct after round-trip"); + } + Ok(()) + } + #[test] fn test_custom_type_creation() -> Result<(), ExtensionTypeError> { let custom_type = CustomType::new( @@ -1048,8 +1311,9 @@ mod tests { #[test] fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> { // Test simple type string alias - let ext_type = text::simple_extensions::Type::Variant0("i32".to_string()); - let concrete = ConcreteType::try_from(ext_type)?; + let ext_type = RawType::Variant0("i32".to_string()); + let mut ctx = TypeContext::default(); + let concrete = Parse::parse(ext_type, &mut ctx)?; assert_eq!(concrete, ConcreteType::builtin(BuiltinType::I32, false)); // Test struct type @@ -1058,8 +1322,9 @@ mod tests { "field1".to_string(), serde_json::Value::String("fp64".to_string()), ); - let ext_type = text::simple_extensions::Type::Variant1(field_map); - let concrete = ConcreteType::try_from(ext_type)?; + let ext_type = RawType::Variant1(field_map); + let mut ctx = TypeContext::default(); + let concrete = Parse::parse(ext_type, &mut ctx)?; if let KnownType::NamedStruct { field_names, @@ -1085,14 +1350,12 @@ mod tests { name: "TestType".to_string(), description: Some("A test type".to_string()), parameters: None, - structure: Some(text::simple_extensions::Type::Variant0( - "BINARY".to_string(), - )), // Alias to fp64 + structure: Some(RawType::Variant0("BINARY".to_string())), // Alias to fp64 variadic: None, }; - let mut ctx = SimpleExtensions::default(); - let custom_type = type_item.parse(&mut ctx)?; + let mut ctx = TypeContext::default(); + let custom_type = Parse::parse(type_item, &mut ctx)?; assert_eq!(custom_type.name, "TestType"); assert_eq!(custom_type.description, Some("A test type".to_string())); assert!(custom_type.structure.is_some()); @@ -1123,12 +1386,12 @@ mod tests { name: "Point".to_string(), description: Some("A 2D point".to_string()), parameters: None, - structure: Some(text::simple_extensions::Type::Variant1(field_map)), + structure: Some(RawType::Variant1(field_map)), variadic: None, }; - let mut ctx = SimpleExtensions::default(); - let custom_type = type_item.parse(&mut ctx)?; + let mut ctx = TypeContext::default(); + let custom_type = Parse::parse(type_item, &mut ctx)?; assert_eq!(custom_type.name, "Point"); if let Some(ConcreteType { @@ -1156,15 +1419,13 @@ mod tests { #[test] fn test_nullable_structure_rejected() { - let ext_type = text::simple_extensions::Type::Variant0("i32?".to_string()); - let result = ConcreteType::try_from(ext_type); + let ext_type = RawType::Variant0("i32?".to_string()); + let mut ctx = TypeContext::default(); + let result = Parse::parse(ext_type, &mut ctx); if let Err(ExtensionTypeError::StructureCannotBeNullable { type_string }) = result { assert!(type_string.contains("i32?")); } else { - panic!( - "Expected nullable structure to be rejected, got: {:?}", - result - ); + panic!("Expected nullable structure to be rejected, got: {result:?}"); } } } From 7a8584ed80cb8e099df11801ecee6446700bfd6e Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 20 Aug 2025 10:43:18 -0400 Subject: [PATCH 12/31] Extensions should be included when the extensions feature is enabled --- src/extensions.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/extensions.rs b/src/extensions.rs index 23b17bfc..0621e7bb 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -6,6 +6,7 @@ //! included in the packaged crate, ignored by git, and automatically kept //! in-sync. +#[cfg(feature = "extensions")] include!(concat!(env!("OUT_DIR"), "/extensions.in")); #[cfg(test)] @@ -18,6 +19,6 @@ mod tests { fn core_extensions() { // Force evaluation of core extensions. LazyLock::force(&EXTENSIONS); - println!("Core extensions: {:#?}", EXTENSIONS); + println!("Core extensions: {EXTENSIONS:#?}"); } } From ac447d8106565d8e8a55b9fae432985ebae071a9 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 20 Aug 2025 10:51:26 -0400 Subject: [PATCH 13/31] Delete some dead code --- src/parse/text/simple_extensions/mod.rs | 6 +- src/parse/text/simple_extensions/types.rs | 199 +--------------------- 2 files changed, 12 insertions(+), 193 deletions(-) diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index 1b0f73b5..5f89d563 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [text::simple_extensions] types. +//! Parsing of [crate::text::simple_extensions] types into [SimpleExtensions]. use std::convert::Infallible; @@ -10,7 +10,7 @@ pub mod argument; mod extensions; mod file; mod registry; -pub mod types; +mod types; pub use extensions::SimpleExtensions; pub use extensions::TypeContext; @@ -18,7 +18,7 @@ pub use file::ExtensionFile; pub use registry::Registry; pub use types::{ConcreteType, CustomType, ExtensionTypeError}; -/// Parse errors for [text::simple_extensions::SimpleExtensions]. +/// Errors for converting from YAML to [SimpleExtensions]. #[derive(Debug, Error)] pub enum SimpleExtensionsError { /// Extension type error diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 115940ca..819f01fc 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -12,7 +12,6 @@ use crate::text::simple_extensions::{ TypeParamDefsItemType, }; use serde_json::Value; -use std::collections::HashMap; use std::fmt; use std::str::FromStr; use thiserror::Error; @@ -141,49 +140,49 @@ impl fmt::Display for TypeParameter { /// Parameterized builtin types that require non-type parameters #[derive(Clone, Debug, PartialEq)] pub enum CompoundType { - /// Fixed-length character string FIXEDCHAR + /// Fixed-length character string: `FIXEDCHAR` FixedChar { /// Length (number of characters), must be >= 1 length: i32, }, - /// Variable-length character string VARCHAR + /// Variable-length character string: `VARCHAR` VarChar { /// Maximum length (number of characters), must be >= 1 length: i32, }, - /// Fixed-length binary data FIXEDBINARY + /// Fixed-length binary data: `FIXEDBINARY` FixedBinary { /// Length (number of bytes), must be >= 1 length: i32, }, - /// Fixed-point decimal DECIMAL + /// Fixed-point decimal: `DECIMAL` Decimal { /// Precision (total digits), <= 38 precision: i32, /// Scale (digits after decimal point), 0 <= S <= P scale: i32, }, - /// Time with sub-second precision PRECISIONTIME

+ /// Time with sub-second precision: `PRECISIONTIME

` PrecisionTime { /// Sub-second precision digits (0-12: seconds to picoseconds) precision: i32, }, - /// Timestamp with sub-second precision PRECISIONTIMESTAMP

+ /// Timestamp with sub-second precision: `PRECISIONTIMESTAMP

` PrecisionTimestamp { /// Sub-second precision digits (0-12: seconds to picoseconds) precision: i32, }, - /// Timezone-aware timestamp with precision PRECISIONTIMESTAMPTZ

+ /// Timezone-aware timestamp with precision: `PRECISIONTIMESTAMPTZ

` PrecisionTimestampTz { /// Sub-second precision digits (0-12: seconds to picoseconds) precision: i32, }, - /// Day-time interval INTERVAL_DAY

+ /// Day-time interval: `INTERVAL_DAY

` IntervalDay { /// Sub-second precision digits (0-9: seconds to nanoseconds) precision: i32, }, - /// Compound interval INTERVAL_COMPOUND

+ /// Compound interval: `INTERVAL_COMPOUND

` IntervalCompound { /// Sub-second precision digits precision: i32, @@ -638,60 +637,6 @@ impl Parse for RawType { #[error("{0}")] pub struct InvalidTypeName(String); -/// Error for invalid Type specifications -#[derive(Debug, thiserror::Error)] -pub enum TypeParseError { - /// Extension type name not found in context - #[error("Extension type '{name}' not found")] - ExtensionTypeNotFound { - /// The extension type name that was not found - name: String, - }, - /// Type variable ID is invalid (must be >= 1) - #[error("Type variable 'any{id}' is invalid (must be >= 1)")] - InvalidTypeVariableId { - /// The invalid type variable ID - id: u32, - }, - /// Unimplemented Type variant - #[error("Unimplemented Type variant")] - UnimplementedVariant, -} - -// TODO: ValidatedType will be updated when we implement proper type validation - -// TODO: Update this Parse implementation when ValidatedType and ParsedType are converted to owned types -// impl Parse for &RawType { -// type Parsed = ValidatedType; -// type Error = TypeParseError; -// fn parse(self, ctx: &mut ExtensionContext) -> Result { -// todo!("Update when ValidatedType and ParsedType are owned") -// } -// } - -/// Error for invalid function call specifications -#[derive(Debug, thiserror::Error)] -pub enum FunctionCallError { - /// Type parsing failed - #[error("Type parsing failed: {0}")] - TypeParseError(#[from] TypeParseError), - /// Unsupported ArgumentsItem variant - #[error("Unimplemented ArgumentsItem variant: {variant}")] - UnimplementedVariant { - /// The unsupported variant name - variant: String, - }, -} - -// TODO: Update this Parse implementation when ArgumentPattern is converted to owned type -// impl Parse for &simple_extensions::ArgumentsItem { -// type Parsed = ArgumentPattern; -// type Error = FunctionCallError; -// fn parse(self, ctx: &mut ExtensionContext) -> Result { -// todo!("Update when ArgumentPattern is owned") -// } -// } - /// Known Substrait types (builtin + extension references) #[derive(Clone, Debug, PartialEq)] pub enum KnownType { @@ -972,89 +917,6 @@ impl<'a> ParsedType<'a> { } } -/// A pattern for function arguments that can match concrete types or type variables (TODO: Remove lifetime when ArgumentPattern is owned) -#[derive(Clone, Debug, PartialEq)] -pub enum ArgumentPattern { - /// Type variable like any1, any2, etc. - TypeVariable(u32), - /// Concrete type pattern - Concrete(ConcreteType), -} - -/// Result of matching an argument pattern against a concrete type (TODO: Remove lifetime when Match is owned) -#[derive(Clone, Debug, PartialEq)] -pub enum Match { - /// Pattern matched exactly (for concrete patterns) - Concrete, - /// Type variable bound to concrete type - Variable(u32, ConcreteType), - /// Match failed - Fail, -} - -impl ArgumentPattern { - /// Check if this pattern matches the given concrete type - pub fn matches(&self, concrete: &ConcreteType) -> Match { - match self { - ArgumentPattern::TypeVariable(id) => Match::Variable(*id, concrete.clone()), - ArgumentPattern::Concrete(pattern_type) => { - if pattern_type == concrete { - Match::Concrete - } else { - Match::Fail - } - } - } - } -} - -/// Type variable bindings from matching function arguments (TODO: Remove lifetime when TypeBindings is owned) -#[derive(Debug, Clone, PartialEq)] -pub struct TypeBindings { - /// Map of type variable IDs (e.g. 1 for 'any1') to their concrete types - pub vars: HashMap, -} - -impl TypeBindings { - /// Create type bindings by matching argument patterns against concrete arguments - pub fn new(patterns: &[ArgumentPattern], args: &[ConcreteType]) -> Option { - let mut vars = HashMap::new(); - - if patterns.len() != args.len() { - return None; - } - - for (pattern, arg) in patterns.iter().zip(args.iter()) { - match pattern.matches(arg) { - Match::Concrete => {} // Pattern matched, nothing to bind - Match::Variable(id, concrete_type) => { - // Check if this type variable is already bound to a different type - if let Some(existing_type) = vars.get(&id) { - if existing_type != &concrete_type { - return None; // Conflict: same variable bound to different types - } - } else { - vars.insert(id, concrete_type); - } - } - Match::Fail => return None, // Pattern did not match - } - } - - Some(TypeBindings { vars }) - } - - /// Get the concrete type bound to a type variable - pub fn get_binding(&self, var_id: u32) -> Option<&ConcreteType> { - self.vars.get(&var_id) - } - - /// Check if all type variables are bound - pub fn is_complete(&self, expected_vars: &[u32]) -> bool { - expected_vars.iter().all(|var| self.vars.contains_key(var)) - } -} - #[cfg(test)] mod tests { use super::super::extensions::TypeContext; @@ -1114,49 +976,6 @@ mod tests { ); } - #[test] - fn test_argument_pattern_matching() { - let concrete_int = ConcreteType::builtin(BuiltinType::I32, false); - let concrete_string = ConcreteType::builtin(BuiltinType::String, false); - - // Test concrete pattern matching - let concrete_pattern = ArgumentPattern::Concrete(concrete_int.clone()); - assert_eq!(concrete_pattern.matches(&concrete_int), Match::Concrete); - assert_eq!(concrete_pattern.matches(&concrete_string), Match::Fail); - - // Test type variable pattern - let var_pattern = ArgumentPattern::TypeVariable(1); - assert_eq!( - var_pattern.matches(&concrete_int), - Match::Variable(1, concrete_int.clone()) - ); - } - - #[test] - fn test_type_bindings() { - let patterns = vec![ - ArgumentPattern::TypeVariable(1), - ArgumentPattern::TypeVariable(1), // Same variable should bind to same type - ]; - let args = vec![ - ConcreteType::builtin(BuiltinType::I32, false), - ConcreteType::builtin(BuiltinType::I32, false), - ]; - - let bindings = TypeBindings::new(&patterns, &args).unwrap(); - assert_eq!( - bindings.get_binding(1), - Some(&ConcreteType::builtin(BuiltinType::I32, false)) - ); - - // Test conflicting bindings - let conflicting_args = vec![ - ConcreteType::builtin(BuiltinType::I32, false), - ConcreteType::builtin(BuiltinType::String, false), - ]; - assert!(TypeBindings::new(&patterns, &conflicting_args).is_none()); - } - #[test] fn test_parameter_type_validation() { let int_param = ParameterType::Integer { From 9b472194e62c538fefa600ca8c04723845fc677e Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 20 Aug 2025 11:07:45 -0400 Subject: [PATCH 14/31] Start merging argument and types --- src/parse/text/simple_extensions/argument.rs | 14 +++++++-- src/parse/text/simple_extensions/types.rs | 30 +++++++++++++------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/parse/text/simple_extensions/argument.rs b/src/parse/text/simple_extensions/argument.rs index bda09471..e4e5d4fb 100644 --- a/src/parse/text/simple_extensions/argument.rs +++ b/src/parse/text/simple_extensions/argument.rs @@ -177,18 +177,26 @@ impl Parse for simple_extensions::EnumOptions { type Error = EnumOptionsError; fn parse(self, _ctx: &mut C) -> Result { - let options = self.0; + self.try_into() + } +} + +impl TryFrom for EnumOptions { + type Error = EnumOptionsError; + + fn try_from(raw: simple_extensions::EnumOptions) -> Result { + let options = raw.0; if options.is_empty() { return Err(EnumOptionsError::EmptyList); } let mut unique_options = HashSet::new(); - for option in options.iter() { + for option in options.into_iter() { if option.is_empty() { return Err(EnumOptionsError::EmptyOption); } if !unique_options.insert(option.clone()) { - return Err(EnumOptionsError::DuplicatedOption(option.clone())); + return Err(EnumOptionsError::DuplicatedOption(option)); } } diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 819f01fc..2a0f3dda 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -5,11 +5,14 @@ //! This module provides a clean, type-safe wrapper around Substrait extension types, //! separating function signature patterns from concrete argument types. +use super::argument::{ + EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError, +}; use super::extensions::TypeContext; use crate::parse::Parse; use crate::text::simple_extensions::{ - EnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs, TypeParamDefsItem, - TypeParamDefsItemType, + EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs, + TypeParamDefsItem, TypeParamDefsItemType, }; use serde_json::Value; use std::fmt; @@ -257,8 +260,8 @@ pub enum ParameterType { }, /// Enumeration parameter Enumeration { - /// Valid enumeration values - options: Vec, + /// Valid enumeration values (validated, deduplicated) + options: ParsedEnumOptions, }, /// Boolean parameter Boolean, @@ -287,9 +290,9 @@ impl ParameterType { } /// Extract raw enum options for enumeration parameters - fn raw_options(&self) -> Option { + fn raw_options(&self) -> Option { match self { - ParameterType::Enumeration { options } => Some(EnumOptions(options.clone())), + ParameterType::Enumeration { options } => Some(options.clone().into()), _ => None, } } @@ -314,7 +317,7 @@ impl ParameterType { fn from_raw( t: TypeParamDefsItemType, - opts: Option, + opts: Option, min: Option, max: Option, ) -> Result { @@ -331,7 +334,8 @@ impl ParameterType { } } TypeParamDefsItemType::Enumeration => { - let options = opts.ok_or(TypeParamError::MissingEnumOptions)?.0; // Extract Vec from EnumOptions + let options: ParsedEnumOptions = + opts.ok_or(TypeParamError::MissingEnumOptions)?.try_into()?; Self::Enumeration { options } } TypeParamDefsItemType::String => Self::String, @@ -426,6 +430,9 @@ pub enum TypeParamError { /// Enumeration parameter is missing options #[error("Enumeration parameter is missing options")] MissingEnumOptions, + /// Enumeration parameter has invalid options + #[error("Enumeration parameter has invalid options: {0}")] + InvalidEnumOptions(#[from] ParsedEnumOptionsError), } /// A validated custom extension type definition @@ -921,6 +928,7 @@ impl<'a> ParsedType<'a> { mod tests { use super::super::extensions::TypeContext; use super::*; + use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions; use crate::text::simple_extensions; #[test] @@ -988,9 +996,9 @@ mod tests { assert!(!int_param.is_valid_value(&Value::Number(11.into()))); assert!(!int_param.is_valid_value(&Value::String("not a number".into()))); - let enum_param = ParameterType::Enumeration { - options: vec!["OVERFLOW".to_string(), "ERROR".to_string()], - }; + let raw = simple_extensions::EnumOptions(vec!["OVERFLOW".to_string(), "ERROR".to_string()]); + let parsed = ParsedEnumOptions::try_from(raw).unwrap(); + let enum_param = ParameterType::Enumeration { options: parsed }; assert!(enum_param.is_valid_value(&Value::String("OVERFLOW".into()))); assert!(!enum_param.is_valid_value(&Value::String("INVALID".into()))); From 8a44db4a6e915b16d3e0dd80effc46cb8334ee2e Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 3 Sep 2025 11:52:05 -0400 Subject: [PATCH 15/31] Some updates to type handling; compiles and passes tests --- Cargo.toml | 2 +- src/extensions.rs | 1 - src/parse/text/simple_extensions/file.rs | 8 +- src/parse/text/simple_extensions/mod.rs | 2 + .../text/simple_extensions/parsed_type.rs | 206 +++++++++++++++++ src/parse/text/simple_extensions/types.rs | 209 ++++++------------ 6 files changed, 276 insertions(+), 152 deletions(-) create mode 100644 src/parse/text/simple_extensions/parsed_type.rs diff --git a/Cargo.toml b/Cargo.toml index 78358392..14848dec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [features] default = [] extensions = ["dep:serde_yaml", "dep:url"] -parse = ["dep:hex", "dep:thiserror", "dep:url", "semver"] +parse = ["dep:hex", "dep:thiserror", "dep:url", "dep:serde_yaml", "semver"] protoc = ["dep:protobuf-src"] semver = ["dep:semver"] serde = ["dep:pbjson", "dep:pbjson-build", "dep:pbjson-types"] diff --git a/src/extensions.rs b/src/extensions.rs index 0621e7bb..f6c72067 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -19,6 +19,5 @@ mod tests { fn core_extensions() { // Force evaluation of core extensions. LazyLock::force(&EXTENSIONS); - println!("Core extensions: {EXTENSIONS:#?}"); } } diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index 88c027cc..20f29665 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -56,7 +56,6 @@ impl ExtensionFile { /// - `reader`: any `Read` instance with the YAML content /// /// Returns a parsed and validated `ExtensionFile` or an error. - #[cfg(feature = "extensions")] pub fn read, R: Read>(uri: U, reader: R) -> Result where SimpleExtensionsError: From, @@ -67,7 +66,6 @@ impl ExtensionFile { } /// Read an extension file from a string slice. - #[cfg(feature = "extensions")] pub fn read_from_str, S: AsRef>( uri: U, s: S, @@ -85,10 +83,8 @@ impl ExtensionFile { #[cfg(test)] mod tests { - use crate::{ - parse::text::simple_extensions::types::ParameterType as RawParameterType, - text::simple_extensions::SimpleExtensions as RawExtensions, - }; + use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType; + use crate::text::simple_extensions::SimpleExtensions as RawExtensions; use super::*; diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index 5f89d563..d257abda 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -9,12 +9,14 @@ use thiserror::Error; pub mod argument; mod extensions; mod file; +mod parsed_type; mod registry; mod types; pub use extensions::SimpleExtensions; pub use extensions::TypeContext; pub use file::ExtensionFile; +pub use parsed_type::TypeExpr; pub use registry::Registry; pub use types::{ConcreteType, CustomType, ExtensionTypeError}; diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs new file mode 100644 index 00000000..34cedebc --- /dev/null +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: Apache-2.0 + +//! Parsed type AST used by the simple extensions type parser. + +use std::str::FromStr; + +use crate::parse::text::simple_extensions::types::CompoundType; + +use super::types::BuiltinType; + +/// A parsed type expression from a type string, with lifetime tied to the original string. +#[derive(Clone, Debug, PartialEq)] +pub enum TypeExpr<'a> { + /// A type with a name, optional parameters, and nullability + Simple(&'a str, Vec>, bool), + /// A user-defined extension type, indicated by `u!Name`, with optional + /// parameters and nullability + UserDefined(&'a str, Vec>, bool), + /// Type variable (e.g., any1, any2) + TypeVariable(u32, bool), +} + +/// A parsed parameter to a parameterized type +#[derive(Clone, Debug, PartialEq)] +pub enum TypeExprParam<'a> { + /// A nested type parameter + Type(TypeExpr<'a>), + /// An integer literal parameter + Integer(i64), + /// A string literal parameter (unquoted) + String(&'a str), +} + +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum TypeParseError { + #[error("Parameter list {0} Must start and end with angle brackets")] + ExpectedClosingAngleBracket(String), +} + +impl<'a> TypeExpr<'a> { + /// Parse a type string into a ParsedType + pub fn parse(type_str: &'a str) -> Result { + // Handle type variables like any1, any2, etc. + if let Some(suffix) = type_str.strip_prefix("any") { + let (middle, nullable) = match suffix.strip_suffix('?') { + Some(middle) => (middle, true), + None => (suffix, false), + }; + + if let Ok(id) = middle.parse::() { + return Ok(TypeExpr::TypeVariable(id, nullable)); + } + } + + let (user_defined, rest) = match type_str.strip_prefix("u!") { + Some(right) => (true, right), + None => (false, type_str), + }; + + let (name_and_nullable, params): (&'a str, Vec>) = + match rest.split_once('<') { + Some((n, p)) => match p.strip_suffix('>') { + Some(p) => (n, parse_params(p)?), + None => return Err(TypeParseError::ExpectedClosingAngleBracket(p.to_string())), + }, + None => (rest, vec![]), + }; + + let (name, nullable) = match name_and_nullable.strip_suffix('?') { + Some(name) => (name, true), + None => (name_and_nullable, false), + }; + + if user_defined { + Ok(TypeExpr::UserDefined(name, params, nullable)) + } else { + Ok(TypeExpr::Simple(name, params, nullable)) + } + } + + /// Visit all extension type references contained in a parsed type, calling `on_ext` + /// for each encountered extension name (including named extension forms). + pub fn visit_references(&self, on_ext: &mut F) + where + F: FnMut(&str), + { + match self { + TypeExpr::UserDefined(name, params, _) => { + // Strip u! prefix when reporting linkage + on_ext(name); + for p in params { + if let TypeExprParam::Type(t) = p { + t.visit_references(on_ext); + } + } + } + TypeExpr::Simple(name, params, _) => { + let lower = name.to_ascii_lowercase(); + if BuiltinType::from_str(&lower).is_err() && !CompoundType::is_name(name) { + on_ext(name); + } + for p in params { + if let TypeExprParam::Type(t) = p { + t.visit_references(on_ext); + } + } + } + TypeExpr::TypeVariable(..) => {} + } + } +} + +fn parse_params<'a>(s: &'a str) -> Result>, TypeParseError> { + let mut result = Vec::new(); + let mut start = 0; + let mut depth = 0; + + for (i, c) in s.char_indices() { + match c { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + result.push(parse_param(s[start..i].trim())?); + start = i + 1; + } + _ => {} + } + } + + if depth != 0 { + return Err(TypeParseError::ExpectedClosingAngleBracket(s.to_string())); + } + + if start < s.len() { + result.push(parse_param(s[start..].trim())?); + } + + Ok(result) +} + +fn parse_param<'a>(s: &'a str) -> Result, TypeParseError> { + if let Ok(i) = s.parse::() { + return Ok(TypeExprParam::Integer(i)); + } + Ok(TypeExprParam::Type(TypeExpr::parse(s)?)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parsed_type_simple() { + let parsed = TypeExpr::parse("i32").unwrap(); + assert_eq!(parsed, TypeExpr::Simple("i32", vec![], false)); + + let parsed_nullable = TypeExpr::parse("i32?").unwrap(); + assert_eq!(parsed_nullable, TypeExpr::Simple("i32", vec![], true)); + } + + #[test] + fn test_parsed_type_variables() { + let parsed = TypeExpr::parse("any1").unwrap(); + assert_eq!(parsed, TypeExpr::TypeVariable(1, false)); + + let parsed_nullable = TypeExpr::parse("any2?").unwrap(); + assert_eq!(parsed_nullable, TypeExpr::TypeVariable(2, true)); + } + + #[test] + fn test_user_defined_and_params() { + match TypeExpr::parse("u!geo?>").unwrap() { + TypeExpr::UserDefined(name, params, nullable) => { + assert_eq!(name, "geo"); + assert!(nullable); + assert_eq!( + params[0], + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], true)) + ); + assert_eq!( + params[1], + TypeExprParam::Type(TypeExpr::Simple( + "point", + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + ], + false + )) + ); + } + other => panic!("unexpected: {other:?}"), + } + assert_eq!( + TypeExpr::parse("Map?").unwrap(), + TypeExpr::Simple( + "Map", + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + TypeExprParam::Type(TypeExpr::Simple("string", vec![], false)), + ], + true, + ) + ); + } +} diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 2a0f3dda..7513dd6e 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -9,6 +9,8 @@ use super::argument::{ EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError, }; use super::extensions::TypeContext; +use super::TypeExpr; +use crate::parse::text::simple_extensions::parsed_type::TypeParseError; use crate::parse::Parse; use crate::text::simple_extensions::{ EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs, @@ -192,6 +194,26 @@ pub enum CompoundType { }, } +impl CompoundType { + /// Check if a string is a valid name for a compound built-in type. + /// + /// Only matches lowercase. + pub fn is_name(s: &str) -> bool { + matches!( + s, + "fixedchar" + | "varchar" + | "fixedbinary" + | "decimal" + | "precisiontime" + | "precisiontimestamp" + | "precisiontimestamptz" + | "interval_day" + | "interval_compound" + ) + } +} + impl fmt::Display for CompoundType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -248,7 +270,7 @@ impl FromStr for BuiltinType { /// Parameter type information for type definitions #[derive(Clone, Debug, PartialEq)] -pub enum ParameterType { +pub enum ParameterConstraint { /// Data type parameter DataType, /// Integer parameter with range constraints @@ -269,22 +291,24 @@ pub enum ParameterType { String, } -impl ParameterType { +impl ParameterConstraint { /// Convert back to raw TypeParamDefsItemType fn raw_type(&self) -> TypeParamDefsItemType { match self { - ParameterType::DataType => TypeParamDefsItemType::DataType, - ParameterType::Boolean => TypeParamDefsItemType::Boolean, - ParameterType::Integer { .. } => TypeParamDefsItemType::Integer, - ParameterType::Enumeration { .. } => TypeParamDefsItemType::Enumeration, - ParameterType::String => TypeParamDefsItemType::String, + ParameterConstraint::DataType => TypeParamDefsItemType::DataType, + ParameterConstraint::Boolean => TypeParamDefsItemType::Boolean, + ParameterConstraint::Integer { .. } => TypeParamDefsItemType::Integer, + ParameterConstraint::Enumeration { .. } => TypeParamDefsItemType::Enumeration, + ParameterConstraint::String => TypeParamDefsItemType::String, } } /// Extract raw bounds for integer parameters (min, max) fn raw_bounds(&self) -> (Option, Option) { match self { - ParameterType::Integer { min, max } => (min.map(|i| i as f64), max.map(|i| i as f64)), + ParameterConstraint::Integer { min, max } => { + (min.map(|i| i as f64), max.map(|i| i as f64)) + } _ => (None, None), } } @@ -292,7 +316,7 @@ impl ParameterType { /// Extract raw enum options for enumeration parameters fn raw_options(&self) -> Option { match self { - ParameterType::Enumeration { options } => Some(options.clone().into()), + ParameterConstraint::Enumeration { options } => Some(options.clone().into()), _ => None, } } @@ -300,17 +324,17 @@ impl ParameterType { /// Check if a parameter value is valid for this parameter type pub fn is_valid_value(&self, value: &Value) -> bool { match (self, value) { - (ParameterType::DataType, Value::String(_)) => true, - (ParameterType::Integer { min, max }, Value::Number(n)) => { + (ParameterConstraint::DataType, Value::String(_)) => true, + (ParameterConstraint::Integer { min, max }, Value::Number(n)) => { if let Some(i) = n.as_i64() { min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val) } else { false } } - (ParameterType::Enumeration { options }, Value::String(s)) => options.contains(s), - (ParameterType::Boolean, Value::Bool(_)) => true, - (ParameterType::String, Value::String(_)) => true, + (ParameterConstraint::Enumeration { options }, Value::String(s)) => options.contains(s), + (ParameterConstraint::Boolean, Value::Bool(_)) => true, + (ParameterConstraint::String, Value::String(_)) => true, _ => false, } } @@ -349,14 +373,14 @@ pub struct TypeParam { /// Parameter name (e.g., "K" for a type variable) pub name: String, /// Parameter type constraints - pub param_type: ParameterType, + pub param_type: ParameterConstraint, /// Human-readable description pub description: Option, } impl TypeParam { /// Create a new type parameter - pub fn new(name: String, param_type: ParameterType, description: Option) -> Self { + pub fn new(name: String, param_type: ParameterConstraint, description: Option) -> Self { Self { name, param_type, @@ -375,7 +399,8 @@ impl TryFrom for TypeParam { fn try_from(item: TypeParamDefsItem) -> Result { let name = item.name.ok_or(TypeParamError::MissingName)?; - let param_type = ParameterType::from_raw(item.type_, item.options, item.min, item.max)?; + let param_type = + ParameterConstraint::from_raw(item.type_, item.options, item.min, item.max)?; Ok(Self { name, @@ -411,6 +436,9 @@ pub enum ExtensionTypeError { /// The type string that was nullable type_string: String, }, + /// Error parsing type + #[error("Error parsing type: {0}")] + ParseType(#[from] TypeParseError), } /// Error types for TypeParam validation @@ -566,34 +594,11 @@ impl Parse for RawType { type Error = ExtensionTypeError; fn parse(self, ctx: &mut TypeContext) -> Result { - // Walk a ParsedType and record extension type references in the context - fn visit_parsed_type_references<'a, F>(p: &'a ParsedType<'a>, on_ext: &mut F) - where - F: FnMut(&'a str), - { - match p { - ParsedType::Extension(name, _) | ParsedType::NamedExtension(name, _) => { - on_ext(name) - } - ParsedType::List(t, _) => visit_parsed_type_references(t, on_ext), - ParsedType::Map(k, v, _) => { - visit_parsed_type_references(k, on_ext); - visit_parsed_type_references(v, on_ext); - } - ParsedType::Struct(ts, _) => { - for t in ts { - visit_parsed_type_references(t, on_ext) - } - } - ParsedType::Builtin(..) | ParsedType::TypeVariable(..) => {} - } - } - match self { RawType::Variant0(type_string) => { - let parsed_type = ParsedType::parse(&type_string); + let parsed_type = TypeExpr::parse(&type_string)?; let mut link = |name: &str| ctx.linked(name); - visit_parsed_type_references(&parsed_type, &mut link); + parsed_type.visit_references(&mut link); let concrete = ConcreteType::try_from(parsed_type)?; // Structure representation cannot be nullable @@ -619,9 +624,9 @@ impl Parse for RawType { } }; - let parsed_field_type = ParsedType::parse(&type_string); + let parsed_field_type = TypeExpr::parse(&type_string)?; let mut link = |name: &str| ctx.linked(name); - visit_parsed_type_references(&parsed_field_type, &mut link); + parsed_field_type.visit_references(&mut link); let field_concrete_type = ConcreteType::try_from(parsed_field_type)?; field_types.push(field_concrete_type); @@ -840,90 +845,28 @@ impl From for RawType { } } -impl<'a> TryFrom> for ConcreteType { +impl<'a> TryFrom> for ConcreteType { type Error = ExtensionTypeError; - fn try_from(parsed_type: ParsedType<'a>) -> Result { + fn try_from(parsed_type: TypeExpr<'a>) -> Result { match parsed_type { - ParsedType::Builtin(builtin_type, nullability) => { - Ok(ConcreteType::builtin(builtin_type, nullability)) - } - ParsedType::Extension(ext_name, nullability) => { - Ok(ConcreteType::extension(ext_name.to_string(), nullability)) - } - ParsedType::List(element_type, nullability) => { - let element_concrete = ConcreteType::try_from(*element_type)?; - Ok(ConcreteType::list(element_concrete, nullability)) - } - ParsedType::Map(key_type, value_type, nullability) => { - let key_concrete = ConcreteType::try_from(*key_type)?; - let value_concrete = ConcreteType::try_from(*value_type)?; - Ok(ConcreteType::map(key_concrete, value_concrete, nullability)) - } - ParsedType::Struct(field_types, nullability) => { - let concrete_field_types: Result, _> = field_types - .into_iter() - .map(ConcreteType::try_from) - .collect(); - Ok(ConcreteType::r#struct(concrete_field_types?, nullability)) + TypeExpr::Simple(name, _params, nullable) => { + // Try builtin first + match BuiltinType::from_str(&name.to_ascii_lowercase()) { + Ok(b) => Ok(ConcreteType::builtin(b, nullable)), + Err(_) => Ok(ConcreteType::extension(name.to_string(), nullable)), + } } - ParsedType::TypeVariable(id, nullability) => { + TypeExpr::UserDefined(name, _params, nullable) => Ok( + ConcreteType::extension_with_params(name.to_string(), vec![], nullable), + ), + TypeExpr::TypeVariable(id, nullability) => { Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability }) } - ParsedType::NamedExtension(type_str, nullability) => { - Ok(ConcreteType::extension(type_str.to_string(), nullability)) - } } } } -/// A parsed type from a type string, with lifetime tied to the original string -#[derive(Clone, Debug, PartialEq)] -pub enum ParsedType<'a> { - /// Built-in type - Builtin(BuiltinType, bool), - /// Extension type reference - Extension(&'a str, bool), - /// List type - List(Box>, bool), - /// Map type - Map(Box>, Box>, bool), - /// Struct type - Struct(Vec>, bool), - /// Type variable (e.g., any1, any2) - TypeVariable(u32, bool), - /// Named extension type (unresolved) - NamedExtension(&'a str, bool), -} - -impl<'a> ParsedType<'a> { - /// Parse a type string into a ParsedType - pub fn parse(type_str: &'a str) -> Self { - // Simple parsing implementation - could be more sophisticated - let (base_type, nullable) = match type_str.strip_suffix('?') { - Some(base) => (base, true), - None => (type_str, false), - }; - - // Handle type variables like any1, any2, etc. - if let Some(suffix) = base_type.strip_prefix("any") { - if let Ok(id) = suffix.parse::() { - if id >= 1 { - return ParsedType::TypeVariable(id, nullable); - } - } - } - - // Try to parse as builtin type - if let Ok(builtin_type) = BuiltinType::from_str(base_type) { - return ParsedType::Builtin(builtin_type, nullable); - } - - // Otherwise, treat as extension type - ParsedType::NamedExtension(base_type, nullable) - } -} - #[cfg(test)] mod tests { use super::super::extensions::TypeContext; @@ -941,28 +884,6 @@ mod tests { assert!(BuiltinType::from_str("invalid").is_err()); } - #[test] - fn test_parsed_type_simple() { - let parsed = ParsedType::parse("i32"); - assert_eq!(parsed, ParsedType::Builtin(BuiltinType::I32, false)); - - let parsed_nullable = ParsedType::parse("i32?"); - assert_eq!(parsed_nullable, ParsedType::Builtin(BuiltinType::I32, true)); - } - - #[test] - fn test_parsed_type_variables() { - let parsed = ParsedType::parse("any1"); - assert_eq!(parsed, ParsedType::TypeVariable(1, false)); - - let parsed_nullable = ParsedType::parse("any2?"); - assert_eq!(parsed_nullable, ParsedType::TypeVariable(2, true)); - - // Invalid type variable ID (must be >= 1) - let parsed_invalid = ParsedType::parse("any0"); - assert_eq!(parsed_invalid, ParsedType::NamedExtension("any0", false)); - } - #[test] fn test_concrete_type_creation() { let int_type = ConcreteType::builtin(BuiltinType::I32, false); @@ -986,7 +907,7 @@ mod tests { #[test] fn test_parameter_type_validation() { - let int_param = ParameterType::Integer { + let int_param = ParameterConstraint::Integer { min: Some(1), max: Some(10), }; @@ -998,7 +919,7 @@ mod tests { let raw = simple_extensions::EnumOptions(vec!["OVERFLOW".to_string(), "ERROR".to_string()]); let parsed = ParsedEnumOptions::try_from(raw).unwrap(); - let enum_param = ParameterType::Enumeration { options: parsed }; + let enum_param = ParameterConstraint::Enumeration { options: parsed }; assert!(enum_param.is_valid_value(&Value::String("OVERFLOW".into()))); assert!(!enum_param.is_valid_value(&Value::String("INVALID".into()))); @@ -1018,7 +939,7 @@ mod tests { }; let tp = TypeParam::try_from(item).expect("should parse integer bounds"); match tp.param_type { - ParameterType::Integer { min, max } => { + ParameterConstraint::Integer { min, max } => { assert_eq!(min, Some(1)); assert_eq!(max, Some(10)); } @@ -1037,7 +958,7 @@ mod tests { }; let tp = TypeParam::try_from(trunc).expect("should parse with truncation"); match tp.param_type { - ParameterType::Integer { min, max } => { + ParameterConstraint::Integer { min, max } => { assert_eq!(min, Some(1)); assert_eq!(max, None); } From 5b1e5df416f4d3663baf54e970f513e6542ea64b Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Tue, 16 Sep 2025 11:02:46 -0400 Subject: [PATCH 16/31] Update to match URN change --- .../text/simple_extensions/extensions.rs | 54 ++++++++----------- src/parse/text/simple_extensions/file.rs | 23 ++++++-- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 9d17ba46..2f0942f4 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -15,29 +15,16 @@ use crate::{ /// Parsing context for extension processing /// /// The context provides access to types defined in the same extension file during parsing. -/// This allows type references to be resolved within the same extension file. -#[derive(Debug)] +/// This allows type references to be resolved within the same extension file. The corresponding +/// URN is tracked by [`ExtensionFile`](super::file::ExtensionFile) so this structure can focus on +/// validated type information. +#[derive(Clone, Debug, Default)] pub struct SimpleExtensions { - /// URN identifying this extension file - urn: Urn, /// Types defined in this extension file types: HashMap, } impl SimpleExtensions { - /// Create a new simple extension container for the provided URN. - pub fn new(urn: Urn) -> Self { - Self { - urn, - types: HashMap::new(), - } - } - - /// Returns the URN attached to this extension file. - pub fn urn(&self) -> &Urn { - &self.urn - } - /// Add a type to the context pub fn add_type(&mut self, custom_type: &CustomType) { self.types @@ -58,6 +45,11 @@ impl SimpleExtensions { pub fn types(&self) -> impl Iterator { self.types.values() } + + /// Consume the parsed extension and return its types. + pub(crate) fn into_types(self) -> HashMap { + self.types + } } /// A context for parsing simple extensions. @@ -90,13 +82,13 @@ impl Context for TypeContext { // Implement parsing for the raw text representation to produce an `ExtensionFile`. impl Parse for RawExtensions { - type Parsed = SimpleExtensions; + type Parsed = (Urn, SimpleExtensions); type Error = super::SimpleExtensionsError; fn parse(self, ctx: &mut TypeContext) -> Result { - let RawExtensions { types, urn, .. } = self; + let RawExtensions { urn, types, .. } = self; let urn = Urn::from_str(&urn)?; - let mut extension = SimpleExtensions::new(urn); + let mut extension = SimpleExtensions::default(); for type_item in types { let custom_type = Parse::parse(type_item, ctx)?; @@ -104,26 +96,26 @@ impl Parse for RawExtensions { extension.add_type(&custom_type); } - Ok(extension) + Ok((urn, extension)) } } -// Implement conversion from parsed form back to raw text representation. -impl From for RawExtensions { - fn from(value: SimpleExtensions) -> Self { - let SimpleExtensions { urn, types } = value; - let urn = urn.to_string(); - // Minimal types-only conversion to satisfy tests - let types = types.into_values().map(Into::into).collect(); +impl From<(Urn, SimpleExtensions)> for RawExtensions { + fn from((urn, extension): (Urn, SimpleExtensions)) -> Self { + let types = extension + .into_types() + .into_values() + .map(Into::into) + .collect(); + RawExtensions { - types, - // TODO: Implement conversion back to raw representation + urn: urn.to_string(), aggregate_functions: vec![], dependencies: HashMap::new(), scalar_functions: vec![], type_variations: vec![], + types, window_functions: vec![], - urn, } } } diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index 12371603..fc030078 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -6,6 +6,10 @@ use crate::urn::Urn; use std::io::Read; /// A parsed and validated [RawExtensions]. +/// +/// `ExtensionFile` owns the canonical URN for a simple extension file along with the parsed +/// [`SimpleExtensions`](super::SimpleExtensions) data. Keeping the URN here (instead of on the inner +/// type map) lets us thread it through I/O, registries, and conversions without duplicating state. #[derive(Debug)] pub struct ExtensionFile { /// The URN this extension was loaded from @@ -17,7 +21,7 @@ pub struct ExtensionFile { impl ExtensionFile { /// Create a new, empty SimpleExtensions pub fn empty(urn: Urn) -> Self { - let extension = SimpleExtensions::new(urn.clone()); + let extension = SimpleExtensions::default(); Self { urn, extension } } @@ -25,8 +29,7 @@ impl ExtensionFile { pub fn create(extensions: RawExtensions) -> Result { // Parse all types (may contain unresolved Extension(String) references) let mut ctx = TypeContext::default(); - let extension = Parse::parse(extensions, &mut ctx)?; - let urn = extension.urn().clone(); + let (urn, extension) = Parse::parse(extensions, &mut ctx)?; // TODO: Use ctx.known/ctx.linked to validate unresolved references and cross-file links. @@ -53,6 +56,17 @@ impl ExtensionFile { &self.extension } + /// Convert the parsed extension file back into the raw text representation by value. + pub fn into_raw(self) -> RawExtensions { + let ExtensionFile { urn, extension } = self; + RawExtensions::from((urn, extension)) + } + + /// Convert the parsed extension file back into the raw text representation by reference. + pub fn to_raw(&self) -> RawExtensions { + RawExtensions::from((self.urn.clone(), self.extension.clone())) + } + /// Read an extension file from a reader. /// - `reader`: any `Read` instance with the YAML content /// @@ -74,7 +88,6 @@ impl ExtensionFile { #[cfg(test)] mod tests { use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType; - use crate::text::simple_extensions::SimpleExtensions as RawExtensions; use super::*; @@ -109,7 +122,7 @@ types: } // Convert back to text::simple_extensions and assert bounds are preserved as f64 - let back: RawExtensions = ext.extension.into(); + let back = ext.to_raw(); let item = back .types .into_iter() From 6444f5cf35d3b1fc7a1ba3f0bafbf644cee599e3 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Tue, 16 Sep 2025 16:57:54 -0400 Subject: [PATCH 17/31] Some renames --- .../text/simple_extensions/parsed_type.rs | 21 +- src/parse/text/simple_extensions/types.rs | 622 ++++++++++++++---- 2 files changed, 515 insertions(+), 128 deletions(-) diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index 34cedebc..1e435688 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -2,11 +2,7 @@ //! Parsed type AST used by the simple extensions type parser. -use std::str::FromStr; - -use crate::parse::text::simple_extensions::types::CompoundType; - -use super::types::BuiltinType; +use crate::parse::text::simple_extensions::types::is_builtin_type_name; /// A parsed type expression from a type string, with lifetime tied to the original string. #[derive(Clone, Debug, PartialEq)] @@ -95,8 +91,7 @@ impl<'a> TypeExpr<'a> { } } TypeExpr::Simple(name, params, _) => { - let lower = name.to_ascii_lowercase(); - if BuiltinType::from_str(&lower).is_err() && !CompoundType::is_name(name) { + if !is_builtin_type_name(name) { on_ext(name); } for p in params { @@ -203,4 +198,16 @@ mod tests { ) ); } + + #[test] + fn test_visit_references_builtin_case_insensitive() { + let parsed = TypeExpr::parse("DECIMAL<10,2>").unwrap(); + let mut refs = Vec::new(); + parsed.visit_references(&mut |name| refs.push(name.to_string())); + assert!(refs.is_empty()); + + let parsed_list = TypeExpr::parse("List").unwrap(); + parsed_list.visit_references(&mut |name| refs.push(name.to_string())); + assert!(refs.is_empty()); + } } diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 33c4d5c4..3fc0a564 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -10,6 +10,7 @@ use super::argument::{ EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError, }; use super::extensions::TypeContext; +use super::parsed_type::TypeExprParam; use crate::parse::Parse; use crate::parse::text::simple_extensions::parsed_type::TypeParseError; use crate::text::simple_extensions::{ @@ -17,6 +18,7 @@ use crate::text::simple_extensions::{ TypeParamDefsItem, TypeParamDefsItemType, }; use serde_json::Value; +use std::convert::TryFrom; use std::fmt; use std::str::FromStr; use thiserror::Error; @@ -63,9 +65,9 @@ where } } -/// Substrait built-in primitive types (no parameters required) +/// Substrait primitive built-in types (no parameters required) #[derive(Clone, Debug, PartialEq, Eq)] -pub enum BuiltinType { +pub enum PrimitiveType { /// Boolean type - `bool` Boolean, /// 8-bit signed integer - `i8` @@ -86,36 +88,27 @@ pub enum BuiltinType { Binary, /// Calendar date - `date` Date, - /// Time of day - `time` (deprecated, use CompoundType::PrecisionTime) - Time, - /// Date and time - `timestamp` (deprecated, use CompoundType::PrecisionTimestamp) - Timestamp, - /// Date and time with timezone - `timestamp_tz` (deprecated, use CompoundType::PrecisionTimestampTz) - TimestampTz, /// Year-month interval - `interval_year` IntervalYear, /// 128-bit UUID - `uuid` Uuid, } -impl fmt::Display for BuiltinType { +impl fmt::Display for PrimitiveType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { - BuiltinType::Boolean => "bool", - BuiltinType::I8 => "i8", - BuiltinType::I16 => "i16", - BuiltinType::I32 => "i32", - BuiltinType::I64 => "i64", - BuiltinType::Fp32 => "fp32", - BuiltinType::Fp64 => "fp64", - BuiltinType::String => "string", - BuiltinType::Binary => "binary", - BuiltinType::Date => "date", - BuiltinType::Time => "time", - BuiltinType::Timestamp => "timestamp", - BuiltinType::TimestampTz => "timestamp_tz", - BuiltinType::IntervalYear => "interval_year", - BuiltinType::Uuid => "uuid", + PrimitiveType::Boolean => "bool", + PrimitiveType::I8 => "i8", + PrimitiveType::I16 => "i16", + PrimitiveType::I32 => "i32", + PrimitiveType::I64 => "i64", + PrimitiveType::Fp32 => "fp32", + PrimitiveType::Fp64 => "fp64", + PrimitiveType::String => "string", + PrimitiveType::Binary => "binary", + PrimitiveType::Date => "date", + PrimitiveType::IntervalYear => "interval_year", + PrimitiveType::Uuid => "uuid", }; f.write_str(s) } @@ -142,9 +135,10 @@ impl fmt::Display for TypeParameter { } } -/// Parameterized builtin types that require non-type parameters +/// Parameterized builtin types that require non-type parameters, e.g. numbers +/// or enum #[derive(Clone, Debug, PartialEq)] -pub enum CompoundType { +pub enum BuiltinParameterized { /// Fixed-length character string: `FIXEDCHAR` FixedChar { /// Length (number of characters), must be >= 1 @@ -194,8 +188,8 @@ pub enum CompoundType { }, } -impl CompoundType { - /// Check if a string is a valid name for a compound built-in type. +impl BuiltinParameterized { + /// Check if a string is a valid name for a parameterized builtin type. /// /// Only matches lowercase. pub fn is_name(s: &str) -> bool { @@ -214,55 +208,102 @@ impl CompoundType { } } -impl fmt::Display for CompoundType { +impl fmt::Display for BuiltinParameterized { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - CompoundType::FixedChar { length } => write!(f, "FIXEDCHAR<{length}>"), - CompoundType::VarChar { length } => write!(f, "VARCHAR<{length}>"), - CompoundType::FixedBinary { length } => write!(f, "FIXEDBINARY<{length}>"), - CompoundType::Decimal { precision, scale } => { + BuiltinParameterized::FixedChar { length } => { + write!(f, "FIXEDCHAR<{length}>") + } + BuiltinParameterized::VarChar { length } => { + write!(f, "VARCHAR<{length}>") + } + BuiltinParameterized::FixedBinary { length } => { + write!(f, "FIXEDBINARY<{length}>") + } + BuiltinParameterized::Decimal { precision, scale } => { write!(f, "DECIMAL<{precision}, {scale}>") } - CompoundType::PrecisionTime { precision } => write!(f, "PRECISIONTIME<{precision}>"), - CompoundType::PrecisionTimestamp { precision } => { + BuiltinParameterized::PrecisionTime { precision } => { + write!(f, "PRECISIONTIME<{precision}>") + } + BuiltinParameterized::PrecisionTimestamp { precision } => { write!(f, "PRECISIONTIMESTAMP<{precision}>") } - CompoundType::PrecisionTimestampTz { precision } => { + BuiltinParameterized::PrecisionTimestampTz { precision } => { write!(f, "PRECISIONTIMESTAMPTZ<{precision}>") } - CompoundType::IntervalDay { precision } => write!(f, "INTERVAL_DAY<{precision}>"), - CompoundType::IntervalCompound { precision } => { + BuiltinParameterized::IntervalDay { precision } => { + write!(f, "INTERVAL_DAY<{precision}>") + } + BuiltinParameterized::IntervalCompound { precision } => { write!(f, "INTERVAL_COMPOUND<{precision}>") } } } } +/// Unified representation of simple builtin types (primitive or parameterized). +/// Does not include container types like List, Map, or Struct. +#[derive(Clone, Debug, PartialEq)] +pub enum BuiltinKind { + /// Primitive builtins like `i32` + Primitive(PrimitiveType), + /// Parameterized builtins like `DECIMAL` + Parameterized(BuiltinParameterized), +} + +impl fmt::Display for BuiltinKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BuiltinKind::Primitive(p) => write!(f, "{p}"), + BuiltinKind::Parameterized(p) => write!(f, "{p}"), + } + } +} + +impl From for BuiltinKind { + fn from(value: PrimitiveType) -> Self { + BuiltinKind::Primitive(value) + } +} + +impl From for BuiltinKind { + fn from(value: BuiltinParameterized) -> Self { + BuiltinKind::Parameterized(value) + } +} + +/// Check if a name corresponds to any built-in type (primitive, parameterized, +/// or container) +pub fn is_builtin_type_name(name: &str) -> bool { + let lower = name.to_ascii_lowercase(); + PrimitiveType::from_str(&lower).is_ok() + || BuiltinParameterized::is_name(&lower) + || matches!(lower.as_str(), "list" | "map" | "struct") +} + /// Error when a builtin type name is not recognized #[derive(Debug, Error)] #[error("Unrecognized builtin type: {0}")] pub struct UnrecognizedBuiltin(String); -impl FromStr for BuiltinType { +impl FromStr for PrimitiveType { type Err = UnrecognizedBuiltin; fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { - "bool" => Ok(BuiltinType::Boolean), - "i8" => Ok(BuiltinType::I8), - "i16" => Ok(BuiltinType::I16), - "i32" => Ok(BuiltinType::I32), - "i64" => Ok(BuiltinType::I64), - "fp32" => Ok(BuiltinType::Fp32), - "fp64" => Ok(BuiltinType::Fp64), - "string" => Ok(BuiltinType::String), - "binary" => Ok(BuiltinType::Binary), - "date" => Ok(BuiltinType::Date), - "time" => Ok(BuiltinType::Time), - "timestamp" => Ok(BuiltinType::Timestamp), - "timestamp_tz" => Ok(BuiltinType::TimestampTz), - "interval_year" => Ok(BuiltinType::IntervalYear), - "uuid" => Ok(BuiltinType::Uuid), + "bool" => Ok(PrimitiveType::Boolean), + "i8" => Ok(PrimitiveType::I8), + "i16" => Ok(PrimitiveType::I16), + "i32" => Ok(PrimitiveType::I32), + "i64" => Ok(PrimitiveType::I64), + "fp32" => Ok(PrimitiveType::Fp32), + "fp64" => Ok(PrimitiveType::Fp64), + "string" => Ok(PrimitiveType::String), + "binary" => Ok(PrimitiveType::Binary), + "date" => Ok(PrimitiveType::Date), + "interval_year" => Ok(PrimitiveType::IntervalYear), + "uuid" => Ok(PrimitiveType::Uuid), _ => Err(UnrecognizedBuiltin(s.to_string())), } } @@ -430,6 +471,38 @@ pub enum ExtensionTypeError { /// Field type is invalid #[error("Invalid structure field type: {0}")] InvalidFieldType(String), + /// Type parameter count is invalid for the given type name + #[error("Type '{type_name}' expects {expected} parameters, got {actual}")] + InvalidParameterCount { + /// The type name being validated + type_name: String, + /// Human-readable description of the expected parameter count + expected: &'static str, + /// The actual number of parameters provided + actual: usize, + }, + /// Type parameter is of the wrong kind for the given position + #[error("Type '{type_name}' parameter {index} must be {expected}")] + InvalidParameterKind { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Expected parameter kind (e.g., integer, type) + expected: &'static str, + }, + /// Provided parameter value does not fit within the expected bounds + #[error("Type '{type_name}' parameter {index} value {value} is out of range for {expected}")] + InvalidParameterValue { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Provided parameter value + value: i64, + /// Description of the expected range or type + expected: &'static str, + }, /// Structure representation cannot be nullable #[error("Structure representation cannot be nullable: {type_string}")] StructureCannotBeNullable { @@ -463,7 +536,7 @@ pub enum TypeParamError { InvalidEnumOptions(#[from] ParsedEnumOptionsError), } -/// A validated custom extension type definition +/// A validated Simple Extension type definition #[derive(Clone, Debug, PartialEq)] pub struct CustomType { /// Type name @@ -633,7 +706,7 @@ impl Parse for RawType { } Ok(ConcreteType { - known_type: KnownType::NamedStruct { + kind: ConcreteTypeKind::NamedStruct { field_names, field_types, }, @@ -651,11 +724,9 @@ pub struct InvalidTypeName(String); /// Known Substrait types (builtin + extension references) #[derive(Clone, Debug, PartialEq)] -pub enum KnownType { - /// Simple built-in Substrait primitive type (no parameters) - Builtin(BuiltinType), - /// Parameterized built-in types - Compound(CompoundType), +pub enum ConcreteTypeKind { + /// Built-in Substrait type (primitive or parameterized) + Builtin(BuiltinKind), /// Extension type with optional parameters Extension { /// Extension type name @@ -683,19 +754,18 @@ pub enum KnownType { }, } -impl fmt::Display for KnownType { +impl fmt::Display for ConcreteTypeKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - KnownType::Builtin(b) => write!(f, "{b}"), - KnownType::Compound(c) => write!(f, "{c}"), - KnownType::Extension { name, parameters } => { + ConcreteTypeKind::Builtin(b) => write!(f, "{b}"), + ConcreteTypeKind::Extension { name, parameters } => { write!(f, "{name}")?; write_separated(f, parameters.iter(), "<", ">", ", ") } - KnownType::List(elem) => write!(f, "List<{elem}>"), - KnownType::Map { key, value } => write!(f, "Map<{key}, {value}>"), - KnownType::Struct(types) => write_separated(f, types.iter(), "Struct<", ">", ", "), - KnownType::NamedStruct { + ConcreteTypeKind::List(elem) => write!(f, "List<{elem}>"), + ConcreteTypeKind::Map { key, value } => write!(f, "Map<{key}, {value}>"), + ConcreteTypeKind::Struct(types) => write_separated(f, types.iter(), "Struct<", ">", ", "), + ConcreteTypeKind::NamedStruct { field_names, field_types, } => { @@ -713,25 +783,28 @@ impl fmt::Display for KnownType { /// A concrete, fully-resolved type instance #[derive(Clone, Debug, PartialEq)] pub struct ConcreteType { - /// The known type information - pub known_type: KnownType, + /// The resolved type shape + pub kind: ConcreteTypeKind, /// Whether this type is nullable pub nullable: bool, } impl ConcreteType { - /// Create a new builtin type - pub fn builtin(builtin_type: BuiltinType, nullable: bool) -> ConcreteType { + /// Create a new primitive builtin type + pub fn builtin(builtin_type: PrimitiveType, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Builtin(builtin_type), + kind: ConcreteTypeKind::Builtin(BuiltinKind::Primitive(builtin_type)), nullable, } } - /// Create a new compound (parameterized) type - pub fn compound(compound_type: CompoundType, nullable: bool) -> ConcreteType { + /// Create a new parameterized builtin type + pub fn parameterized_builtin( + builtin_type: BuiltinParameterized, + nullable: bool, + ) -> ConcreteType { ConcreteType { - known_type: KnownType::Compound(compound_type), + kind: ConcreteTypeKind::Builtin(BuiltinKind::Parameterized(builtin_type)), nullable, } } @@ -739,7 +812,7 @@ impl ConcreteType { /// Create a new extension type reference (without parameters) pub fn extension(name: String, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Extension { + kind: ConcreteTypeKind::Extension { name, parameters: Vec::new(), }, @@ -754,7 +827,7 @@ impl ConcreteType { nullable: bool, ) -> ConcreteType { ConcreteType { - known_type: KnownType::Extension { name, parameters }, + kind: ConcreteTypeKind::Extension { name, parameters }, nullable, } } @@ -762,7 +835,7 @@ impl ConcreteType { /// Create a new list type pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::List(Box::new(element_type)), + kind: ConcreteTypeKind::List(Box::new(element_type)), nullable, } } @@ -770,7 +843,7 @@ impl ConcreteType { /// Create a new struct type (ordered fields without names) pub fn r#struct(field_types: Vec, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Struct(field_types), + kind: ConcreteTypeKind::Struct(field_types), nullable, } } @@ -778,7 +851,7 @@ impl ConcreteType { /// Create a new map type pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType { ConcreteType { - known_type: KnownType::Map { + kind: ConcreteTypeKind::Map { key: Box::new(key_type), value: Box::new(value_type), }, @@ -793,7 +866,7 @@ impl ConcreteType { nullable: bool, ) -> ConcreteType { ConcreteType { - known_type: KnownType::NamedStruct { + kind: ConcreteTypeKind::NamedStruct { field_names, field_types, }, @@ -809,13 +882,13 @@ impl ConcreteType { /// Check if this type is compatible with another type (considering nullability) pub fn is_compatible_with(&self, other: &ConcreteType) -> bool { // Types must match exactly, but nullable types can accept non-nullable values - self.known_type == other.known_type && (self.nullable || !other.nullable) + self.kind == other.kind && (self.nullable || !other.nullable) } } impl fmt::Display for ConcreteType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.known_type)?; + write!(f, "{}", self.kind)?; if self.nullable { write!(f, "?")?; } @@ -825,8 +898,8 @@ impl fmt::Display for ConcreteType { impl From for RawType { fn from(val: ConcreteType) -> Self { - match val.known_type { - KnownType::NamedStruct { + match val.kind { + ConcreteTypeKind::NamedStruct { field_names, field_types, } => { @@ -845,21 +918,260 @@ impl From for RawType { } } +fn expect_integer_param( + type_name: &str, + index: usize, + param: &TypeExprParam<'_>, +) -> Result { + match param { + TypeExprParam::Integer(value) => { + i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue { + type_name: type_name.to_string(), + index, + value: *value, + expected: "an i32", + }) + } + _ => Err(ExtensionTypeError::InvalidParameterKind { + type_name: type_name.to_string(), + index, + expected: "an integer", + }), + } +} + +fn expect_type_argument<'a>( + type_name: &str, + index: usize, + param: TypeExprParam<'a>, +) -> Result { + match param { + TypeExprParam::Type(t) => ConcreteType::try_from(t), + TypeExprParam::Integer(_) => Err(ExtensionTypeError::InvalidParameterKind { + type_name: type_name.to_string(), + index, + expected: "a type", + }), + TypeExprParam::String(_) => Err(ExtensionTypeError::InvalidParameterKind { + type_name: type_name.to_string(), + index, + expected: "a type", + }), + } +} + +fn type_expr_param_to_type_parameter<'a>( + param: TypeExprParam<'a>, +) -> Result { + Ok(match param { + TypeExprParam::Integer(v) => TypeParameter::Integer(v), + TypeExprParam::String(s) => TypeParameter::String(s.to_string()), + TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?), + }) +} + +fn parse_parameterized_builtin<'a>( + display_name: &str, + lower_name: &str, + params: &[TypeExprParam<'a>], +) -> Result, ExtensionTypeError> { + match lower_name { + "fixedchar" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::FixedChar { length })) + } + "varchar" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::VarChar { length })) + } + "fixedbinary" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let length = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::FixedBinary { length })) + } + "decimal" => { + if params.len() != 2 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "2", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let scale = expect_integer_param(display_name, 1, ¶ms[1])?; + Ok(Some(BuiltinParameterized::Decimal { precision, scale })) + } + "precisiontime" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::PrecisionTime { precision })) + } + "precisiontimestamp" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::PrecisionTimestamp { + precision, + })) + } + "precisiontimestamptz" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::PrecisionTimestampTz { + precision, + })) + } + "interval_day" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::IntervalDay { precision })) + } + "interval_compound" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: display_name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + Ok(Some(BuiltinParameterized::IntervalCompound { + precision, + })) + } + _ => Ok(None), + } +} + impl<'a> TryFrom> for ConcreteType { type Error = ExtensionTypeError; fn try_from(parsed_type: TypeExpr<'a>) -> Result { match parsed_type { - TypeExpr::Simple(name, _params, nullable) => { - // Try builtin first - match BuiltinType::from_str(&name.to_ascii_lowercase()) { - Ok(b) => Ok(ConcreteType::builtin(b, nullable)), - Err(_) => Ok(ConcreteType::extension(name.to_string(), nullable)), + TypeExpr::Simple(name, params, nullable) => { + let lower = name.to_ascii_lowercase(); + + match lower.as_str() { + "list" => { + if params.len() != 1 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "1", + actual: params.len(), + }); + } + let element = + expect_type_argument(name, 0, params.into_iter().next().unwrap())?; + return Ok(ConcreteType::list(element, nullable)); + } + "map" => { + if params.len() != 2 { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "2", + actual: params.len(), + }); + } + let mut iter = params.into_iter(); + let key = expect_type_argument(name, 0, iter.next().unwrap())?; + let value = expect_type_argument(name, 1, iter.next().unwrap())?; + return Ok(ConcreteType::map(key, value, nullable)); + } + "struct" => { + let field_types = params + .into_iter() + .enumerate() + .map(|(idx, param)| expect_type_argument(name, idx, param)) + .collect::, _>>()?; + return Ok(ConcreteType::r#struct(field_types, nullable)); + } + _ => {} + } + + if let Some(parameterized) = + parse_parameterized_builtin(name, lower.as_str(), ¶ms)? + { + return Ok(ConcreteType::parameterized_builtin(parameterized, nullable)); + } + + match PrimitiveType::from_str(&lower) { + Ok(builtin) => { + if !params.is_empty() { + return Err(ExtensionTypeError::InvalidParameterCount { + type_name: name.to_string(), + expected: "0", + actual: params.len(), + }); + } + Ok(ConcreteType::builtin(builtin, nullable)) + } + Err(_) => { + let parameters = params + .into_iter() + .map(type_expr_param_to_type_parameter) + .collect::, _>>()?; + Ok(ConcreteType::extension_with_params( + name.to_string(), + parameters, + nullable, + )) + } } } - TypeExpr::UserDefined(name, _params, nullable) => Ok( - ConcreteType::extension_with_params(name.to_string(), vec![], nullable), - ), + TypeExpr::UserDefined(name, params, nullable) => { + let parameters = params + .into_iter() + .map(type_expr_param_to_type_parameter) + .collect::, _>>()?; + Ok(ConcreteType::extension_with_params( + name.to_string(), + parameters, + nullable, + )) + } TypeExpr::TypeVariable(id, nullability) => { Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability }) } @@ -871,26 +1183,27 @@ impl<'a> TryFrom> for ConcreteType { mod tests { use super::super::extensions::TypeContext; use super::*; + use crate::parse::text::simple_extensions::TypeExpr; use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions; use crate::text::simple_extensions; #[test] fn test_builtin_type_parsing() { - assert_eq!(BuiltinType::from_str("i32").unwrap(), BuiltinType::I32); + assert_eq!(PrimitiveType::from_str("i32").unwrap(), PrimitiveType::I32); assert_eq!( - BuiltinType::from_str("string").unwrap(), - BuiltinType::String + PrimitiveType::from_str("string").unwrap(), + PrimitiveType::String ); - assert!(BuiltinType::from_str("invalid").is_err()); + assert!(PrimitiveType::from_str("invalid").is_err()); } #[test] fn test_concrete_type_creation() { - let int_type = ConcreteType::builtin(BuiltinType::I32, false); + let int_type = ConcreteType::builtin(PrimitiveType::I32, false); assert_eq!( int_type, ConcreteType { - known_type: KnownType::Builtin(BuiltinType::I32), + kind: ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::I32)), nullable: false } ); @@ -899,12 +1212,80 @@ mod tests { assert_eq!( list_type, ConcreteType { - known_type: KnownType::List(Box::new(int_type)), + kind: ConcreteTypeKind::List(Box::new(int_type)), nullable: true } ); } + #[test] + fn test_list_type_parameters_preserved() { + let parsed = TypeExpr::parse("List").unwrap(); + let concrete = ConcreteType::try_from(parsed).unwrap(); + assert_eq!( + concrete, + ConcreteType::list(ConcreteType::builtin(PrimitiveType::I32, false), false) + ); + } + + #[test] + fn test_decimal_type_case_insensitive() { + let parsed = TypeExpr::parse("DECIMAL<10,2>").unwrap(); + let concrete = ConcreteType::try_from(parsed).unwrap(); + match concrete.kind { + ConcreteTypeKind::Builtin(BuiltinKind::Parameterized( + BuiltinParameterized::Decimal { precision, scale }, + )) => { + assert_eq!(precision, 10); + assert_eq!(scale, 2); + } + other => panic!("unexpected type: {other:?}"), + } + assert!(!concrete.nullable); + } + + #[test] + fn test_extension_parameters_preserved() { + let parsed_udf = TypeExpr::parse("u!geo, 10>").unwrap(); + let udf_type = ConcreteType::try_from(parsed_udf).unwrap(); + match &udf_type.kind { + ConcreteTypeKind::Extension { name, parameters } => { + assert_eq!(name, "geo"); + assert_eq!(parameters.len(), 2); + match ¶meters[0] { + TypeParameter::Type(inner) => match &inner.kind { + ConcreteTypeKind::List(element) => { + assert_eq!(**element, ConcreteType::builtin(PrimitiveType::I32, false)); + } + other => panic!("unexpected list element: {other:?}"), + }, + other => panic!("unexpected parameter: {other:?}"), + } + assert_eq!(parameters[1], TypeParameter::Integer(10)); + } + other => panic!("unexpected type: {other:?}"), + } + + let parsed_simple = TypeExpr::parse("Geo>").unwrap(); + let simple_type = ConcreteType::try_from(parsed_simple).unwrap(); + match &simple_type.kind { + ConcreteTypeKind::Extension { name, parameters } => { + assert_eq!(name, "Geo"); + assert_eq!(parameters.len(), 1); + match ¶meters[0] { + TypeParameter::Type(inner) => match &inner.kind { + ConcreteTypeKind::List(element) => { + assert_eq!(**element, ConcreteType::builtin(PrimitiveType::I32, false)); + } + other => panic!("unexpected list element: {other:?}"), + }, + other => panic!("unexpected parameter: {other:?}"), + } + } + other => panic!("unexpected type: {other:?}"), + } + } + #[test] fn test_parameter_type_validation() { let int_param = ParameterConstraint::Integer { @@ -971,7 +1352,7 @@ mod tests { let custom = CustomType::new( "AliasType".to_string(), vec![], - Some(ConcreteType::builtin(BuiltinType::I32, false)), + Some(ConcreteType::builtin(PrimitiveType::I32, false)), None, Some("desc".to_string()), )?; @@ -989,11 +1370,11 @@ mod tests { let fields = vec![ ( "x".to_string(), - ConcreteType::builtin(BuiltinType::Fp64, false), + ConcreteType::builtin(PrimitiveType::Fp64, false), ), ( "y".to_string(), - ConcreteType::builtin(BuiltinType::Fp64, false), + ConcreteType::builtin(PrimitiveType::Fp64, false), ), ]; let (names, types): (Vec<_>, Vec<_>) = fields.into_iter().unzip(); @@ -1013,8 +1394,8 @@ mod tests { let parsed = Parse::parse(item, &mut ctx)?; assert_eq!(parsed.name, custom.name); if let Some(ConcreteType { - known_type: - KnownType::NamedStruct { + kind: + ConcreteTypeKind::NamedStruct { field_names, field_types, }, @@ -1035,7 +1416,7 @@ mod tests { let custom_type = CustomType::new( "MyType".to_string(), vec![], - Some(ConcreteType::builtin(BuiltinType::I32, false)), + Some(ConcreteType::builtin(PrimitiveType::I32, false)), None, Some("A custom type".to_string()), )?; @@ -1062,7 +1443,7 @@ mod tests { let ext_type = RawType::Variant0("i32".to_string()); let mut ctx = TypeContext::default(); let concrete = Parse::parse(ext_type, &mut ctx)?; - assert_eq!(concrete, ConcreteType::builtin(BuiltinType::I32, false)); + assert_eq!(concrete, ConcreteType::builtin(PrimitiveType::I32, false)); // Test struct type let mut field_map = serde_json::Map::new(); @@ -1074,16 +1455,16 @@ mod tests { let mut ctx = TypeContext::default(); let concrete = Parse::parse(ext_type, &mut ctx)?; - if let KnownType::NamedStruct { + if let ConcreteTypeKind::NamedStruct { field_names, field_types, - } = &concrete.known_type + } = &concrete.kind { assert_eq!(field_names, &vec!["field1".to_string()]); assert_eq!(field_types.len(), 1); assert_eq!( field_types[0], - ConcreteType::builtin(BuiltinType::Fp64, false) + ConcreteType::builtin(PrimitiveType::Fp64, false) ); } else { panic!("Expected named struct type"); @@ -1110,8 +1491,8 @@ mod tests { if let Some(structure) = &custom_type.structure { assert_eq!( - structure.known_type, - KnownType::Builtin(BuiltinType::Binary) + structure.kind, + ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::Binary)) ); } @@ -1143,8 +1524,8 @@ mod tests { assert_eq!(custom_type.name, "Point"); if let Some(ConcreteType { - known_type: - KnownType::NamedStruct { + kind: + ConcreteTypeKind::NamedStruct { field_names, field_types, }, @@ -1155,11 +1536,10 @@ mod tests { assert!(field_names.contains(&"y".to_string())); assert_eq!(field_types.len(), 2); // Note: HashMap iteration order is not guaranteed, so we just check the types exist - assert!( - field_types - .iter() - .all(|t| matches!(t.known_type, KnownType::Builtin(BuiltinType::Fp64))) - ); + assert!(field_types.iter().all(|t| matches!( + t.kind, + ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::Fp64)) + ))); } else { panic!("Expected struct type"); } From 162e8ab28202e943343018710a20d3c40a8e3ad7 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 17 Sep 2025 11:31:16 -0400 Subject: [PATCH 18/31] Update tests --- src/parse/text/simple_extensions/file.rs | 47 +- .../text/simple_extensions/parsed_type.rs | 102 ++- src/parse/text/simple_extensions/registry.rs | 97 +-- src/parse/text/simple_extensions/types.rs | 704 ++++++++++-------- 4 files changed, 543 insertions(+), 407 deletions(-) diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index fc030078..f023b14b 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -87,14 +87,10 @@ impl ExtensionFile { #[cfg(test)] mod tests { - use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType; - use super::*; + use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType; - #[test] - fn yaml_round_trip_integer_param_bounds() { - // A minimal YAML extension file with a single type that has integer bounds on a parameter - let yaml = r#" + const YAML_PARAM_TEST: &str = r#" %YAML 1.2 --- urn: extension:example.com:param_test @@ -107,37 +103,40 @@ types: max: 10 "#; - let ext = ExtensionFile::read_from_str(yaml).expect("parse ok"); + #[test] + fn yaml_round_trip_integer_param_bounds() { + let ext = ExtensionFile::read_from_str(YAML_PARAM_TEST).expect("parse ok"); assert_eq!(ext.urn().to_string(), "extension:example.com:param_test"); - // Validate parsed parameter bounds let ty = ext.get_type("ParamTest").expect("type exists"); - assert_eq!(ty.parameters.len(), 1); - match &ty.parameters[0].param_type { - RawParameterType::Integer { min, max } => { - assert_eq!(min, &Some(1)); - assert_eq!(max, &Some(10)); - } - other => panic!("unexpected param type: {other:?}"), + match &ty.parameters[..] { + [param] => match ¶m.param_type { + RawParameterType::Integer { + min: actual_min, + max: actual_max, + } => { + assert_eq!(actual_min, &Some(1)); + assert_eq!(actual_max, &Some(10)); + } + other => panic!("unexpected param type: {other:?}"), + }, + other => panic!("unexpected parameters: {other:?}"), } - // Convert back to text::simple_extensions and assert bounds are preserved as f64 let back = ext.to_raw(); + assert_eq!(back.urn, "extension:example.com:param_test"); let item = back .types .into_iter() .find(|t| t.name == "ParamTest") .expect("round-tripped type present"); - let param_defs = item.parameters.expect("params present"); - assert_eq!(param_defs.0.len(), 1); - let p = ¶m_defs.0[0]; - assert_eq!(p.name.as_deref(), Some("K")); + let param = item.parameters.unwrap().0.into_iter().next().unwrap(); + assert_eq!(param.name.as_deref(), Some("K")); assert!(matches!( - p.type_, + param.type_, crate::text::simple_extensions::TypeParamDefsItemType::Integer )); - assert_eq!(p.min, Some(1.0)); - assert_eq!(p.max, Some(10.0)); - assert_eq!(back.urn, "extension:example.com:param_test"); + assert_eq!(param.min, Some(1.0)); + assert_eq!(param.max, Some(10.0)); } } diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index 1e435688..6418f9b8 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -144,50 +144,67 @@ fn parse_param<'a>(s: &'a str) -> Result, TypeParseError> { mod tests { use super::*; - #[test] - fn test_parsed_type_simple() { - let parsed = TypeExpr::parse("i32").unwrap(); - assert_eq!(parsed, TypeExpr::Simple("i32", vec![], false)); - - let parsed_nullable = TypeExpr::parse("i32?").unwrap(); - assert_eq!(parsed_nullable, TypeExpr::Simple("i32", vec![], true)); + fn parse(expr: &str) -> TypeExpr<'_> { + TypeExpr::parse(expr).expect("parse succeeds") } #[test] - fn test_parsed_type_variables() { - let parsed = TypeExpr::parse("any1").unwrap(); - assert_eq!(parsed, TypeExpr::TypeVariable(1, false)); + fn test_simple_types() { + let cases = vec![ + ("i32", TypeExpr::Simple("i32", vec![], false)), + ("i32?", TypeExpr::Simple("i32", vec![], true)), + ("MAP", TypeExpr::Simple("MAP", vec![], false)), + ]; + + for (expr, expected) in cases { + assert_eq!(parse(expr), expected, "unexpected parse for {expr}"); + } + } - let parsed_nullable = TypeExpr::parse("any2?").unwrap(); - assert_eq!(parsed_nullable, TypeExpr::TypeVariable(2, true)); + #[test] + fn test_type_variables() { + let cases = vec![ + ("any1", TypeExpr::TypeVariable(1, false)), + ("any2?", TypeExpr::TypeVariable(2, true)), + ]; + + for (expr, expected) in cases { + assert_eq!( + parse(expr), + expected, + "unexpected variable parse for {expr}" + ); + } } #[test] - fn test_user_defined_and_params() { - match TypeExpr::parse("u!geo?>").unwrap() { + fn test_user_defined_and_parameters() { + let expr = "u!geo?>"; + match parse(expr) { TypeExpr::UserDefined(name, params, nullable) => { - assert_eq!(name, "geo"); - assert!(nullable); - assert_eq!( - params[0], - TypeExprParam::Type(TypeExpr::Simple("i32", vec![], true)) - ); + assert_eq!(name, "geo", "unexpected name for {expr}"); + assert!(nullable, "{expr} should be nullable"); assert_eq!( - params[1], - TypeExprParam::Type(TypeExpr::Simple( - "point", - vec![ - TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), - TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), - ], - false - )) + params, + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], true)), + TypeExprParam::Type(TypeExpr::Simple( + "point", + vec![ + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + TypeExprParam::Type(TypeExpr::Simple("i32", vec![], false)), + ], + false, + )), + ] ); } - other => panic!("unexpected: {other:?}"), + other => panic!("unexpected parse result: {other:?}"), } + + let map_expr = "Map?"; assert_eq!( - TypeExpr::parse("Map?").unwrap(), + parse(map_expr), TypeExpr::Simple( "Map", vec![ @@ -195,19 +212,24 @@ mod tests { TypeExprParam::Type(TypeExpr::Simple("string", vec![], false)), ], true, - ) + ), + "unexpected map parse" ); } #[test] fn test_visit_references_builtin_case_insensitive() { - let parsed = TypeExpr::parse("DECIMAL<10,2>").unwrap(); - let mut refs = Vec::new(); - parsed.visit_references(&mut |name| refs.push(name.to_string())); - assert!(refs.is_empty()); - - let parsed_list = TypeExpr::parse("List").unwrap(); - parsed_list.visit_references(&mut |name| refs.push(name.to_string())); - assert!(refs.is_empty()); + let cases = vec![ + ("DECIMAL<10,2>", Vec::::new()), + ("List", Vec::::new()), + ("u!custom", vec!["custom".to_string()]), + ("Geo", vec!["Geo".to_string()]), + ]; + + for (expr, expected_refs) in cases { + let mut refs = Vec::new(); + parse(expr).visit_references(&mut |name| refs.push(name.to_string())); + assert_eq!(refs, expected_refs, "unexpected references for {expr}"); + } } } diff --git a/src/parse/text/simple_extensions/registry.rs b/src/parse/text/simple_extensions/registry.rs index 7174f65c..542d4bfa 100644 --- a/src/parse/text/simple_extensions/registry.rs +++ b/src/parse/text/simple_extensions/registry.rs @@ -73,65 +73,78 @@ impl Registry { #[cfg(test)] mod tests { - use super::ExtensionFile as ParsedSimpleExtensions; - use super::Registry; + use super::{ExtensionFile, Registry}; use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem}; use crate::urn::Urn; use std::str::FromStr; - fn create_test_extension_with_types() -> SimpleExtensions { - SimpleExtensions { + fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile { + let types = type_names + .iter() + .map(|name| SimpleExtensionsTypesItem { + name: (*name).to_string(), + description: None, + parameters: None, + structure: None, + variadic: None, + }) + .collect(); + + let raw = SimpleExtensions { scalar_functions: vec![], aggregate_functions: vec![], window_functions: vec![], dependencies: Default::default(), type_variations: vec![], - types: vec![SimpleExtensionsTypesItem { - name: "test_type".to_string(), - description: Some("A test type".to_string()), - parameters: None, - structure: None, - variadic: None, - }], - urn: "extension:example.com:test".to_string(), - } + types, + urn: urn.to_string(), + }; + + ExtensionFile::create(raw).expect("valid extension file") } #[test] - fn test_new_registry() { - let urn = Urn::from_str("extension:example.com:test").unwrap(); - let extension_file = - ParsedSimpleExtensions::create(create_test_extension_with_types()).unwrap(); - let extensions = vec![extension_file]; - - let registry = Registry::new(extensions); - assert_eq!(registry.extensions().count(), 1); - let extension_urns: Vec<&Urn> = registry.extensions().map(|ext| ext.urn()).collect(); - assert!(extension_urns.contains(&&urn)); + fn test_registry_iteration() { + let urns = vec![ + "extension:example.com:first", + "extension:example.com:second", + ]; + let registry = Registry::new( + urns.iter() + .map(|urn| extension_file(urn, &["type"])) + .collect(), + ); + + let collected: Vec<&Urn> = registry.extensions().map(|ext| ext.urn()).collect(); + assert_eq!(collected.len(), 2); + for urn in urns { + assert!( + collected + .iter() + .any(|candidate| candidate.to_string() == urn) + ); + } } #[test] fn test_type_lookup() { let urn = Urn::from_str("extension:example.com:test").unwrap(); - let extension_file = - ParsedSimpleExtensions::create(create_test_extension_with_types()).unwrap(); - let extensions = vec![extension_file]; - - let registry = Registry::new(extensions); - - // Test successful type lookup - let found_type = registry.get_type(&urn, "test_type"); - assert!(found_type.is_some()); - assert_eq!(found_type.unwrap().name, "test_type"); - - // Test missing type lookup - let missing_type = registry.get_type(&urn, "nonexistent_type"); - assert!(missing_type.is_none()); - - // Test missing extension lookup - let wrong_urn = Urn::from_str("extension:example.com:wrong").unwrap(); - let missing_extension = registry.get_type(&wrong_urn, "test_type"); - assert!(missing_extension.is_none()); + let registry = Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]); + let other_urn = Urn::from_str("extension:example.com:other").unwrap(); + + let cases = vec![ + (&urn, "test_type", true), + (&urn, "missing", false), + (&other_urn, "test_type", false), + ]; + + for (query_urn, type_name, expected) in cases { + assert_eq!( + registry.get_type(query_urn, type_name).is_some(), + expected, + "unexpected lookup result for {query_urn}:{type_name}" + ); + } } #[cfg(feature = "extensions")] diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 3fc0a564..139b7e36 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -200,8 +200,11 @@ impl BuiltinParameterized { | "fixedbinary" | "decimal" | "precisiontime" + | "time" | "precisiontimestamp" + | "timestamp" | "precisiontimestamptz" + | "timestamp_tz" | "interval_day" | "interval_compound" ) @@ -764,7 +767,9 @@ impl fmt::Display for ConcreteTypeKind { } ConcreteTypeKind::List(elem) => write!(f, "List<{elem}>"), ConcreteTypeKind::Map { key, value } => write!(f, "Map<{key}, {value}>"), - ConcreteTypeKind::Struct(types) => write_separated(f, types.iter(), "Struct<", ">", ", "), + ConcreteTypeKind::Struct(types) => { + write_separated(f, types.iter(), "Struct<", ">", ", ") + } ConcreteTypeKind::NamedStruct { field_names, field_types, @@ -1021,7 +1026,7 @@ fn parse_parameterized_builtin<'a>( let scale = expect_integer_param(display_name, 1, ¶ms[1])?; Ok(Some(BuiltinParameterized::Decimal { precision, scale })) } - "precisiontime" => { + "precisiontime" | "time" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1032,7 +1037,7 @@ fn parse_parameterized_builtin<'a>( let precision = expect_integer_param(display_name, 0, ¶ms[0])?; Ok(Some(BuiltinParameterized::PrecisionTime { precision })) } - "precisiontimestamp" => { + "precisiontimestamp" | "timestamp" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1041,11 +1046,9 @@ fn parse_parameterized_builtin<'a>( }); } let precision = expect_integer_param(display_name, 0, ¶ms[0])?; - Ok(Some(BuiltinParameterized::PrecisionTimestamp { - precision, - })) + Ok(Some(BuiltinParameterized::PrecisionTimestamp { precision })) } - "precisiontimestamptz" => { + "precisiontimestamptz" | "timestamp_tz" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1078,9 +1081,7 @@ fn parse_parameterized_builtin<'a>( }); } let precision = expect_integer_param(display_name, 0, ¶ms[0])?; - Ok(Some(BuiltinParameterized::IntervalCompound { - precision, - })) + Ok(Some(BuiltinParameterized::IntervalCompound { precision })) } _ => Ok(None), } @@ -1186,103 +1187,222 @@ mod tests { use crate::parse::text::simple_extensions::TypeExpr; use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions; use crate::text::simple_extensions; + use std::iter::FromIterator; - #[test] - fn test_builtin_type_parsing() { - assert_eq!(PrimitiveType::from_str("i32").unwrap(), PrimitiveType::I32); - assert_eq!( - PrimitiveType::from_str("string").unwrap(), - PrimitiveType::String + /// Create a [ConcreteType] from a [BuiltinParameterized] + fn concretize(builtin: BuiltinParameterized) -> ConcreteType { + ConcreteType::parameterized_builtin(builtin, false) + } + + /// Parse a string into a [ConcreteType] + fn parse_type(expr: &str) -> ConcreteType { + let parsed = TypeExpr::parse(expr).unwrap(); + ConcreteType::try_from(parsed).unwrap() + } + + /// Parse a string into a builtin [ConcreteType], with no unresolved + /// extension references + fn parse_simple(s: &str) -> ConcreteType { + let parsed = TypeExpr::parse(s).unwrap(); + + let mut refs = Vec::new(); + parsed.visit_references(&mut |name| refs.push(name.to_string())); + assert!(refs.is_empty(), "{s} should not add an extension reference"); + + ConcreteType::try_from(parsed).unwrap() + } + + /// Create a type parameter from a type expression string + fn type_param(expr: &str) -> TypeParameter { + TypeParameter::Type(parse_type(expr)) + } + + /// Create an extension type + fn extension(name: &str, parameters: Vec, nullable: bool) -> ConcreteType { + ConcreteType::extension_with_params(name.to_string(), parameters, nullable) + } + + /// Convert a custom type to raw and back, ensuring round-trip consistency + fn round_trip(custom: &CustomType) { + let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx).unwrap(); + assert_eq!(&parsed, custom); + } + + /// Create a raw named struct (e.g. straight from YAML) from field name and + /// type pairs + fn raw_named_struct(fields: &[(&str, &str)]) -> RawType { + let map = serde_json::Map::from_iter( + fields + .iter() + .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))), ); - assert!(PrimitiveType::from_str("invalid").is_err()); + RawType::Variant1(map) } #[test] - fn test_concrete_type_creation() { - let int_type = ConcreteType::builtin(PrimitiveType::I32, false); - assert_eq!( - int_type, - ConcreteType { - kind: ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::I32)), - nullable: false - } - ); + fn test_primitive_type_parsing() { + let cases = vec![ + ("bool", Some(PrimitiveType::Boolean)), + ("i32", Some(PrimitiveType::I32)), + ("STRING", Some(PrimitiveType::String)), + ("uuid", Some(PrimitiveType::Uuid)), + ("timestamp", None), + ("invalid", None), + ]; - let list_type = ConcreteType::list(int_type.clone(), true); - assert_eq!( - list_type, - ConcreteType { - kind: ConcreteTypeKind::List(Box::new(int_type)), - nullable: true + for (input, expected) in cases { + match expected { + Some(expected_type) => { + assert_eq!( + PrimitiveType::from_str(input).unwrap(), + expected_type, + "expected primitive type for {input}" + ); + } + None => { + assert!( + PrimitiveType::from_str(input).is_err(), + "expected parsing {input} to fail" + ); + } } - ); + } } #[test] - fn test_list_type_parameters_preserved() { - let parsed = TypeExpr::parse("List").unwrap(); - let concrete = ConcreteType::try_from(parsed).unwrap(); - assert_eq!( - concrete, - ConcreteType::list(ConcreteType::builtin(PrimitiveType::I32, false), false) - ); + fn test_parameterized_builtin_types() { + let cases = vec![ + ( + "time<9>", + concretize(BuiltinParameterized::PrecisionTime { precision: 9 }), + ), + ( + "timestamp<3>", + concretize(BuiltinParameterized::PrecisionTimestamp { precision: 3 }), + ), + ( + "timestamp_tz<4>", + concretize(BuiltinParameterized::PrecisionTimestampTz { precision: 4 }), + ), + ( + "precisiontime<2>", + concretize(BuiltinParameterized::PrecisionTime { precision: 2 }), + ), + ( + "precisiontimestamp<1>", + concretize(BuiltinParameterized::PrecisionTimestamp { precision: 1 }), + ), + ( + "precisiontimestamptz<5>", + concretize(BuiltinParameterized::PrecisionTimestampTz { precision: 5 }), + ), + ( + "DECIMAL<10,2>", + concretize(BuiltinParameterized::Decimal { + precision: 10, + scale: 2, + }), + ), + ( + "fixedchar<12>", + concretize(BuiltinParameterized::FixedChar { length: 12 }), + ), + ( + "VarChar<42>", + concretize(BuiltinParameterized::VarChar { length: 42 }), + ), + ( + "fixedbinary<8>", + concretize(BuiltinParameterized::FixedBinary { length: 8 }), + ), + ( + "interval_day<7>", + concretize(BuiltinParameterized::IntervalDay { precision: 7 }), + ), + ( + "interval_compound<6>", + concretize(BuiltinParameterized::IntervalCompound { precision: 6 }), + ), + ]; + + for (expr, expected) in cases { + let found = parse_simple(expr); + assert_eq!(found, expected, "unexpected type for {expr}"); + } } #[test] - fn test_decimal_type_case_insensitive() { - let parsed = TypeExpr::parse("DECIMAL<10,2>").unwrap(); - let concrete = ConcreteType::try_from(parsed).unwrap(); - match concrete.kind { - ConcreteTypeKind::Builtin(BuiltinKind::Parameterized( - BuiltinParameterized::Decimal { precision, scale }, - )) => { - assert_eq!(precision, 10); - assert_eq!(scale, 2); - } - other => panic!("unexpected type: {other:?}"), + fn test_container_types() { + let cases = vec![ + ( + "List", + ConcreteType::list(ConcreteType::builtin(PrimitiveType::I32, false), false), + ), + ( + "List", + ConcreteType::list(ConcreteType::builtin(PrimitiveType::Fp64, true), false), + ), + ( + "Map?", + ConcreteType::map( + ConcreteType::builtin(PrimitiveType::I64, false), + ConcreteType::builtin(PrimitiveType::String, true), + true, + ), + ), + ( + "Struct?", + ConcreteType::r#struct( + vec![ + ConcreteType::builtin(PrimitiveType::I8, false), + ConcreteType::builtin(PrimitiveType::String, true), + ], + true, + ), + ), + ]; + + for (expr, expected) in cases { + assert_eq!(parse_type(expr), expected, "unexpected parse for {expr}"); } - assert!(!concrete.nullable); } #[test] - fn test_extension_parameters_preserved() { - let parsed_udf = TypeExpr::parse("u!geo, 10>").unwrap(); - let udf_type = ConcreteType::try_from(parsed_udf).unwrap(); - match &udf_type.kind { - ConcreteTypeKind::Extension { name, parameters } => { - assert_eq!(name, "geo"); - assert_eq!(parameters.len(), 2); - match ¶meters[0] { - TypeParameter::Type(inner) => match &inner.kind { - ConcreteTypeKind::List(element) => { - assert_eq!(**element, ConcreteType::builtin(PrimitiveType::I32, false)); - } - other => panic!("unexpected list element: {other:?}"), - }, - other => panic!("unexpected parameter: {other:?}"), - } - assert_eq!(parameters[1], TypeParameter::Integer(10)); - } - other => panic!("unexpected type: {other:?}"), - } + fn test_extension_types() { + let cases = vec![ + ( + "u!geo, 10>", + extension( + "geo", + vec![type_param("List"), TypeParameter::Integer(10)], + false, + ), + ), + ( + "Geo?>", + extension("Geo", vec![type_param("List")], true), + ), + ( + "Custom", + extension( + "Custom", + vec![ + type_param("string?"), + TypeParameter::Type(ConcreteType::builtin(PrimitiveType::Boolean, false)), + ], + false, + ), + ), + ]; - let parsed_simple = TypeExpr::parse("Geo>").unwrap(); - let simple_type = ConcreteType::try_from(parsed_simple).unwrap(); - match &simple_type.kind { - ConcreteTypeKind::Extension { name, parameters } => { - assert_eq!(name, "Geo"); - assert_eq!(parameters.len(), 1); - match ¶meters[0] { - TypeParameter::Type(inner) => match &inner.kind { - ConcreteTypeKind::List(element) => { - assert_eq!(**element, ConcreteType::builtin(PrimitiveType::I32, false)); - } - other => panic!("unexpected list element: {other:?}"), - }, - other => panic!("unexpected parameter: {other:?}"), - } - } - other => panic!("unexpected type: {other:?}"), + for (expr, expected) in cases { + assert_eq!( + parse_type(expr), + expected, + "unexpected extension for {expr}" + ); } } @@ -1292,81 +1412,77 @@ mod tests { min: Some(1), max: Some(10), }; + let enum_param = ParameterConstraint::Enumeration { + options: ParsedEnumOptions::try_from(simple_extensions::EnumOptions(vec![ + "OVERFLOW".to_string(), + "ERROR".to_string(), + ])) + .unwrap(), + }; - assert!(int_param.is_valid_value(&Value::Number(5.into()))); - assert!(!int_param.is_valid_value(&Value::Number(0.into()))); - assert!(!int_param.is_valid_value(&Value::Number(11.into()))); - assert!(!int_param.is_valid_value(&Value::String("not a number".into()))); - - let raw = simple_extensions::EnumOptions(vec!["OVERFLOW".to_string(), "ERROR".to_string()]); - let parsed = ParsedEnumOptions::try_from(raw).unwrap(); - let enum_param = ParameterConstraint::Enumeration { options: parsed }; + let cases = vec![ + (&int_param, Value::Number(5.into()), true), + (&int_param, Value::Number(0.into()), false), + (&int_param, Value::Number(11.into()), false), + (&int_param, Value::String("not a number".into()), false), + (&enum_param, Value::String("OVERFLOW".into()), true), + (&enum_param, Value::String("INVALID".into()), false), + ]; - assert!(enum_param.is_valid_value(&Value::String("OVERFLOW".into()))); - assert!(!enum_param.is_valid_value(&Value::String("INVALID".into()))); + for (param, value, expected) in cases { + assert_eq!( + param.is_valid_value(&value), + expected, + "unexpected validation result for {value:?}" + ); + } } #[test] fn test_integer_param_bounds_round_trip() { - // Valid bounds now use lossy cast from f64 to i64; fractional parts are truncated toward zero - let item = simple_extensions::TypeParamDefsItem { - name: Some("K".to_string()), - description: None, - type_: simple_extensions::TypeParamDefsItemType::Integer, - min: Some(1.0), - max: Some(10.0), - options: None, - optional: None, - }; - let tp = TypeParam::try_from(item).expect("should parse integer bounds"); - match tp.param_type { - ParameterConstraint::Integer { min, max } => { - assert_eq!(min, Some(1)); - assert_eq!(max, Some(10)); - } - _ => panic!("expected integer param type"), - } + let cases = vec![ + ( + "bounded", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.0), + max: Some(10.0), + options: None, + optional: None, + }, + (Some(1), Some(10)), + ), + ( + "truncated", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: Some(1.5), + max: None, + options: None, + optional: None, + }, + (Some(1), None), + ), + ]; - // Fractional min is truncated - let trunc = simple_extensions::TypeParamDefsItem { - name: Some("K".to_string()), - description: None, - type_: simple_extensions::TypeParamDefsItemType::Integer, - min: Some(1.5), - max: None, - options: None, - optional: None, - }; - let tp = TypeParam::try_from(trunc).expect("should parse with truncation"); - match tp.param_type { - ParameterConstraint::Integer { min, max } => { - assert_eq!(min, Some(1)); - assert_eq!(max, None); + for (label, item, (expected_min, expected_max)) in cases { + let tp = TypeParam::try_from(item).expect("should parse integer bounds"); + match tp.param_type { + ParameterConstraint::Integer { min, max } => { + assert_eq!(min, expected_min, "min mismatch for {label}"); + assert_eq!(max, expected_max, "max mismatch for {label}"); + } + _ => panic!("expected integer param type for {label}"), } - _ => panic!("expected integer param type"), } } #[test] - fn test_custom_type_round_trip_alias() -> Result<(), ExtensionTypeError> { - let custom = CustomType::new( - "AliasType".to_string(), - vec![], - Some(ConcreteType::builtin(PrimitiveType::I32, false)), - None, - Some("desc".to_string()), - )?; - let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); - let mut ctx = TypeContext::default(); - let parsed = Parse::parse(item, &mut ctx)?; - assert_eq!(parsed.name, custom.name); - assert_eq!(parsed.description, custom.description); - assert_eq!(parsed.structure, custom.structure); - Ok(()) - } - - #[test] - fn test_custom_type_round_trip_named_struct() -> Result<(), ExtensionTypeError> { + fn test_custom_type_round_trip() -> Result<(), ExtensionTypeError> { let fields = vec![ ( "x".to_string(), @@ -1378,184 +1494,170 @@ mod tests { ), ]; let (names, types): (Vec<_>, Vec<_>) = fields.into_iter().unzip(); - let custom = CustomType::new( - "Point".to_string(), - vec![], - Some(ConcreteType::named_struct( - names.clone(), - types.clone(), - false, - )), - None, - None, - )?; - let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into(); - let mut ctx = TypeContext::default(); - let parsed = Parse::parse(item, &mut ctx)?; - assert_eq!(parsed.name, custom.name); - if let Some(ConcreteType { - kind: - ConcreteTypeKind::NamedStruct { - field_names, - field_types, - }, - nullable, - }) = parsed.structure - { - assert!(!nullable); - assert_eq!(field_names, names); - assert_eq!(field_types, types); - } else { - panic!("expected named struct after round-trip"); + + let cases = vec![ + CustomType::new( + "AliasType".to_string(), + vec![], + Some(ConcreteType::builtin(PrimitiveType::I32, false)), + None, + Some("a test alias type".to_string()), + )?, + CustomType::new( + "Point".to_string(), + vec![TypeParam::new( + "T".to_string(), + ParameterConstraint::DataType, + Some("a numeric value".to_string()), + )], + Some(ConcreteType::named_struct(names, types, false)), + None, + None, + )?, + ]; + + for custom in cases { + round_trip(&custom); } - Ok(()) - } - #[test] - fn test_custom_type_creation() -> Result<(), ExtensionTypeError> { - let custom_type = CustomType::new( - "MyType".to_string(), - vec![], - Some(ConcreteType::builtin(PrimitiveType::I32, false)), - None, - Some("A custom type".to_string()), - )?; - - assert_eq!(custom_type.name, "MyType"); - assert_eq!(custom_type.parameters.len(), 0); - assert!(custom_type.structure.is_some()); Ok(()) } #[test] fn test_invalid_type_names() { - // Empty name should be invalid - assert!(CustomType::validate_name("").is_err()); - // Name with whitespace should be invalid - assert!(CustomType::validate_name("bad name").is_err()); - // Valid name should pass - assert!(CustomType::validate_name("GoodName").is_ok()); - } + let cases = vec![ + ("", false), + ("bad name", false), + ("GoodName", true), + ("also_good", true), + ]; - #[test] - fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> { - // Test simple type string alias - let ext_type = RawType::Variant0("i32".to_string()); - let mut ctx = TypeContext::default(); - let concrete = Parse::parse(ext_type, &mut ctx)?; - assert_eq!(concrete, ConcreteType::builtin(PrimitiveType::I32, false)); - - // Test struct type - let mut field_map = serde_json::Map::new(); - field_map.insert( - "field1".to_string(), - serde_json::Value::String("fp64".to_string()), - ); - let ext_type = RawType::Variant1(field_map); - let mut ctx = TypeContext::default(); - let concrete = Parse::parse(ext_type, &mut ctx)?; - - if let ConcreteTypeKind::NamedStruct { - field_names, - field_types, - } = &concrete.kind - { - assert_eq!(field_names, &vec!["field1".to_string()]); - assert_eq!(field_types.len(), 1); + for (name, expected_ok) in cases { + let result = CustomType::validate_name(name); assert_eq!( - field_types[0], - ConcreteType::builtin(PrimitiveType::Fp64, false) + result.is_ok(), + expected_ok, + "unexpected validation for {name}" ); - } else { - panic!("Expected named struct type"); } - - Ok(()) } #[test] - fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> { - let type_item = simple_extensions::SimpleExtensionsTypesItem { - name: "TestType".to_string(), - description: Some("A test type".to_string()), - parameters: None, - structure: Some(RawType::Variant0("BINARY".to_string())), // Alias to fp64 - variadic: None, - }; - - let mut ctx = TypeContext::default(); - let custom_type = Parse::parse(type_item, &mut ctx)?; - assert_eq!(custom_type.name, "TestType"); - assert_eq!(custom_type.description, Some("A test type".to_string())); - assert!(custom_type.structure.is_some()); + fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> { + let cases = vec![ + ( + "alias", + RawType::Variant0("i32".to_string()), + ConcreteType::builtin(PrimitiveType::I32, false), + ), + ( + "named_struct", + raw_named_struct(&[("field1", "fp64"), ("field2", "i32?")]), + ConcreteType::named_struct( + vec!["field1".to_string(), "field2".to_string()], + vec![ + ConcreteType::builtin(PrimitiveType::Fp64, false), + ConcreteType::builtin(PrimitiveType::I32, true), + ], + false, + ), + ), + ]; - if let Some(structure) = &custom_type.structure { - assert_eq!( - structure.kind, - ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::Binary)) - ); + for (label, raw, expected) in cases { + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(raw, &mut ctx)?; + assert_eq!(parsed, expected, "unexpected type for {label}"); } Ok(()) } #[test] - fn test_custom_type_with_struct() -> Result<(), ExtensionTypeError> { - let mut field_map = serde_json::Map::new(); - field_map.insert( - "x".to_string(), - serde_json::Value::String("fp64".to_string()), - ); - field_map.insert( - "y".to_string(), - serde_json::Value::String("fp64".to_string()), - ); - - let type_item = simple_extensions::SimpleExtensionsTypesItem { - name: "Point".to_string(), - description: Some("A 2D point".to_string()), - parameters: None, - structure: Some(RawType::Variant1(field_map)), - variadic: None, - }; - - let mut ctx = TypeContext::default(); - let custom_type = Parse::parse(type_item, &mut ctx)?; - assert_eq!(custom_type.name, "Point"); - - if let Some(ConcreteType { - kind: - ConcreteTypeKind::NamedStruct { - field_names, - field_types, + fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> { + let cases = vec![ + ( + "alias", + simple_extensions::SimpleExtensionsTypesItem { + name: "Alias".to_string(), + description: Some("Alias type".to_string()), + parameters: None, + structure: Some(RawType::Variant0("BINARY".to_string())), + variadic: None, }, - .. - }) = &custom_type.structure - { - assert!(field_names.contains(&"x".to_string())); - assert!(field_names.contains(&"y".to_string())); - assert_eq!(field_types.len(), 2); - // Note: HashMap iteration order is not guaranteed, so we just check the types exist - assert!(field_types.iter().all(|t| matches!( - t.kind, - ConcreteTypeKind::Builtin(BuiltinKind::Primitive(PrimitiveType::Fp64)) - ))); - } else { - panic!("Expected struct type"); + "Alias", + Some("Alias type"), + Some(ConcreteType::builtin(PrimitiveType::Binary, false)), + ), + ( + "named_struct", + simple_extensions::SimpleExtensionsTypesItem { + name: "Point".to_string(), + description: Some("A 2D point".to_string()), + parameters: None, + structure: Some(raw_named_struct(&[("x", "fp64"), ("y", "fp64?")])), + variadic: None, + }, + "Point", + Some("A 2D point"), + Some(ConcreteType::named_struct( + vec!["x".to_string(), "y".to_string()], + vec![ + ConcreteType::builtin(PrimitiveType::Fp64, false), + ConcreteType::builtin(PrimitiveType::Fp64, true), + ], + false, + )), + ), + ( + "no_structure", + simple_extensions::SimpleExtensionsTypesItem { + name: "Opaque".to_string(), + description: None, + parameters: None, + structure: None, + variadic: Some(true), + }, + "Opaque", + None, + None, + ), + ]; + + for (label, item, expected_name, expected_description, expected_structure) in cases { + let mut ctx = TypeContext::default(); + let parsed = Parse::parse(item, &mut ctx)?; + assert_eq!(parsed.name, expected_name); + assert_eq!( + parsed.description.as_deref(), + expected_description, + "description mismatch for {label}" + ); + assert_eq!( + parsed.structure, expected_structure, + "structure mismatch for {label}" + ); } Ok(()) } + /// A type defined with a structure cannot be defined as nullable; e.g. if + /// you define 'Integer' as an alias for 'i64?', then what do you mean by + /// 'INTEGER?' - is that now equal to `i64??` #[test] fn test_nullable_structure_rejected() { - let ext_type = RawType::Variant0("i32?".to_string()); - let mut ctx = TypeContext::default(); - let result = Parse::parse(ext_type, &mut ctx); - if let Err(ExtensionTypeError::StructureCannotBeNullable { type_string }) = result { - assert!(type_string.contains("i32?")); - } else { - panic!("Expected nullable structure to be rejected, got: {result:?}"); + let cases = vec![RawType::Variant0("i32?".to_string())]; + + for raw in cases { + let mut ctx = TypeContext::default(); + let result = Parse::parse(raw, &mut ctx); + match result { + Err(ExtensionTypeError::StructureCannotBeNullable { type_string }) => { + assert!(type_string.contains('?')); + } + other => panic!("Expected nullable structure error, got {other:?}"), + } } } } From 421d95732734de36afcdd138c512dc6b54546de4 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Thu, 18 Sep 2025 11:15:19 -0400 Subject: [PATCH 19/31] Removed string / enum type parameter, that's not a thing --- src/parse/text/simple_extensions/parsed_type.rs | 2 -- src/parse/text/simple_extensions/types.rs | 9 --------- 2 files changed, 11 deletions(-) diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index 6418f9b8..c9c77bdf 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -23,8 +23,6 @@ pub enum TypeExprParam<'a> { Type(TypeExpr<'a>), /// An integer literal parameter Integer(i64), - /// A string literal parameter (unquoted) - String(&'a str), } #[derive(Debug, PartialEq, thiserror::Error)] diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 139b7e36..4dc18e81 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -121,8 +121,6 @@ pub enum TypeParameter { Integer(i64), /// Type parameter (nested type) Type(ConcreteType), - /// String parameter - String(String), } impl fmt::Display for TypeParameter { @@ -130,7 +128,6 @@ impl fmt::Display for TypeParameter { match self { TypeParameter::Integer(i) => write!(f, "{i}"), TypeParameter::Type(t) => write!(f, "{t}"), - TypeParameter::String(s) => write!(f, "{s}"), } } } @@ -957,11 +954,6 @@ fn expect_type_argument<'a>( index, expected: "a type", }), - TypeExprParam::String(_) => Err(ExtensionTypeError::InvalidParameterKind { - type_name: type_name.to_string(), - index, - expected: "a type", - }), } } @@ -970,7 +962,6 @@ fn type_expr_param_to_type_parameter<'a>( ) -> Result { Ok(match param { TypeExprParam::Integer(v) => TypeParameter::Integer(v), - TypeExprParam::String(s) => TypeParameter::String(s.to_string()), TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?), }) } From 46b64684c18f98e46051a2b3a0d17f838a936fff Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 12:01:17 -0400 Subject: [PATCH 20/31] Time should be primitive, precisiontime parameterized --- .../text/simple_extensions/parsed_type.rs | 3 ++ src/parse/text/simple_extensions/types.rs | 50 +++++++++++-------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index c9c77bdf..af4030c5 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -152,6 +152,9 @@ mod tests { ("i32", TypeExpr::Simple("i32", vec![], false)), ("i32?", TypeExpr::Simple("i32", vec![], true)), ("MAP", TypeExpr::Simple("MAP", vec![], false)), + ("timestamp", TypeExpr::Simple("timestamp", vec![], false)), + ("timestamp_tz?", TypeExpr::Simple("timestamp_tz", vec![], true)), + ("time", TypeExpr::Simple("time", vec![], false)), ]; for (expr, expected) in cases { diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 4dc18e81..eebcfc45 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -86,8 +86,14 @@ pub enum PrimitiveType { String, /// Variable-length binary data - `binary` Binary, + /// Naive Timestamp + Timestamp, + /// Timestamp with time zone - `timestamp_tz` + TimestampTz, /// Calendar date - `date` Date, + /// Time of day - `time` + Time, /// Year-month interval - `interval_year` IntervalYear, /// 128-bit UUID - `uuid` @@ -106,7 +112,10 @@ impl fmt::Display for PrimitiveType { PrimitiveType::Fp64 => "fp64", PrimitiveType::String => "string", PrimitiveType::Binary => "binary", + PrimitiveType::Timestamp => "timestamp", + PrimitiveType::TimestampTz => "timestamp_tz", PrimitiveType::Date => "date", + PrimitiveType::Time => "time", PrimitiveType::IntervalYear => "interval_year", PrimitiveType::Uuid => "uuid", }; @@ -197,11 +206,8 @@ impl BuiltinParameterized { | "fixedbinary" | "decimal" | "precisiontime" - | "time" | "precisiontimestamp" - | "timestamp" | "precisiontimestamptz" - | "timestamp_tz" | "interval_day" | "interval_compound" ) @@ -301,7 +307,10 @@ impl FromStr for PrimitiveType { "fp64" => Ok(PrimitiveType::Fp64), "string" => Ok(PrimitiveType::String), "binary" => Ok(PrimitiveType::Binary), + "timestamp" => Ok(PrimitiveType::Timestamp), + "timestamp_tz" => Ok(PrimitiveType::TimestampTz), "date" => Ok(PrimitiveType::Date), + "time" => Ok(PrimitiveType::Time), "interval_year" => Ok(PrimitiveType::IntervalYear), "uuid" => Ok(PrimitiveType::Uuid), _ => Err(UnrecognizedBuiltin(s.to_string())), @@ -1017,7 +1026,10 @@ fn parse_parameterized_builtin<'a>( let scale = expect_integer_param(display_name, 1, ¶ms[1])?; Ok(Some(BuiltinParameterized::Decimal { precision, scale })) } - "precisiontime" | "time" => { + // Should we accept both "precision_time" and "precisiontime"? The + // docs/spec say PRECISIONTIME. The protos use underscores, so it could + // show up in generated code, although maybe that's out of spec. + "precisiontime" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1028,7 +1040,7 @@ fn parse_parameterized_builtin<'a>( let precision = expect_integer_param(display_name, 0, ¶ms[0])?; Ok(Some(BuiltinParameterized::PrecisionTime { precision })) } - "precisiontimestamp" | "timestamp" => { + "precisiontimestamp" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1039,7 +1051,7 @@ fn parse_parameterized_builtin<'a>( let precision = expect_integer_param(display_name, 0, ¶ms[0])?; Ok(Some(BuiltinParameterized::PrecisionTimestamp { precision })) } - "precisiontimestamptz" | "timestamp_tz" => { + "precisiontimestamptz" => { if params.len() != 1 { return Err(ExtensionTypeError::InvalidParameterCount { type_name: display_name.to_string(), @@ -1198,7 +1210,7 @@ mod tests { let mut refs = Vec::new(); parsed.visit_references(&mut |name| refs.push(name.to_string())); - assert!(refs.is_empty(), "{s} should not add an extension reference"); + assert!(refs.is_empty(), "{s} should add a builtin type"); ConcreteType::try_from(parsed).unwrap() } @@ -1236,10 +1248,20 @@ mod tests { fn test_primitive_type_parsing() { let cases = vec![ ("bool", Some(PrimitiveType::Boolean)), + ("i8", Some(PrimitiveType::I8)), + ("i16", Some(PrimitiveType::I16)), ("i32", Some(PrimitiveType::I32)), + ("i64", Some(PrimitiveType::I64)), + ("fp32", Some(PrimitiveType::Fp32)), + ("fp64", Some(PrimitiveType::Fp64)), ("STRING", Some(PrimitiveType::String)), + ("binary", Some(PrimitiveType::Binary)), ("uuid", Some(PrimitiveType::Uuid)), - ("timestamp", None), + ("date", Some(PrimitiveType::Date)), + ("interval_year", Some(PrimitiveType::IntervalYear)), + ("time", Some(PrimitiveType::Time)), + ("timestamp", Some(PrimitiveType::Timestamp)), + ("timestamp_tz", Some(PrimitiveType::TimestampTz)), ("invalid", None), ]; @@ -1265,18 +1287,6 @@ mod tests { #[test] fn test_parameterized_builtin_types() { let cases = vec![ - ( - "time<9>", - concretize(BuiltinParameterized::PrecisionTime { precision: 9 }), - ), - ( - "timestamp<3>", - concretize(BuiltinParameterized::PrecisionTimestamp { precision: 3 }), - ), - ( - "timestamp_tz<4>", - concretize(BuiltinParameterized::PrecisionTimestampTz { precision: 4 }), - ), ( "precisiontime<2>", concretize(BuiltinParameterized::PrecisionTime { precision: 2 }), From 83a4249f3433dd69ccef34b0f5db271893367d5d Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 14:25:29 -0400 Subject: [PATCH 21/31] Enforce ranges on fixed-range types --- .../text/simple_extensions/parsed_type.rs | 5 +- src/parse/text/simple_extensions/types.rs | 117 ++++++++++++++++-- 2 files changed, 108 insertions(+), 14 deletions(-) diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index af4030c5..0ac61e88 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -153,7 +153,10 @@ mod tests { ("i32?", TypeExpr::Simple("i32", vec![], true)), ("MAP", TypeExpr::Simple("MAP", vec![], false)), ("timestamp", TypeExpr::Simple("timestamp", vec![], false)), - ("timestamp_tz?", TypeExpr::Simple("timestamp_tz", vec![], true)), + ( + "timestamp_tz?", + TypeExpr::Simple("timestamp_tz", vec![], true), + ), ("time", TypeExpr::Simple("time", vec![], false)), ]; diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index eebcfc45..e4a6981d 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -20,6 +20,7 @@ use crate::text::simple_extensions::{ use serde_json::Value; use std::convert::TryFrom; use std::fmt; +use std::ops::RangeInclusive; use std::str::FromStr; use thiserror::Error; @@ -501,7 +502,7 @@ pub enum ExtensionTypeError { expected: &'static str, }, /// Provided parameter value does not fit within the expected bounds - #[error("Type '{type_name}' parameter {index} value {value} is out of range for {expected}")] + #[error("Type '{type_name}' parameter {index} value {value} is not within {expected}")] InvalidParameterValue { /// The type name being validated type_name: String, @@ -512,6 +513,18 @@ pub enum ExtensionTypeError { /// Description of the expected range or type expected: &'static str, }, + /// Provided parameter value does not fit within the expected bounds + #[error("Type '{type_name}' parameter {index} value {value} is out of range {expected:?}")] + InvalidParameterRange { + /// The type name being validated + type_name: String, + /// Zero-based index of the offending parameter + index: usize, + /// Provided parameter value + value: i64, + /// Description of the expected range or type + expected: RangeInclusive, + }, /// Structure representation cannot be nullable #[error("Structure representation cannot be nullable: {type_string}")] StructureCannotBeNullable { @@ -933,8 +946,9 @@ fn expect_integer_param( type_name: &str, index: usize, param: &TypeExprParam<'_>, + range: Option>, ) -> Result { - match param { + let value = match param { TypeExprParam::Integer(value) => { i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue { type_name: type_name.to_string(), @@ -948,7 +962,20 @@ fn expect_integer_param( index, expected: "an integer", }), + }?; + + if let Some(range) = &range { + if !range.contains(&value) { + return Err(ExtensionTypeError::InvalidParameterRange { + type_name: type_name.to_string(), + index, + value: i64::from(value), + expected: range.clone(), + }); + } } + + Ok(value) } fn expect_type_argument<'a>( @@ -989,7 +1016,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let length = expect_integer_param(display_name, 0, ¶ms[0])?; + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; Ok(Some(BuiltinParameterized::FixedChar { length })) } "varchar" => { @@ -1000,7 +1027,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let length = expect_integer_param(display_name, 0, ¶ms[0])?; + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; Ok(Some(BuiltinParameterized::VarChar { length })) } "fixedbinary" => { @@ -1011,7 +1038,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let length = expect_integer_param(display_name, 0, ¶ms[0])?; + let length = expect_integer_param(display_name, 0, ¶ms[0], None)?; Ok(Some(BuiltinParameterized::FixedBinary { length })) } "decimal" => { @@ -1022,8 +1049,8 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; - let scale = expect_integer_param(display_name, 1, ¶ms[1])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(1..=38))?; + let scale = expect_integer_param(display_name, 1, ¶ms[1], Some(0..=precision))?; Ok(Some(BuiltinParameterized::Decimal { precision, scale })) } // Should we accept both "precision_time" and "precisiontime"? The @@ -1037,7 +1064,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; Ok(Some(BuiltinParameterized::PrecisionTime { precision })) } "precisiontimestamp" => { @@ -1048,7 +1075,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; Ok(Some(BuiltinParameterized::PrecisionTimestamp { precision })) } "precisiontimestamptz" => { @@ -1059,7 +1086,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=12))?; Ok(Some(BuiltinParameterized::PrecisionTimestampTz { precision, })) @@ -1072,7 +1099,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], Some(0..=9))?; Ok(Some(BuiltinParameterized::IntervalDay { precision })) } "interval_compound" => { @@ -1083,7 +1110,7 @@ fn parse_parameterized_builtin<'a>( actual: params.len(), }); } - let precision = expect_integer_param(display_name, 0, ¶ms[0])?; + let precision = expect_integer_param(display_name, 0, ¶ms[0], None)?; Ok(Some(BuiltinParameterized::IntervalCompound { precision })) } _ => Ok(None), @@ -1203,6 +1230,12 @@ mod tests { ConcreteType::try_from(parsed).unwrap() } + /// Parse a string into a [ConcreteType], returning the result + fn parse_type_result(expr: &str) -> Result { + let parsed = TypeExpr::parse(expr).unwrap(); + ConcreteType::try_from(parsed) + } + /// Parse a string into a builtin [ConcreteType], with no unresolved /// extension references fn parse_simple(s: &str) -> ConcreteType { @@ -1210,7 +1243,7 @@ mod tests { let mut refs = Vec::new(); parsed.visit_references(&mut |name| refs.push(name.to_string())); - assert!(refs.is_empty(), "{s} should add a builtin type"); + assert!(refs.is_empty(), "{s} should not add an extension reference"); ConcreteType::try_from(parsed).unwrap() } @@ -1334,6 +1367,64 @@ mod tests { } } + #[test] + fn test_parameterized_builtin_range_errors() { + use ExtensionTypeError::InvalidParameterRange; + + let cases = vec![ + ("precisiontime<13>", "precisiontime", 0, 13, 0..=12), + ("precisiontime<-1>", "precisiontime", 0, -1, 0..=12), + ( + "precisiontimestamp<13>", + "precisiontimestamp", + 0, + 13, + 0..=12, + ), + ( + "precisiontimestamp<-1>", + "precisiontimestamp", + 0, + -1, + 0..=12, + ), + ( + "precisiontimestamptz<20>", + "precisiontimestamptz", + 0, + 20, + 0..=12, + ), + ("interval_day<10>", "interval_day", 0, 10, 0..=9), + ("DECIMAL<39,0>", "DECIMAL", 0, 39, 1..=38), + ("DECIMAL<0,0>", "DECIMAL", 0, 0, 1..=38), + ("DECIMAL<10,-1>", "DECIMAL", 1, -1, 0..=10), + ("DECIMAL<10,12>", "DECIMAL", 1, 12, 0..=10), + ]; + + for (expr, expected_type, expected_index, expected_value, expected_range) in cases { + match parse_type_result(expr) { + Ok(value) => panic!("expected error parsing {expr}, got {value:?}"), + Err(InvalidParameterRange { + type_name, + index, + value, + expected, + }) => { + assert_eq!(type_name, expected_type, "unexpected type for {expr}"); + assert_eq!(index, expected_index, "unexpected index for {expr}"); + assert_eq!( + value, + i64::from(expected_value), + "unexpected value for {expr}" + ); + assert_eq!(expected, expected_range, "unexpected range for {expr}"); + } + Err(other) => panic!("expected InvalidParameterRange for {expr}, got {other:?}"), + } + } + } + #[test] fn test_container_types() { let cases = vec![ From f415dc8cc14771ab577242af5e9b5da57ce51071 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 14:49:41 -0400 Subject: [PATCH 22/31] Enforce float / int bounds --- .../text/simple_extensions/extensions.rs | 2 +- src/parse/text/simple_extensions/types.rs | 62 +++++++++++++++---- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 2f0942f4..9f5d484c 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -55,7 +55,7 @@ impl SimpleExtensions { /// A context for parsing simple extensions. #[derive(Debug, Default)] pub struct TypeContext { - /// Types that have been seen so far, not yet resolved. + /// Types that have been seen so far, now resolved. known: HashSet, /// Types that have been linked to, not yet resolved. linked: HashSet, diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index e4a6981d..39785894 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -400,9 +400,19 @@ impl ParameterConstraint { TypeParamDefsItemType::DataType => Self::DataType, TypeParamDefsItemType::Boolean => Self::Boolean, TypeParamDefsItemType::Integer => { - // TODO: This truncates from float to int; probably fine - let min_i = min.map(|n| n as i64); - let max_i = max.map(|n| n as i64); + if let Some(min_f) = min { + if min_f.fract() != 0.0 { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + } + if let Some(max_f) = max { + if max_f.fract() != 0.0 { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); + } + } + + let min_i = min.map(|v| v as i64); + let max_i = max.map(|v| v as i64); Self::Integer { min: min_i, max: max_i, @@ -1544,10 +1554,10 @@ mod tests { options: None, optional: None, }, - (Some(1), Some(10)), + Ok((Some(1), Some(10))), ), ( - "truncated", + "fractional_min", simple_extensions::TypeParamDefsItem { name: Some("K".to_string()), description: None, @@ -1557,18 +1567,44 @@ mod tests { options: None, optional: None, }, - (Some(1), None), + Err(TypeParamError::InvalidIntegerBounds { + min: Some(1.5), + max: None, + }), + ), + ( + "fractional_max", + simple_extensions::TypeParamDefsItem { + name: Some("K".to_string()), + description: None, + type_: simple_extensions::TypeParamDefsItemType::Integer, + min: None, + max: Some(9.9), + options: None, + optional: None, + }, + Err(TypeParamError::InvalidIntegerBounds { + min: None, + max: Some(9.9), + }), ), ]; - for (label, item, (expected_min, expected_max)) in cases { - let tp = TypeParam::try_from(item).expect("should parse integer bounds"); - match tp.param_type { - ParameterConstraint::Integer { min, max } => { - assert_eq!(min, expected_min, "min mismatch for {label}"); - assert_eq!(max, expected_max, "max mismatch for {label}"); + for (label, item, expected) in cases { + match (TypeParam::try_from(item), expected) { + (Ok(tp), Ok((expected_min, expected_max))) => match tp.param_type { + ParameterConstraint::Integer { min, max } => { + assert_eq!(min, expected_min, "min mismatch for {label}"); + assert_eq!(max, expected_max, "max mismatch for {label}"); + } + _ => panic!("expected integer param type for {label}"), + }, + (Err(actual_err), Err(expected_err)) => { + assert_eq!(actual_err, expected_err, "unexpected error for {label}"); + } + (result, expected) => { + panic!("unexpected result for {label}: got {result:?}, expected {expected:?}") } - _ => panic!("expected integer param type for {label}"), } } } From 0cbfc348ab741f59770d0f70014421a8080ade15 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 15:11:16 -0400 Subject: [PATCH 23/31] Validate missing types --- .../text/simple_extensions/extensions.rs | 7 ++++++ src/parse/text/simple_extensions/file.rs | 22 +++++++++++++++++++ src/parse/text/simple_extensions/mod.rs | 5 ++--- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 9f5d484c..5f63752f 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -96,6 +96,13 @@ impl Parse for RawExtensions { extension.add_type(&custom_type); } + if let Some(missing) = ctx.linked.iter().next() { + // TODO: Track originating type(s) to improve this error message. + return Err(super::SimpleExtensionsError::UnresolvedTypeReference { + type_name: missing.clone(), + }); + } + Ok((urn, extension)) } } diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index f023b14b..7b507597 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -103,6 +103,15 @@ types: max: 10 "#; + const YAML_UNRESOLVED_TYPE: &str = r#" +%YAML 1.2 +--- +urn: extension:example.com:unresolved +types: + - name: "Alias" + structure: List> +"#; + #[test] fn yaml_round_trip_integer_param_bounds() { let ext = ExtensionFile::read_from_str(YAML_PARAM_TEST).expect("parse ok"); @@ -139,4 +148,17 @@ types: assert_eq!(param.min, Some(1.0)); assert_eq!(param.max, Some(10.0)); } + + #[test] + fn unresolved_type_reference_errors() { + let err = ExtensionFile::read_from_str(YAML_UNRESOLVED_TYPE) + .expect_err("expected unresolved type reference error"); + + match err { + SimpleExtensionsError::UnresolvedTypeReference { type_name } => { + assert_eq!(type_name, "MissingType"); + } + other => panic!("unexpected error type: {other:?}"), + } + } } diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index cd9e5886..fed447b4 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -36,12 +36,11 @@ pub enum SimpleExtensionsError { #[error("invalid urn")] InvalidUrn(#[from] crate::urn::InvalidUrn), /// Unresolved type reference in structure field - #[error("Type '{type_name}' referenced in '{originating}' structure not found")] + #[error("Type '{type_name}' referenced in structure not found")] UnresolvedTypeReference { /// The type name that could not be resolved type_name: String, - /// The type that contains the unresolved reference - originating: String, + // TODO: the location in the file where this came from would be nice }, } From 93b41efc2937c33234374117f34476e334e76a75 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 15:16:13 -0400 Subject: [PATCH 24/31] Add error for type variations --- .../text/simple_extensions/parsed_type.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index 0ac61e88..b624aa3b 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -29,6 +29,8 @@ pub enum TypeExprParam<'a> { pub enum TypeParseError { #[error("Parameter list {0} Must start and end with angle brackets")] ExpectedClosingAngleBracket(String), + #[error("Type variation syntax is not supported: {0}")] + UnsupportedVariation(String), } impl<'a> TypeExpr<'a> { @@ -60,6 +62,10 @@ impl<'a> TypeExpr<'a> { None => (rest, vec![]), }; + if name_and_nullable.contains('[') || name_and_nullable.contains(']') { + return Err(TypeParseError::UnsupportedVariation(type_str.to_string())); + } + let (name, nullable) = match name_and_nullable.strip_suffix('?') { Some(name) => (name, true), None => (name_and_nullable, false), @@ -236,4 +242,16 @@ mod tests { assert_eq!(refs, expected_refs, "unexpected references for {expr}"); } } + + #[test] + fn test_variation_not_supported() { + let cases = vec!["i32[1]", "Foo?[1]", "u!bar[2]" ]; + + for expr in cases { + match TypeExpr::parse(expr) { + Err(TypeParseError::UnsupportedVariation(s)) => assert_eq!(s, expr), + other => panic!("expected UnsupportedVariation for {expr}, got {other:?}"), + } + } + } } From 470aa00b6f23a7581d7b55f4e5ccc80ff1b0d7b3 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 15:27:11 -0400 Subject: [PATCH 25/31] Fix casing on type display, add tests --- src/parse/text/simple_extensions/types.rs | 41 +++++++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index 39785894..c4e82369 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -794,10 +794,10 @@ impl fmt::Display for ConcreteTypeKind { write!(f, "{name}")?; write_separated(f, parameters.iter(), "<", ">", ", ") } - ConcreteTypeKind::List(elem) => write!(f, "List<{elem}>"), - ConcreteTypeKind::Map { key, value } => write!(f, "Map<{key}, {value}>"), + ConcreteTypeKind::List(elem) => write!(f, "list<{elem}>"), + ConcreteTypeKind::Map { key, value } => write!(f, "map<{key}, {value}>"), ConcreteTypeKind::Struct(types) => { - write_separated(f, types.iter(), "Struct<", ">", ", ") + write_separated(f, types.iter(), "struct<", ">", ", ") } ConcreteTypeKind::NamedStruct { field_names, @@ -1540,6 +1540,41 @@ mod tests { } } + #[test] + fn test_type_round_trip_display() { + let cases = vec![ + ("i32", None), + ("I64?", Some("i64?")), + ("list", None), + ("List", Some("list")), + ("map>", None), + ( + "struct", + None, + ), + ( + "Struct, Map>>", + Some("struct, map>>") + ), + ( + "Map, Struct>>", + Some("map, struct>>"), + ), + ("u!custom", Some("custom")), + ]; + + for (input, expected) in cases { + let parsed = TypeExpr::parse(input).unwrap(); + let concrete = ConcreteType::try_from(parsed).unwrap(); + let actual = concrete.to_string(); + if let Some(expected_display) = expected { + assert_eq!(actual, expected_display, "unexpected display for {input}"); + } else { + assert_eq!(actual, input, "unexpected canonical output for {input}"); + } + } + } + #[test] fn test_integer_param_bounds_round_trip() { let cases = vec![ From 202747a70bd74fb61ed6ce728304ae8f7f4e9e44 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 19 Sep 2025 15:48:46 -0400 Subject: [PATCH 26/31] Stable field ordering for structures --- src/parse/text/simple_extensions/types.rs | 77 +++++++++++++++++++++-- 1 file changed, 71 insertions(+), 6 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index c4e82369..eb0c662c 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -714,10 +714,43 @@ impl Parse for RawType { Ok(concrete) } RawType::Variant1(field_map) => { + // Here we have the internal structure of a custom type, + // specified by (field name, type) pairs. Note that in the YAML + // itself, these are a map - and thus, while the text has an + // order, the data implicitly does not. In Rust, the field map + // is of type serde_json::Map, which also does not preserve + // order. + // + // So while it might be surprising in some ways that we are not + // preserving the order of fields as specified in the YAML, the + // nature of YAML somewhat precludes that. + // + // To give an example: we are considering these two equivalent: + // + // ```yaml + // types: + // - name: point1 + // structure: + // longitude: i32 + // latitude: i32 + // - name: point2 + // structure: + // latitude: i32 + // longitude: i32 + // ``` + // + // In Rust, these come in as a `serde_json::Map`, which does not + // preserve insertion order. Here, we chose to store keys in + // lexicographic order, with an explicit sort so that our + // internal representation and any round-trip output remain + // deterministic. + let mut entries: Vec<_> = field_map.into_iter().collect(); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + let mut field_names = Vec::new(); let mut field_types = Vec::new(); - for (field_name, field_type_value) in field_map { + for (field_name, field_type_value) in entries { field_names.push(field_name); let type_string = match field_type_value { @@ -1284,6 +1317,10 @@ mod tests { .iter() .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))), ); + + // Named struct YAML/json objects are inherently unordered; we sort the + // fields lexicographically when parsing so round-tripped output is + // deterministic. This test locks in that behaviour. RawType::Variant1(map) } @@ -1548,13 +1585,10 @@ mod tests { ("list", None), ("List", Some("list")), ("map>", None), - ( - "struct", - None, - ), + ("struct", None), ( "Struct, Map>>", - Some("struct, map>>") + Some("struct, map>>"), ), ( "Map, Struct>>", @@ -1575,6 +1609,37 @@ mod tests { } } + /// Test that named struct field order is stable and sorted + /// lexicographically when round-tripping through RawType. + /// + /// Normally, order in structs in SQL / relational algebra is significant; + /// but the spec doesn't say that, and it starts as a YAML map, which + /// doesn't generally preserve order, so for both implementation ease and + /// test stability, we sort the fields when parsing named structs. + /// (Preserving order would be difficult - most YAML parsers don't preserve + /// map order) + #[test] + fn test_named_struct_field_order_stability() -> Result<(), ExtensionTypeError> { + let mut raw_fields = serde_json::Map::new(); + raw_fields.insert("beta".to_string(), Value::String("i32".to_string())); + raw_fields.insert("alpha".to_string(), Value::String("string?".to_string())); + + let raw = RawType::Variant1(raw_fields); + let mut ctx = TypeContext::default(); + let concrete = Parse::parse(raw, &mut ctx)?; + + let round_tripped: RawType = concrete.into(); + match round_tripped { + RawType::Variant1(result_map) => { + let keys: Vec<_> = result_map.keys().collect(); + assert_eq!(keys, vec!["alpha", "beta"], "field order should be sorted"); + } + other => panic!("expected Variant1, got {other:?}"), + } + + Ok(()) + } + #[test] fn test_integer_param_bounds_round_trip() { let cases = vec![ From ac479d98a030a2eed5170923172f4de63c4fdd74 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 10 Oct 2025 09:24:32 +0200 Subject: [PATCH 27/31] fix(parse): apply Substrait identifier rules for CustomType names --- src/parse/text/simple_extensions/types.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index eb0c662c..f6887add 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -585,13 +585,22 @@ pub struct CustomType { impl CustomType { /// Check if this type name is valid according to Substrait naming rules + /// (see the `Identifier` rule in `substrait/grammar/SubstraitLexer.g4`). + /// Identifiers are case-insensitive and must start with a an ASCII letter, + /// `_`, or `$`, followed by ASCII letters, digits, `_`, or `$`. + // + // Note: I'm not sure if `$` is actually something we want to allow, or if + // `_` is, but it's in the grammar so I'm allowing it here. pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> { - if name.is_empty() { + let mut chars = name.chars(); + let first = chars + .next() + .ok_or_else(|| InvalidTypeName(name.to_string()))?; + if !(first.is_ascii_alphabetic() || first == '_' || first == '$') { return Err(InvalidTypeName(name.to_string())); } - // Basic validation - could be extended with more rules - if name.contains(|c: char| c.is_whitespace()) { + if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$') { return Err(InvalidTypeName(name.to_string())); } @@ -1756,8 +1765,14 @@ mod tests { let cases = vec![ ("", false), ("bad name", false), + ("9bad", false), + ("bad-name", false), + ("bad.name", false), ("GoodName", true), ("also_good", true), + ("_underscore", true), + ("$dollar", true), + ("CamelCase123", true), ]; for (name, expected_ok) in cases { From c0a60b408b0dd0bcc95526c4be635d8a09210c78 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 10 Oct 2025 10:57:55 +0200 Subject: [PATCH 28/31] docs(parse): add SPDX header, clarify errors, and guard duplicate extension types --- Cargo.toml | 2 ++ src/parse/text/simple_extensions/argument.rs | 6 ++++-- src/parse/text/simple_extensions/extensions.rs | 13 ++++++++++--- src/parse/text/simple_extensions/file.rs | 2 ++ src/parse/text/simple_extensions/mod.rs | 7 ++++++- src/parse/text/simple_extensions/parsed_type.rs | 4 ++-- src/parse/text/simple_extensions/registry.rs | 2 -- src/parse/text/simple_extensions/types.rs | 2 +- 8 files changed, 27 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ea23c13c..a8d6aa85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,8 @@ pbjson = { version = "0.8.0", optional = true } pbjson-types = { version = "0.8.0", optional = true } prost = "0.14.1" prost-types = "0.14.1" +# Required by generated text schemas: the typify-generated code emits +# ::regress::Regex for `pattern` validations. regress = "0.10.4" semver = { version = "1.0.27", optional = true } serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/parse/text/simple_extensions/argument.rs b/src/parse/text/simple_extensions/argument.rs index e4e5d4fb..8e52f2bc 100644 --- a/src/parse/text/simple_extensions/argument.rs +++ b/src/parse/text/simple_extensions/argument.rs @@ -237,7 +237,8 @@ pub struct ValueArg { /// A fully defined type or a type expression. /// - /// todo: implement parsed [simple_extensions::Type]. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` parser) + /// so the caller does not have to interpret the raw string. value: simple_extensions::Type, /// Whether this argument is required to be a constant for invocation. @@ -319,7 +320,8 @@ pub struct TypeArg { /// A partially or completely parameterized type. E.g. `List` or `K`. /// - /// todo: implement parsed [simple_extensions::Type]. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` parser) + /// so the caller does not have to interpret the raw string. type_: String, } diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 5f63752f..087083c3 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -5,7 +5,7 @@ use std::collections::{HashMap, HashSet}; use std::str::FromStr; -use super::types::CustomType; +use super::{SimpleExtensionsError, types::CustomType}; use crate::{ parse::{Context, Parse}, text::simple_extensions::SimpleExtensions as RawExtensions, @@ -26,9 +26,16 @@ pub struct SimpleExtensions { impl SimpleExtensions { /// Add a type to the context - pub fn add_type(&mut self, custom_type: &CustomType) { + pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> { + if self.types.contains_key(&custom_type.name) { + return Err(SimpleExtensionsError::DuplicateTypeName { + name: custom_type.name.clone(), + }); + } + self.types .insert(custom_type.name.clone(), custom_type.clone()); + Ok(()) } /// Check if a type with the given name exists in the context @@ -93,7 +100,7 @@ impl Parse for RawExtensions { for type_item in types { let custom_type = Parse::parse(type_item, ctx)?; // Add the parsed type to the context so later types can reference it - extension.add_type(&custom_type); + extension.add_type(&custom_type)?; } if let Some(missing) = ctx.linked.iter().next() { diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index 7b507597..2a24545b 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + use super::{CustomType, SimpleExtensions, SimpleExtensionsError}; use crate::parse::Parse; use crate::parse::text::simple_extensions::extensions::TypeContext; diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index fed447b4..a3e019cc 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -14,7 +14,6 @@ mod registry; mod types; pub use extensions::SimpleExtensions; -pub use extensions::TypeContext; pub use file::ExtensionFile; pub use parsed_type::TypeExpr; pub use registry::Registry; @@ -42,6 +41,12 @@ pub enum SimpleExtensionsError { type_name: String, // TODO: the location in the file where this came from would be nice }, + /// Duplicate type definition within the same extension + #[error("duplicate type definition for `{name}`")] + DuplicateTypeName { + /// The repeated type name + name: String, + }, } // Needed for certain conversions - e.g. Urn -> Urn - to succeed. diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index b624aa3b..a238a1bb 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -27,7 +27,7 @@ pub enum TypeExprParam<'a> { #[derive(Debug, PartialEq, thiserror::Error)] pub enum TypeParseError { - #[error("Parameter list {0} Must start and end with angle brackets")] + #[error("missing closing angle bracket in parameter list: {0}")] ExpectedClosingAngleBracket(String), #[error("Type variation syntax is not supported: {0}")] UnsupportedVariation(String), @@ -245,7 +245,7 @@ mod tests { #[test] fn test_variation_not_supported() { - let cases = vec!["i32[1]", "Foo?[1]", "u!bar[2]" ]; + let cases = vec!["i32[1]", "Foo?[1]", "u!bar[2]"]; for expr in cases { match TypeExpr::parse(expr) { diff --git a/src/parse/text/simple_extensions/registry.rs b/src/parse/text/simple_extensions/registry.rs index 542d4bfa..c6bd39b7 100644 --- a/src/parse/text/simple_extensions/registry.rs +++ b/src/parse/text/simple_extensions/registry.rs @@ -10,8 +10,6 @@ //! //! This module is only available when the `parse` feature is enabled. -#![cfg(feature = "parse")] - use super::{ExtensionFile, types::CustomType}; use crate::urn::Urn; diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index f6887add..baa97ac2 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -793,7 +793,7 @@ impl Parse for RawType { /// Invalid type name error #[derive(Debug, Error, PartialEq)] -#[error("{0}")] +#[error("invalid type name `{0}`")] pub struct InvalidTypeName(String); /// Known Substrait types (builtin + extension references) From ca1e133f07dcce0c5dc82a0c880225c745622f31 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 10 Oct 2025 11:23:36 +0200 Subject: [PATCH 29/31] fix: clippy warning about nested ifs --- src/parse/text/simple_extensions/types.rs | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index baa97ac2..de7bde59 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -400,15 +400,15 @@ impl ParameterConstraint { TypeParamDefsItemType::DataType => Self::DataType, TypeParamDefsItemType::Boolean => Self::Boolean, TypeParamDefsItemType::Integer => { - if let Some(min_f) = min { - if min_f.fract() != 0.0 { - return Err(TypeParamError::InvalidIntegerBounds { min, max }); - } + if let Some(min_f) = min + && min_f.fract() != 0.0 + { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); } - if let Some(max_f) = max { - if max_f.fract() != 0.0 { - return Err(TypeParamError::InvalidIntegerBounds { min, max }); - } + if let Some(max_f) = max + && max_f.fract() != 0.0 + { + return Err(TypeParamError::InvalidIntegerBounds { min, max }); } let min_i = min.map(|v| v as i64); @@ -1016,15 +1016,15 @@ fn expect_integer_param( }), }?; - if let Some(range) = &range { - if !range.contains(&value) { - return Err(ExtensionTypeError::InvalidParameterRange { - type_name: type_name.to_string(), - index, - value: i64::from(value), - expected: range.clone(), - }); - } + if let Some(range) = &range + && !range.contains(&value) + { + return Err(ExtensionTypeError::InvalidParameterRange { + type_name: type_name.to_string(), + index, + value: i64::from(value), + expected: range.clone(), + }); } Ok(value) From 95270ab03118bcff4d8d2d5384d649a42c6d2074 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 10 Oct 2025 11:41:42 +0200 Subject: [PATCH 30/31] doc: update extensions types and links --- src/parse/text/mod.rs | 8 ++- src/parse/text/simple_extensions/argument.rs | 51 ++++++++++--------- .../text/simple_extensions/extensions.rs | 20 +++++--- src/parse/text/simple_extensions/file.rs | 25 ++++----- src/parse/text/simple_extensions/mod.rs | 14 ++++- .../text/simple_extensions/parsed_type.rs | 2 +- 6 files changed, 73 insertions(+), 47 deletions(-) diff --git a/src/parse/text/mod.rs b/src/parse/text/mod.rs index 4f7b641e..73219362 100644 --- a/src/parse/text/mod.rs +++ b/src/parse/text/mod.rs @@ -1,5 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [text](crate::text) types. +//! Utilities for working with Substrait *text* objects. +//! +//! The generated [`crate::text`] module exposes the raw YAML-derived structs +//! (e.g. [`crate::text::simple_extensions::SimpleExtensions`]). This module +//! provides parsing helpers that validate those raw values and offer +//! higher-level wrappers for validation, lookups, and combining into protobuf +//! objects. pub mod simple_extensions; diff --git a/src/parse/text/simple_extensions/argument.rs b/src/parse/text/simple_extensions/argument.rs index 8e52f2bc..f08a3587 100644 --- a/src/parse/text/simple_extensions/argument.rs +++ b/src/parse/text/simple_extensions/argument.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [simple_extensions::ArgumentsItem]. +//! Parsing of type arguments: [`simple_extensions::ArgumentsItem`]. use std::{collections::HashSet, ops::Deref}; @@ -11,21 +11,24 @@ use crate::{ text::simple_extensions, }; -/// A parsed [simple_extensions::ArgumentsItem]. +/// A parsed [`simple_extensions::ArgumentsItem`]. #[derive(Clone, Debug)] pub enum ArgumentsItem { - /// Arguments that support a fixed set of declared values as constant arguments. + /// Arguments that support a fixed set of declared values as constant + /// arguments. EnumArgument(EnumerationArg), /// Arguments that refer to a data value. ValueArgument(ValueArg), - /// Arguments that are used only to inform the evaluation and/or type derivation of the function. + /// Arguments that are used only to inform the evaluation and/or type + /// derivation of the function. TypeArgument(TypeArg), } impl ArgumentsItem { - /// Parses an `Option` field, rejecting it if an empty string is provided. + /// Parses an `Option` field, rejecting it if an empty string is + /// provided. #[inline] fn parse_optional_string( name: &str, @@ -75,7 +78,7 @@ impl From for simple_extensions::ArgumentsItem { } } -/// Parse errors for [simple_extensions::ArgumentsItem]. +/// Parse errors for [`simple_extensions::ArgumentsItem`]. #[derive(Debug, Error, PartialEq)] pub enum ArgumentsItemError { /// Invalid enumeration options. @@ -103,21 +106,21 @@ pub struct EnumerationArg { impl EnumerationArg { /// Returns the name of this argument. /// - /// See [simple_extensions::EnumerationArg::name]. + /// See [`simple_extensions::EnumerationArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::EnumerationArg::description]. + /// See [`simple_extensions::EnumerationArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } /// Returns the options of this argument. /// - /// See [simple_extensions::EnumerationArg::options]. + /// See [`simple_extensions::EnumerationArg::options`]. pub fn options(&self) -> &EnumOptions { &self.options } @@ -210,7 +213,7 @@ impl From for simple_extensions::EnumOptions { } } -/// Parse errors for [simple_extensions::EnumOptions]. +/// Parse errors for [`simple_extensions::EnumOptions`]. #[derive(Debug, Error, PartialEq)] pub enum EnumOptionsError { /// Empty list. @@ -237,27 +240,27 @@ pub struct ValueArg { /// A fully defined type or a type expression. /// - /// TODO: parse this to a typed representation (likely using the `TypeExpr` parser) - /// so the caller does not have to interpret the raw string. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` + /// parser) so the caller does not have to interpret the raw string. value: simple_extensions::Type, - /// Whether this argument is required to be a constant for invocation. - /// For example, in some system a regular expression pattern would only be accepted as a literal - /// and not a column value reference. + /// Whether this argument is required to be a constant for invocation. For + /// example, in some system a regular expression pattern would only be + /// accepted as a literal and not a column value reference. constant: Option, } impl ValueArg { /// Returns the name of this argument. /// - /// See [simple_extensions::ValueArg::name]. + /// See [`simple_extensions::ValueArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::ValueArg::description]. + /// See [`simple_extensions::ValueArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } @@ -265,7 +268,7 @@ impl ValueArg { /// Returns the constant of this argument. /// Defaults to `false` if the underlying value is `None`. /// - /// See [simple_extensions::ValueArg::constant]. + /// See [`simple_extensions::ValueArg::constant`]. pub fn constant(&self) -> bool { self.constant.unwrap_or(false) } @@ -309,10 +312,10 @@ impl From for ArgumentsItem { } } -/// Arguments that are used only to inform the evaluation and/or type derivation of the function. +/// A type argument to a parameterized type, e.g. the `T` in `List`. #[derive(Clone, Debug, PartialEq)] pub struct TypeArg { - /// A human-readable name for this argument to help clarify use. + /// A human-readable name for this argument to clarify use. name: Option, /// Additional description for this argument. @@ -320,22 +323,22 @@ pub struct TypeArg { /// A partially or completely parameterized type. E.g. `List` or `K`. /// - /// TODO: parse this to a typed representation (likely using the `TypeExpr` parser) - /// so the caller does not have to interpret the raw string. + /// TODO: parse this to a typed representation (likely using the `TypeExpr` + /// parser) so the caller does not have to interpret the raw string. type_: String, } impl TypeArg { /// Returns the name of this argument. /// - /// See [simple_extensions::TypeArg::name]. + /// See [`simple_extensions::TypeArg::name`]. pub fn name(&self) -> Option<&String> { self.name.as_ref() } /// Returns the description of this argument. /// - /// See [simple_extensions::TypeArg::description]. + /// See [`simple_extensions::TypeArg::description`]. pub fn description(&self) -> Option<&String> { self.description.as_ref() } diff --git a/src/parse/text/simple_extensions/extensions.rs b/src/parse/text/simple_extensions/extensions.rs index 087083c3..56344c00 100644 --- a/src/parse/text/simple_extensions/extensions.rs +++ b/src/parse/text/simple_extensions/extensions.rs @@ -12,12 +12,14 @@ use crate::{ urn::Urn, }; -/// Parsing context for extension processing +/// The contents (types) in an [`ExtensionFile`](super::file::ExtensionFile). /// -/// The context provides access to types defined in the same extension file during parsing. -/// This allows type references to be resolved within the same extension file. The corresponding -/// URN is tracked by [`ExtensionFile`](super::file::ExtensionFile) so this structure can focus on -/// validated type information. +/// This structure stores and provides access to the individual objects defined +/// in an [`ExtensionFile`](super::file::ExtensionFile); [`SimpleExtensions`] +/// represents the contents of an extensions file. +/// +/// Currently, only the [`CustomType`]s are included; any scalar, window, or +/// aggregate functions are not yet included. #[derive(Clone, Debug, Default)] pub struct SimpleExtensions { /// Types defined in this extension file @@ -25,7 +27,7 @@ pub struct SimpleExtensions { } impl SimpleExtensions { - /// Add a type to the context + /// Add a type to the context. Name must be unique. pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> { if self.types.contains_key(&custom_type.name) { return Err(SimpleExtensionsError::DuplicateTypeName { @@ -59,7 +61,8 @@ impl SimpleExtensions { } } -/// A context for parsing simple extensions. +/// A context for parsing simple extensions, tracking what type names are +/// resolved or unresolved. #[derive(Debug, Default)] pub struct TypeContext { /// Types that have been seen so far, now resolved. @@ -75,7 +78,8 @@ impl TypeContext { self.known.insert(name.to_string()); } - /// Mark a type as linked to - some other type or function references it, but we haven't seen it. + /// Mark a type as linked to - some other type or function references it, + /// but we haven't seen it. pub fn linked(&mut self, name: &str) { if !self.known.contains(name) { self.linked.insert(name.to_string()); diff --git a/src/parse/text/simple_extensions/file.rs b/src/parse/text/simple_extensions/file.rs index 2a24545b..bc2a3dcf 100644 --- a/src/parse/text/simple_extensions/file.rs +++ b/src/parse/text/simple_extensions/file.rs @@ -7,11 +7,10 @@ use crate::text::simple_extensions::SimpleExtensions as RawExtensions; use crate::urn::Urn; use std::io::Read; -/// A parsed and validated [RawExtensions]. +/// A parsed and validated [`RawExtensions`]: a simple extensions file. /// -/// `ExtensionFile` owns the canonical URN for a simple extension file along with the parsed -/// [`SimpleExtensions`](super::SimpleExtensions) data. Keeping the URN here (instead of on the inner -/// type map) lets us thread it through I/O, registries, and conversions without duplicating state. +/// An [`ExtensionFile`] has a canonical [`Urn`] and a parsed set of +/// [`SimpleExtensions`] data. It represents the extensions file as a whole. #[derive(Debug)] pub struct ExtensionFile { /// The URN this extension was loaded from @@ -21,13 +20,13 @@ pub struct ExtensionFile { } impl ExtensionFile { - /// Create a new, empty SimpleExtensions + /// Create a new, empty [`ExtensionFile`] with an empty set of [`SimpleExtensions`]. pub fn empty(urn: Urn) -> Self { let extension = SimpleExtensions::default(); Self { urn, extension } } - /// Create a validated SimpleExtensions from raw data + /// Create an [`ExtensionFile`] from raw simple extension data. pub fn create(extensions: RawExtensions) -> Result { // Parse all types (may contain unresolved Extension(String) references) let mut ctx = TypeContext::default(); @@ -48,31 +47,33 @@ impl ExtensionFile { self.extension.types() } - /// Returns the URN for this extension file. + /// Returns the [`Urn`]` for this extension file. pub fn urn(&self) -> &Urn { &self.urn } - /// Get a reference to the underlying SimpleExtension + /// Get a reference to the underlying [`SimpleExtensions`]. pub fn extension(&self) -> &SimpleExtensions { &self.extension } - /// Convert the parsed extension file back into the raw text representation by value. + /// Convert the parsed extension file back into the raw text representation + /// by value. pub fn into_raw(self) -> RawExtensions { let ExtensionFile { urn, extension } = self; RawExtensions::from((urn, extension)) } - /// Convert the parsed extension file back into the raw text representation by reference. + /// Convert the parsed extension file back into the raw text representation + /// by reference. pub fn to_raw(&self) -> RawExtensions { RawExtensions::from((self.urn.clone(), self.extension.clone())) } /// Read an extension file from a reader. - /// - `reader`: any `Read` instance with the YAML content + /// - `reader`: any [`Read`] instance with the YAML content /// - /// Returns a parsed and validated `ExtensionFile` or an error. + /// Returns a parsed and validated [`ExtensionFile`] or an error. pub fn read(reader: R) -> Result { let raw: RawExtensions = serde_yaml::from_reader(reader)?; Self::create(raw) diff --git a/src/parse/text/simple_extensions/mod.rs b/src/parse/text/simple_extensions/mod.rs index a3e019cc..0e66e7c8 100644 --- a/src/parse/text/simple_extensions/mod.rs +++ b/src/parse/text/simple_extensions/mod.rs @@ -1,6 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 -//! Parsing of [crate::text::simple_extensions] types into [SimpleExtensions]. +//! Rustic types for validating and working with Substrait simple extensions. +//! +//! The raw YAML structs live in [`crate::text::simple_extensions`]. This +//! module parses those values into the typed representations used by this +//! crate: +//! * [`ExtensionFile`] – a fully validated extension document (URN plus its +//! definitions). +//! * [`SimpleExtensions`] – the validated objects declared by a single +//! extension file. +//! * [`CustomType`] / [`ConcreteType`] – type definitions and resolved type +//! structures used when checking function signatures. +//! * [`Registry`] – a reusable lookup structure that stores validated extension +//! files and exposes typed access to their contents. use std::convert::Infallible; diff --git a/src/parse/text/simple_extensions/parsed_type.rs b/src/parse/text/simple_extensions/parsed_type.rs index a238a1bb..1a784330 100644 --- a/src/parse/text/simple_extensions/parsed_type.rs +++ b/src/parse/text/simple_extensions/parsed_type.rs @@ -34,7 +34,7 @@ pub enum TypeParseError { } impl<'a> TypeExpr<'a> { - /// Parse a type string into a ParsedType + /// Parse a type string into a [`TypeExpr`]. pub fn parse(type_str: &'a str) -> Result { // Handle type variables like any1, any2, etc. if let Some(suffix) = type_str.strip_prefix("any") { From 3e58a42ea2328264fe11e131cf64dd5515d3f3a2 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 10 Oct 2025 16:25:31 +0200 Subject: [PATCH 31/31] fix: update usage of typify-generated types to match typify upgrade --- src/parse/text/simple_extensions/types.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/parse/text/simple_extensions/types.rs b/src/parse/text/simple_extensions/types.rs index de7bde59..2385e9cc 100644 --- a/src/parse/text/simple_extensions/types.rs +++ b/src/parse/text/simple_extensions/types.rs @@ -709,7 +709,7 @@ impl Parse for RawType { fn parse(self, ctx: &mut TypeContext) -> Result { match self { - RawType::Variant0(type_string) => { + RawType::String(type_string) => { let parsed_type = TypeExpr::parse(&type_string)?; let mut link = |name: &str| ctx.linked(name); parsed_type.visit_references(&mut link); @@ -722,7 +722,7 @@ impl Parse for RawType { Ok(concrete) } - RawType::Variant1(field_map) => { + RawType::Object(field_map) => { // Here we have the internal structure of a custom type, // specified by (field name, type) pairs. Note that in the YAML // itself, these are a map - and thus, while the text has an @@ -987,9 +987,9 @@ impl From for RawType { panic!("duplicate value '{v:?}' in NamedStruct"); } } - RawType::Variant1(map) + RawType::Object(map) } - _ => RawType::Variant0(val.to_string()), + _ => RawType::String(val.to_string()), } } } @@ -1330,7 +1330,7 @@ mod tests { // Named struct YAML/json objects are inherently unordered; we sort the // fields lexicographically when parsing so round-tripped output is // deterministic. This test locks in that behaviour. - RawType::Variant1(map) + RawType::Object(map) } #[test] @@ -1633,17 +1633,17 @@ mod tests { raw_fields.insert("beta".to_string(), Value::String("i32".to_string())); raw_fields.insert("alpha".to_string(), Value::String("string?".to_string())); - let raw = RawType::Variant1(raw_fields); + let raw = RawType::Object(raw_fields); let mut ctx = TypeContext::default(); let concrete = Parse::parse(raw, &mut ctx)?; let round_tripped: RawType = concrete.into(); match round_tripped { - RawType::Variant1(result_map) => { + RawType::Object(result_map) => { let keys: Vec<_> = result_map.keys().collect(); assert_eq!(keys, vec!["alpha", "beta"], "field order should be sorted"); } - other => panic!("expected Variant1, got {other:?}"), + other => panic!("expected Object, got {other:?}"), } Ok(()) @@ -1790,7 +1790,7 @@ mod tests { let cases = vec![ ( "alias", - RawType::Variant0("i32".to_string()), + RawType::String("i32".to_string()), ConcreteType::builtin(PrimitiveType::I32, false), ), ( @@ -1825,7 +1825,7 @@ mod tests { name: "Alias".to_string(), description: Some("Alias type".to_string()), parameters: None, - structure: Some(RawType::Variant0("BINARY".to_string())), + structure: Some(RawType::String("BINARY".to_string())), variadic: None, }, "Alias", @@ -1890,7 +1890,7 @@ mod tests { /// 'INTEGER?' - is that now equal to `i64??` #[test] fn test_nullable_structure_rejected() { - let cases = vec![RawType::Variant0("i32?".to_string())]; + let cases = vec![RawType::String("i32?".to_string())]; for raw in cases { let mut ctx = TypeContext::default();