Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/adbc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def read_adbc_named_table(name: str, conn):
table = filter(
table,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"gte",
"extension:io.substrait:functions_comparison:gte",
expressions=[column("ints"), literal(3, i64())],
),
)
Expand Down
27 changes: 9 additions & 18 deletions examples/builder_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def basic_example():
table = filter(
table,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"lt",
"extension:io.substrait:functions_comparison:lt",
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
Expand Down Expand Up @@ -172,8 +171,7 @@ def advanced_example():
table = filter(
table,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"lt",
"extension:io.substrait:functions_comparison:lt",
expressions=[column("id"), literal(100, i64(nullable=False))],
),
)
Expand Down Expand Up @@ -205,8 +203,7 @@ def advanced_example():
adult_users = filter(
users,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"gt",
"extension:io.substrait:functions_comparison:gt",
expressions=[column("age"), literal(25, i64())],
),
)
Expand All @@ -221,8 +218,7 @@ def advanced_example():
column("salary"),
# Add a calculated field (this would show function options if available)
scalar_function(
"extension:io.substrait:functions_arithmetic",
"multiply",
"extension:io.substrait:functions_arithmetic:multiply",
expressions=[column("salary"), literal(1.1, fp64())],
alias="salary_with_bonus",
),
Expand Down Expand Up @@ -254,8 +250,7 @@ def advanced_example():
high_value_orders = filter(
orders,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"gt",
"extension:io.substrait:functions_comparison:gt",
expressions=[column("amount"), literal(50.0, fp64())],
),
)
Expand Down Expand Up @@ -286,17 +281,14 @@ def expression_only_example():
print("=== Expression-Only Example ===")
# Show complex expression structure
complex_expr = scalar_function(
"extension:io.substrait:functions_arithmetic",
"multiply",
"extension:io.substrait:functions_arithmetic:multiply",
expressions=[
scalar_function(
"extension:io.substrait:functions_arithmetic",
"add",
"extension:io.substrait:functions_arithmetic:add",
expressions=[
column("base_salary"),
scalar_function(
"extension:io.substrait:functions_arithmetic",
"multiply",
"extension:io.substrait:functions_arithmetic:multiply",
expressions=[
column("base_salary"),
literal(0.15, fp64()), # 15% bonus
Expand All @@ -305,8 +297,7 @@ def expression_only_example():
],
),
scalar_function(
"extension:io.substrait:functions_arithmetic",
"subtract",
"extension:io.substrait:functions_arithmetic:subtract",
expressions=[
literal(1.0, fp64()),
literal(0.25, fp64()), # 25% tax rate
Expand Down
3 changes: 1 addition & 2 deletions examples/duckdb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def read_duckdb_named_table(name: str, conn):
table = filter(
table,
expression=scalar_function(
"extension:io.substrait:functions_comparison",
"equal",
"extension:io.substrait:functions_comparison:equal",
expressions=[column("c_nationkey"), literal(3, i32())],
),
)
Expand Down
42 changes: 30 additions & 12 deletions src/substrait/builders/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def resolve(


def scalar_function(
urn: str,
function: str,
function: Union[str, Iterable[str]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tokoko I am open minded to this approach, but I would strongly prefer not introducing string parsing if it is possible. Instead, maybe the API could be something like:

Suggested change
function: Union[str, Iterable[str]],
function: Union[ExtensionID, Iterable[ExtensionID]],

where

@dataclass
class ExtensionID:
  urn: str
  function: str

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also be open minded to just having a separate function for the individual and the list case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of parsing strings either. I simply wanted not to clutter the api too much. How about using NamedTuple instead of a dataclass. It would allow the user to pass plain tuples as well.

from typing import NamedTuple, Union, Iterable

class ExtensionFunction(NamedTuple):
    urn: str
    function: str

def process_func(func: Union[ExtensionFunction, Iterable[ExtensionFunction]]):
    functions = [func] if isinstance(func[0], str) else func
    for f in functions:
        urn, name = f
        print(urn)
        print(name)

process_func(ExtensionFunction("sample_urn", "sample_func"))
process_func(("sample_urn", "sample_func"))
process_func([("sample_urn", "sample_func1"), ("sample_urn", "sample_func2")])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with that approach as an API for allowing multiple parameters, though I still have my hesitations about the PR as a whole as expressed in the below comment.

expressions: Iterable[ExtendedExpressionOrUnbound],
alias: Union[Iterable[str], str] = None,
):
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
functions = [function] if isinstance(function, str) else function

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
Expand All @@ -225,11 +225,17 @@ def resolve(

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(urn, function, signature)
for f in functions:
urn, name = f.rsplit(":", 1)
func = registry.lookup_function(urn, name, signature)
if func:
break

if not func:
raise Exception(f"Unknown function {function} for {signature}")

resolved_func, return_type = func

func_extension_urns = [
ste.SimpleExtensionURN(
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
Expand Down Expand Up @@ -288,7 +294,7 @@ def resolve(
),
output_names=_alias_or_inferred(
alias,
function,
name,
[e.referred_expr[0].output_names[0] for e in bound_expressions],
),
)
Expand All @@ -303,12 +309,12 @@ def resolve(


def aggregate_function(
urn: str,
function: str,
function: Union[str, Iterable[str]],
expressions: Iterable[ExtendedExpressionOrUnbound],
alias: Union[Iterable[str], str] = None,
):
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
functions = [function] if isinstance(function, str) else function

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
Expand All @@ -323,11 +329,17 @@ def resolve(

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(urn, function, signature)
for f in functions:
urn, name = f.rsplit(":", 1)
func = registry.lookup_function(urn, name, signature)
if func:
break

if not func:
raise Exception(f"Unknown function {function} for {signature}")

resolved_func, return_type = func

func_extension_urns = [
ste.SimpleExtensionURN(
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
Expand Down Expand Up @@ -382,7 +394,7 @@ def resolve(
),
output_names=_alias_or_inferred(
alias,
"IfThen",
name,
[e.referred_expr[0].output_names[0] for e in bound_expressions],
),
)
Expand All @@ -398,13 +410,13 @@ def resolve(

# TODO bounds, sorts
def window_function(
urn: str,
function: str,
function: Union[str, Iterable[str]],
expressions: Iterable[ExtendedExpressionOrUnbound],
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
alias: Union[Iterable[str], str] = None,
):
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
functions = [function] if isinstance(function, str) else function

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
Expand All @@ -423,11 +435,17 @@ def resolve(

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(urn, function, signature)
for f in functions:
urn, name = f.rsplit(":", 1)
func = registry.lookup_function(urn, name, signature)
if func:
break

if not func:
raise Exception(f"Unknown function {function} for {signature}")

resolved_func, return_type = func

func_extension_urns = [
ste.SimpleExtensionURN(
extension_urn_anchor=registry.lookup_urn(urn), urn=urn
Expand Down Expand Up @@ -495,7 +513,7 @@ def resolve(
),
output_names=_alias_or_inferred(
alias,
function,
name,
[e.referred_expr[0].output_names[0] for e in bound_expressions],
),
)
Expand Down
30 changes: 15 additions & 15 deletions src/substrait/sql/sql_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@

SchemaResolver = Callable[[str], stt.NamedStruct]

function_mapping = {
"Plus": ("extension:io.substrait:functions_arithmetic", "add"),
"Minus": ("extension:io.substrait:functions_arithmetic", "subtract"),
"Gt": ("extension:io.substrait:functions_comparison", "gt"),
"GtEq": ("extension:io.substrait:functions_comparison", "gte"),
"Lt": ("extension:io.substrait:functions_comparison", "lt"),
"Eq": ("extension:io.substrait:functions_comparison", "equal"),
scalar_function_mapping = {
"Plus": ["extension:io.substrait:functions_arithmetic:add"],
"Minus": ["extension:io.substrait:functions_arithmetic:subtract"],
"Gt": ["extension:io.substrait:functions_comparison:gt"],
"GtEq": ["extension:io.substrait:functions_comparison:gte"],
"Lt": ["extension:io.substrait:functions_comparison:lt"],
"Eq": ["extension:io.substrait:functions_comparison:equal"],
}

aggregate_function_mapping = {
"SUM": ("extension:io.substrait:functions_arithmetic", "sum")
"SUM": ["extension:io.substrait:functions_arithmetic:sum"]
}

window_function_mapping = {
"row_number": ("extension:io.substrait:functions_arithmetic", "row_number"),
"row_number": ["extension:io.substrait:functions_arithmetic:row_number"],
}


Expand Down Expand Up @@ -106,8 +106,8 @@ def translate_expression(
groupings=groupings,
),
]
func = function_mapping[ast["op"]]
return scalar_function(func[0], func[1], expressions=expressions, alias=alias)
func = scalar_function_mapping[ast["op"]]
return scalar_function(func, expressions=expressions, alias=alias)
elif op == "Value":
return literal(
int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias
Expand All @@ -125,8 +125,8 @@ def translate_expression(
]
name = ast["name"][0]["Identifier"]["value"]

if name in function_mapping:
func = function_mapping[name]
if name in scalar_function_mapping:
func = scalar_function_mapping[name]
return scalar_function(func[0], func[1], *expressions, alias=alias)
elif name in aggregate_function_mapping:
# All measures need to be extracted out because substrait calculates measures in a separate rel
Expand All @@ -140,7 +140,7 @@ def translate_expression(
random_name = "".join(
random.choices(string.ascii_uppercase + string.digits, k=5)
) # TODO make this deterministic
aggr = aggregate_function(func[0], func[1], expressions, alias=random_name)
aggr = aggregate_function(func, expressions, alias=random_name)
measures.append((aggr, ast, random_name))
return column(random_name, alias=alias)
elif name in window_function_mapping:
Expand All @@ -158,7 +158,7 @@ def translate_expression(
]

return window_function(
func[0], func[1], expressions, partitions=partitions, alias=alias
func, expressions, partitions=partitions, alias=alias
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@

def test_aggregate_count():
e = aggregate_function(
"extension:test:urn",
"count",
"extension:test:urn:count",
expressions=[
literal(
10,
Expand Down
3 changes: 1 addition & 2 deletions tests/builders/extended_expression/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def test_cast_with_extension():

actual = cast(
input=scalar_function(
"extension:test:functions",
"add",
"extension:test:functions:add",
expressions=[literal(1, i8()), literal(2, i8())],
),
type=i16(),
Expand Down
3 changes: 1 addition & 2 deletions tests/builders/extended_expression/test_if_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def test_if_then_with_extension():
ifs=[
(
scalar_function(
"extension:io.substrait:functions_comparison",
"gt",
"extension:io.substrait:functions_comparison:gt",
expressions=[
column("order_total"),
literal(
Expand Down
3 changes: 1 addition & 2 deletions tests/builders/extended_expression/test_multi_or_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def test_multi_or_list_with_extension():
actual = multi_or_list(
value=[
scalar_function(
"extension:test:functions",
"add",
"extension:test:functions:add",
expressions=[literal(1, i8()), literal(2, i8())],
),
literal(10, i8()),
Expand Down
13 changes: 7 additions & 6 deletions tests/builders/extended_expression/test_scalar_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@

def test_sclar_add():
e = scalar_function(
"extension:test:urn",
"test_func",
"extension:test:urn:test_func",
expressions=[
literal(
10,
Expand Down Expand Up @@ -117,17 +116,19 @@ def test_sclar_add():
base_schema=named_struct,
)

print(e)
print("---------------")
print(expected)

assert e == expected


def test_nested_scalar_calls():
e = scalar_function(
"extension:test:urn",
"is_positive",
"extension:test:urn:is_positive",
expressions=[
scalar_function(
"extension:test:urn",
"test_func",
"extension:test:urn:test_func",
expressions=[
literal(
10,
Expand Down
3 changes: 1 addition & 2 deletions tests/builders/extended_expression/test_singular_or_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def test_singular_or_list_with_extension():

actual = singular_or_list(
value=scalar_function(
"extension:test:functions",
"add",
"extension:test:functions:add",
expressions=[literal(1, i8()), literal(2, i8())],
),
options=[literal(3, i8()), literal(4, i8())],
Expand Down
3 changes: 1 addition & 2 deletions tests/builders/extended_expression/test_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def test_switch_with_extension():

actual = switch(
match=scalar_function(
"extension:test:functions",
"add",
"extension:test:functions:add",
expressions=[literal(1, i8()), literal(2, i8())],
),
ifs=[
Expand Down
Loading