From 4094d144bf856ab4f594021e237560e717e20916 Mon Sep 17 00:00:00 2001 From: tokoko Date: Fri, 31 Oct 2025 00:05:46 +0400 Subject: [PATCH 1/2] feat: allow passing multiple functions to function builders --- src/substrait/builders/extended_expression.py | 42 +++++++++++++------ src/substrait/sql/sql_to_substrait.py | 30 ++++++------- .../test_aggregate_function.py | 3 +- .../builders/extended_expression/test_cast.py | 3 +- .../extended_expression/test_if_then.py | 3 +- .../extended_expression/test_multi_or_list.py | 3 +- .../test_scalar_function.py | 13 +++--- .../test_singular_or_list.py | 3 +- .../extended_expression/test_switch.py | 3 +- .../test_window_function.py | 2 +- tests/builders/plan/test_aggregate.py | 3 +- tests/test_uri_urn_migration.py | 11 ++--- 12 files changed, 64 insertions(+), 55 deletions(-) diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index abda416..07d9ec3 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -205,12 +205,12 @@ def resolve( def scalar_function( - urn: str, - function: str, + function: Union[str, Iterable[str]], expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" + functions = [function] if isinstance(function, str) else function def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry @@ -225,11 +225,17 @@ def resolve( signature = [typ for es in expression_schemas for typ in es.types] - func = registry.lookup_function(urn, function, signature) + for f in functions: + urn, name = f.rsplit(":", 1) + func = registry.lookup_function(urn, name, signature) + if func: + break if not func: raise Exception(f"Unknown function {function} for {signature}") + resolved_func, return_type = func + func_extension_urns = [ ste.SimpleExtensionURN( extension_urn_anchor=registry.lookup_urn(urn), urn=urn @@ -288,7 +294,7 @@ def resolve( ), output_names=_alias_or_inferred( alias, - function, + name, [e.referred_expr[0].output_names[0] for e in bound_expressions], ), ) @@ -303,12 +309,12 @@ def resolve( def aggregate_function( - urn: str, - function: str, + function: Union[str, Iterable[str]], expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" + functions = [function] if isinstance(function, str) else function def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry @@ -323,11 +329,17 @@ def resolve( signature = [typ for es in expression_schemas for typ in es.types] - func = registry.lookup_function(urn, function, signature) + for f in functions: + urn, name = f.rsplit(":", 1) + func = registry.lookup_function(urn, name, signature) + if func: + break if not func: raise Exception(f"Unknown function {function} for {signature}") + resolved_func, return_type = func + func_extension_urns = [ ste.SimpleExtensionURN( extension_urn_anchor=registry.lookup_urn(urn), urn=urn @@ -382,7 +394,7 @@ def resolve( ), output_names=_alias_or_inferred( alias, - "IfThen", + name, [e.referred_expr[0].output_names[0] for e in bound_expressions], ), ) @@ -398,13 +410,13 @@ def resolve( # TODO bounds, sorts def window_function( - urn: str, - function: str, + function: Union[str, Iterable[str]], expressions: Iterable[ExtendedExpressionOrUnbound], partitions: Iterable[ExtendedExpressionOrUnbound] = [], alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" + functions = [function] if isinstance(function, str) else function def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry @@ -423,11 +435,17 @@ def resolve( signature = [typ for es in expression_schemas for typ in es.types] - func = registry.lookup_function(urn, function, signature) + for f in functions: + urn, name = f.rsplit(":", 1) + func = registry.lookup_function(urn, name, signature) + if func: + break if not func: raise Exception(f"Unknown function {function} for {signature}") + resolved_func, return_type = func + func_extension_urns = [ ste.SimpleExtensionURN( extension_urn_anchor=registry.lookup_urn(urn), urn=urn @@ -495,7 +513,7 @@ def resolve( ), output_names=_alias_or_inferred( alias, - function, + name, [e.referred_expr[0].output_names[0] for e in bound_expressions], ), ) diff --git a/src/substrait/sql/sql_to_substrait.py b/src/substrait/sql/sql_to_substrait.py index 1b2b6c2..b4a6917 100644 --- a/src/substrait/sql/sql_to_substrait.py +++ b/src/substrait/sql/sql_to_substrait.py @@ -27,21 +27,21 @@ SchemaResolver = Callable[[str], stt.NamedStruct] -function_mapping = { - "Plus": ("extension:io.substrait:functions_arithmetic", "add"), - "Minus": ("extension:io.substrait:functions_arithmetic", "subtract"), - "Gt": ("extension:io.substrait:functions_comparison", "gt"), - "GtEq": ("extension:io.substrait:functions_comparison", "gte"), - "Lt": ("extension:io.substrait:functions_comparison", "lt"), - "Eq": ("extension:io.substrait:functions_comparison", "equal"), +scalar_function_mapping = { + "Plus": ["extension:io.substrait:functions_arithmetic:add"], + "Minus": ["extension:io.substrait:functions_arithmetic:subtract"], + "Gt": ["extension:io.substrait:functions_comparison:gt"], + "GtEq": ["extension:io.substrait:functions_comparison:gte"], + "Lt": ["extension:io.substrait:functions_comparison:lt"], + "Eq": ["extension:io.substrait:functions_comparison:equal"], } aggregate_function_mapping = { - "SUM": ("extension:io.substrait:functions_arithmetic", "sum") + "SUM": ["extension:io.substrait:functions_arithmetic:sum"] } window_function_mapping = { - "row_number": ("extension:io.substrait:functions_arithmetic", "row_number"), + "row_number": ["extension:io.substrait:functions_arithmetic:row_number"], } @@ -106,8 +106,8 @@ def translate_expression( groupings=groupings, ), ] - func = function_mapping[ast["op"]] - return scalar_function(func[0], func[1], expressions=expressions, alias=alias) + func = scalar_function_mapping[ast["op"]] + return scalar_function(func, expressions=expressions, alias=alias) elif op == "Value": return literal( int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias @@ -125,8 +125,8 @@ def translate_expression( ] name = ast["name"][0]["Identifier"]["value"] - if name in function_mapping: - func = function_mapping[name] + if name in scalar_function_mapping: + func = scalar_function_mapping[name] return scalar_function(func[0], func[1], *expressions, alias=alias) elif name in aggregate_function_mapping: # All measures need to be extracted out because substrait calculates measures in a separate rel @@ -140,7 +140,7 @@ def translate_expression( random_name = "".join( random.choices(string.ascii_uppercase + string.digits, k=5) ) # TODO make this deterministic - aggr = aggregate_function(func[0], func[1], expressions, alias=random_name) + aggr = aggregate_function(func, expressions, alias=random_name) measures.append((aggr, ast, random_name)) return column(random_name, alias=alias) elif name in window_function_mapping: @@ -158,7 +158,7 @@ def translate_expression( ] return window_function( - func[0], func[1], expressions, partitions=partitions, alias=alias + func, expressions, partitions=partitions, alias=alias ) else: diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py index cbbe2ea..39fe605 100644 --- a/tests/builders/extended_expression/test_aggregate_function.py +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -45,8 +45,7 @@ def test_aggregate_count(): e = aggregate_function( - "extension:test:urn", - "count", + "extension:test:urn:count", expressions=[ literal( 10, diff --git a/tests/builders/extended_expression/test_cast.py b/tests/builders/extended_expression/test_cast.py index 704f80d..4025d5f 100644 --- a/tests/builders/extended_expression/test_cast.py +++ b/tests/builders/extended_expression/test_cast.py @@ -70,8 +70,7 @@ def test_cast_with_extension(): actual = cast( input=scalar_function( - "extension:test:functions", - "add", + "extension:test:functions:add", expressions=[literal(1, i8()), literal(2, i8())], ), type=i16(), diff --git a/tests/builders/extended_expression/test_if_then.py b/tests/builders/extended_expression/test_if_then.py index 81d27d6..77c3d46 100644 --- a/tests/builders/extended_expression/test_if_then.py +++ b/tests/builders/extended_expression/test_if_then.py @@ -110,8 +110,7 @@ def test_if_then_with_extension(): ifs=[ ( scalar_function( - "extension:io.substrait:functions_comparison", - "gt", + "extension:io.substrait:functions_comparison:gt", expressions=[ column("order_total"), literal( diff --git a/tests/builders/extended_expression/test_multi_or_list.py b/tests/builders/extended_expression/test_multi_or_list.py index 701847f..44ab205 100644 --- a/tests/builders/extended_expression/test_multi_or_list.py +++ b/tests/builders/extended_expression/test_multi_or_list.py @@ -108,8 +108,7 @@ def test_multi_or_list_with_extension(): actual = multi_or_list( value=[ scalar_function( - "extension:test:functions", - "add", + "extension:test:functions:add", expressions=[literal(1, i8()), literal(2, i8())], ), literal(10, i8()), diff --git a/tests/builders/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py index c7819b9..6693692 100644 --- a/tests/builders/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -47,8 +47,7 @@ def test_sclar_add(): e = scalar_function( - "extension:test:urn", - "test_func", + "extension:test:urn:test_func", expressions=[ literal( 10, @@ -117,17 +116,19 @@ def test_sclar_add(): base_schema=named_struct, ) + print(e) + print("---------------") + print(expected) + assert e == expected def test_nested_scalar_calls(): e = scalar_function( - "extension:test:urn", - "is_positive", + "extension:test:urn:is_positive", expressions=[ scalar_function( - "extension:test:urn", - "test_func", + "extension:test:urn:test_func", expressions=[ literal( 10, diff --git a/tests/builders/extended_expression/test_singular_or_list.py b/tests/builders/extended_expression/test_singular_or_list.py index 82e998e..806f5de 100644 --- a/tests/builders/extended_expression/test_singular_or_list.py +++ b/tests/builders/extended_expression/test_singular_or_list.py @@ -76,8 +76,7 @@ def test_singular_or_list_with_extension(): actual = singular_or_list( value=scalar_function( - "extension:test:functions", - "add", + "extension:test:functions:add", expressions=[literal(1, i8()), literal(2, i8())], ), options=[literal(3, i8()), literal(4, i8())], diff --git a/tests/builders/extended_expression/test_switch.py b/tests/builders/extended_expression/test_switch.py index 0ec8b71..eda6165 100644 --- a/tests/builders/extended_expression/test_switch.py +++ b/tests/builders/extended_expression/test_switch.py @@ -100,8 +100,7 @@ def test_switch_with_extension(): actual = switch( match=scalar_function( - "extension:test:functions", - "add", + "extension:test:functions:add", expressions=[literal(1, i8()), literal(2, i8())], ), ifs=[ diff --git a/tests/builders/extended_expression/test_window_function.py b/tests/builders/extended_expression/test_window_function.py index 2abef07..29c3107 100644 --- a/tests/builders/extended_expression/test_window_function.py +++ b/tests/builders/extended_expression/test_window_function.py @@ -49,7 +49,7 @@ def test_row_number(): - e = window_function("extension:test:urn", "row_number", expressions=[], alias="rn")( + e = window_function("extension:test:urn:row_number", expressions=[], alias="rn")( named_struct, registry ) diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 0b30685..efc0031 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -43,8 +43,7 @@ def test_aggregate(): group_expr = column("id") measure_expr = aggregate_function( - "extension:test:urn", - "count", + "extension:test:urn:count", expressions=[column("is_applicable")], alias=["count"], ) diff --git a/tests/test_uri_urn_migration.py b/tests/test_uri_urn_migration.py index 184abd6..5e04905 100644 --- a/tests/test_uri_urn_migration.py +++ b/tests/test_uri_urn_migration.py @@ -49,8 +49,7 @@ def test_extended_expression_outputs_both_uri_and_urn(): named_struct = stt.NamedStruct(names=["value"], struct=struct) func_expr = scalar_function( - "extension:test:functions", - "test_func", + "extension:test:functions:test_func", expressions=[ literal( 10, @@ -137,8 +136,7 @@ def test_project_outputs_both_uri_and_urn(): table = read_named_table("table", named_struct) add_expr = scalar_function( - "extension:test:math", - "add", + "extension:test:math:add", expressions=[column("a"), column("b")], alias=["add"], ) @@ -213,8 +211,7 @@ def test_filter_outputs_both_uri_and_urn(): table = read_named_table("table", named_struct) gt_expr = scalar_function( - "extension:test:comparison", - "greater_than", + "extension:test:comparison:greater_than", expressions=[column("value"), literal(100, i64(nullable=False))], ) @@ -287,7 +284,7 @@ def test_aggregate_with_aggregate_function(): table = read_named_table("table", named_struct) sum_expr = aggregate_function( - "extension:test:aggregate", "sum", expressions=[column("value")], alias=["sum"] + "extension:test:aggregate:sum", expressions=[column("value")], alias=["sum"] ) actual = aggregate(table, grouping_expressions=[column("id")], measures=[sum_expr])( From 5c3e317838a88e8adb6613613fe5d089504ace30 Mon Sep 17 00:00:00 2001 From: tokoko Date: Fri, 31 Oct 2025 00:12:04 +0400 Subject: [PATCH 2/2] fix examples --- examples/adbc_example.py | 3 +-- examples/builder_example.py | 27 +++++++++------------------ examples/duckdb_example.py | 3 +-- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/adbc_example.py b/examples/adbc_example.py index 532033a..f67bc5d 100644 --- a/examples/adbc_example.py +++ b/examples/adbc_example.py @@ -45,8 +45,7 @@ def read_adbc_named_table(name: str, conn): table = filter( table, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "gte", + "extension:io.substrait:functions_comparison:gte", expressions=[column("ints"), literal(3, i64())], ), ) diff --git a/examples/builder_example.py b/examples/builder_example.py index b4f0cf3..b7bd686 100644 --- a/examples/builder_example.py +++ b/examples/builder_example.py @@ -29,8 +29,7 @@ def basic_example(): table = filter( table, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "lt", + "extension:io.substrait:functions_comparison:lt", expressions=[column("id"), literal(100, i64(nullable=False))], ), ) @@ -172,8 +171,7 @@ def advanced_example(): table = filter( table, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "lt", + "extension:io.substrait:functions_comparison:lt", expressions=[column("id"), literal(100, i64(nullable=False))], ), ) @@ -205,8 +203,7 @@ def advanced_example(): adult_users = filter( users, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "gt", + "extension:io.substrait:functions_comparison:gt", expressions=[column("age"), literal(25, i64())], ), ) @@ -221,8 +218,7 @@ def advanced_example(): column("salary"), # Add a calculated field (this would show function options if available) scalar_function( - "extension:io.substrait:functions_arithmetic", - "multiply", + "extension:io.substrait:functions_arithmetic:multiply", expressions=[column("salary"), literal(1.1, fp64())], alias="salary_with_bonus", ), @@ -254,8 +250,7 @@ def advanced_example(): high_value_orders = filter( orders, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "gt", + "extension:io.substrait:functions_comparison:gt", expressions=[column("amount"), literal(50.0, fp64())], ), ) @@ -286,17 +281,14 @@ def expression_only_example(): print("=== Expression-Only Example ===") # Show complex expression structure complex_expr = scalar_function( - "extension:io.substrait:functions_arithmetic", - "multiply", + "extension:io.substrait:functions_arithmetic:multiply", expressions=[ scalar_function( - "extension:io.substrait:functions_arithmetic", - "add", + "extension:io.substrait:functions_arithmetic:add", expressions=[ column("base_salary"), scalar_function( - "extension:io.substrait:functions_arithmetic", - "multiply", + "extension:io.substrait:functions_arithmetic:multiply", expressions=[ column("base_salary"), literal(0.15, fp64()), # 15% bonus @@ -305,8 +297,7 @@ def expression_only_example(): ], ), scalar_function( - "extension:io.substrait:functions_arithmetic", - "subtract", + "extension:io.substrait:functions_arithmetic:subtract", expressions=[ literal(1.0, fp64()), literal(0.25, fp64()), # 25% tax rate diff --git a/examples/duckdb_example.py b/examples/duckdb_example.py index ccd8172..e68873a 100644 --- a/examples/duckdb_example.py +++ b/examples/duckdb_example.py @@ -41,8 +41,7 @@ def read_duckdb_named_table(name: str, conn): table = filter( table, expression=scalar_function( - "extension:io.substrait:functions_comparison", - "equal", + "extension:io.substrait:functions_comparison:equal", expressions=[column("c_nationkey"), literal(3, i32())], ), )