Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ParquetStorageBackend,
)
from ._typing import DataFrame, LazyFrame, Validation
from .columns import Any as AnyColumn
from .columns import Column, column_from_dict
from .config import Config
from .exc import (
Expand Down Expand Up @@ -814,7 +815,9 @@ def cast(
further down the line might fail because of the cast and/or missing columns.
"""
lf = df.lazy().select(
pl.col(name).cast(col.dtype) for name, col in cls.columns().items()
# Skip casting for Any columns since they accept any type
pl.col(name) if isinstance(col, AnyColumn) else pl.col(name).cast(col.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of building special treatment for Any here, how about we move ownership of casting into the Column itself? I.e. Column gets a cast method, and the default implementation is:

def cast(self, col: pl.Expr) -> pl.Expr:
    return col.cast(self.dtype)

In Any, we then implement the override:

def cast(self, col: pl.Expr) -> pl.Expr:
    return col

I think this would be neat because you never have to think about special casting logic outside the column implementations themselves

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this solution as well, and then I thought about the dy.Integer column. How should we manage this case ? Maybe this Column.cast function should take as well the type of the input expression, meaning that we need to wrapped it with pipe_with_schema.
I am away from the computer right now, I can have a deeper look on Tuesday.

for name, col in cls.columns().items()
)
if isinstance(df, pl.DataFrame):
return lf.collect() # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions tests/column_types/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ class AnySchema(dy.Schema):
def test_any_dtype_passes(data: dict[str, Any]) -> None:
df = pl.DataFrame(data)
assert AnySchema.is_valid(df)


def test_any_cast() -> None:
df = pl.DataFrame({"a": 0})
result = AnySchema.cast(df)
assert result["a"].dtype == pl.Int64
Loading