diff --git a/plugins/pandera/README.md b/plugins/pandera/README.md new file mode 100644 index 000000000..a400860c3 --- /dev/null +++ b/plugins/pandera/README.md @@ -0,0 +1,63 @@ +# Flyte Pandera Plugin + +This plugin integrates [Pandera](https://pandera.readthedocs.io/) with [Flyte](https://flyte.org/), enabling automatic runtime validation of pandas DataFrames against Pandera schemas as data flows between tasks. + +## Installation + +```bash +pip install flyteplugins-pandera +``` + +## Usage + +Define a Pandera schema using `DataFrameModel` and use `pandera.typing.DataFrame` as your task's type annotation: + +```python +import flyte +import pandas as pd +import pandera as pa +from pandera.typing import DataFrame, Series + +env = flyte.TaskEnvironment(name="my-env") + + +class UserSchema(pa.DataFrameModel): + name: Series[str] + age: Series[int] = pa.Field(ge=0, le=120) + email: Series[str] + + +@env.task +async def generate_users() -> DataFrame[UserSchema]: + return pd.DataFrame({ + "name": ["Alice", "Bob"], + "age": [25, 30], + "email": ["alice@example.com", "bob@example.com"], + }) + + +@env.task +async def process_users(df: DataFrame[UserSchema]) -> DataFrame[UserSchema]: + df["age"] = df["age"] + 1 + return df +``` + +DataFrames are automatically validated against the schema on both input and output. If validation fails, a detailed error report is generated. + +## Configuration + +Control validation behavior using `ValidationConfig` with `typing.Annotated`: + +```python +from typing import Annotated +from flyteplugins.pandera import ValidationConfig + +@env.task +async def lenient_task( + df: Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] +) -> DataFrame[UserSchema]: + return df +``` + +- `on_error="raise"` (default): Raises an exception on validation failure +- `on_error="warn"`: Logs a warning and continues with the original data diff --git a/plugins/pandera/pyproject.toml b/plugins/pandera/pyproject.toml new file mode 100644 index 000000000..f3b48159a --- /dev/null +++ b/plugins/pandera/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "flyteplugins-pandera" +dynamic = ["version"] +description = "Pandera data validation plugin for Flyte" +readme = "README.md" +authors = [{ name = "Flyte Contributors", email = "admin@flyte.org" }] +requires-python = ">=3.10" +dependencies = [ + "pandera", + "flyte", + "great_tables", +] + +[project.entry-points."flyte.plugins.types"] +pandera = "flyteplugins.pandera.transformer:register_pandera_transformer" + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415", "ASYNC240"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] diff --git a/plugins/pandera/src/flyteplugins/pandera/__init__.py b/plugins/pandera/src/flyteplugins/pandera/__init__.py new file mode 100644 index 000000000..2e8759edb --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/__init__.py @@ -0,0 +1,19 @@ +""" +Pandera data validation plugin for Flyte. + +This plugin integrates Pandera's runtime data validation with Flyte's type system, +enabling automatic validation of pandas DataFrames against Pandera schemas when +data flows between tasks. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PanderaTransformer + PandasReportRenderer + ValidationConfig +""" + +from .config import ValidationConfig as ValidationConfig +from .renderer import PandasReportRenderer as PandasReportRenderer +from .transformer import PanderaTransformer as PanderaTransformer diff --git a/plugins/pandera/src/flyteplugins/pandera/config.py b/plugins/pandera/src/flyteplugins/pandera/config.py new file mode 100644 index 000000000..68654aa23 --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/config.py @@ -0,0 +1,17 @@ +"""Pandera validation configuration.""" + +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class ValidationConfig: + """Configuration for Pandera validation behavior. + + Attributes: + on_error: Determines how validation errors are handled. + "raise" will raise the SchemaError/SchemaErrors exception. + "warn" will log a warning and continue with the original data. + """ + + on_error: Literal["raise", "warn"] = "raise" diff --git a/plugins/pandera/src/flyteplugins/pandera/renderer.py b/plugins/pandera/src/flyteplugins/pandera/renderer.py new file mode 100644 index 000000000..96b78cb9e --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/renderer.py @@ -0,0 +1,320 @@ +"""Pandera validation report renderer for Flyte Decks.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from flyte._utils import lazy_module + +if TYPE_CHECKING: + import great_tables as gt + import pandas + + import pandera +else: + gt = lazy_module("great_tables") + pandas = lazy_module("pandas") + pandera = lazy_module("pandera") + + +SCHEMA_ERROR_KEY = "SCHEMA" +DATA_ERROR_KEY = "DATA" +SCHEMA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "failure_case", "error"] +DATA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "index", "failure_case", "error"] +DATA_ERROR_DISPLAY_ORDER = ["column", "error_code", "percent_valid", "check", "failure_cases", "error"] + +DATA_PREVIEW_HEAD = 5 +FAILURE_CASE_LIMIT = 10 +ERROR_COLUMN_MAX_WIDTH = 200 + + +@dataclass +class PandasReport: + summary: "pandas.DataFrame" + data_preview: "pandas.DataFrame" + schema_error_df: Optional["pandas.DataFrame"] = None + data_error_df: Optional["pandas.DataFrame"] = None + + +class PandasReportRenderer: + """Renders Pandera validation reports as HTML for Flyte Decks.""" + + def __init__(self, title: str = "Pandera Error Report"): + self._title = title + + def _create_success_report(self, data: "pandas.DataFrame", schema: "pandera.DataFrameSchema") -> PandasReport: + summary = pandas.DataFrame( + [ + {"Metadata": "Schema Name", "Value": schema.name}, + {"Metadata": "Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"}, + {"Metadata": "Total schema errors", "Value": 0}, + {"Metadata": "Total data errors", "Value": 0}, + {"Metadata": "Schema Object", "Value": f"```\n{schema!r}\n```"}, + ] + ) + return PandasReport( + summary=summary, + data_preview=data.head(DATA_PREVIEW_HEAD), + ) + + @staticmethod + def _reshape_long_failure_cases(long_failure_cases: "pandas.DataFrame") -> "pandas.DataFrame": + return ( + long_failure_cases.pivot( + index=["schema_context", "check", "index"], columns="column", values="failure_case" + ) + .apply(lambda s: s.to_dict(), axis="columns") + .rename("failure_case") + .reset_index(["index", "check"]) + .reset_index(drop=True)[["check", "index", "failure_case"]] + ) + + def _prepare_data_error_df( + self, + data: "pandas.DataFrame", + data_errors: dict, + failure_cases: "pandas.DataFrame", + ) -> "pandas.DataFrame": + def num_failure_cases(series): + return len(series) + + def _failure_cases(series): + series = series.astype(str) + out = ", ".join(str(x) for x in series.iloc[:FAILURE_CASE_LIMIT]) + if len(series) > FAILURE_CASE_LIMIT: + out += f" ... (+{len(series) - FAILURE_CASE_LIMIT} more)" + return out + + data_errors = pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in data_errors.items()) + long_failure_case_selector = (failure_cases["schema_context"] == "DataFrameSchema") & ( + failure_cases["column"].notna() + ) + long_failure_cases = failure_cases[long_failure_case_selector] + + data_error_df = [data_errors.merge(failure_cases, how="inner", on=["column", "check"])] + if long_failure_cases.shape[0] > 0: + reshaped_failure_cases = self._reshape_long_failure_cases(long_failure_cases) + long_data_errors = data_errors.assign( + column=data_errors.column.where(~(data_errors.column == data_errors.schema), "NA") + ) + data_error_df.append( + long_data_errors.merge(reshaped_failure_cases, how="inner", on=["check"]).assign(column="NA") + ) + + data_error_df = pandas.concat(data_error_df) + out_df = ( + data_error_df[DATA_ERROR_COLUMNS] + .groupby(["column", "error_code", "check", "error"]) + .failure_case.agg([num_failure_cases, _failure_cases]) + .reset_index() + .rename(columns={"_failure_cases": "failure_cases"}) + .assign(percent_valid=lambda df: 1 - (df["num_failure_cases"] / data.shape[0])) + ) + return out_df + + def _create_error_report( + self, + data: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + error: "pandera.errors.SchemaErrors", + ) -> PandasReport: + failure_cases = error.failure_cases + error_dict = error.args[0] + + schema_errors = error_dict.get(SCHEMA_ERROR_KEY) + data_errors = error_dict.get(DATA_ERROR_KEY) + + if schema_errors is None: + schema_error_df = None + total_schema_errors = 0 + else: + schema_error_df = ( + pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in schema_errors.items()) + .merge(failure_cases, how="left", on=["column", "check"])[SCHEMA_ERROR_COLUMNS] + .drop(["schema"], axis="columns") + ) + total_schema_errors = schema_error_df.shape[0] + + if data_errors is None: + data_error_df = None + total_data_errors = 0 + else: + data_error_df = self._prepare_data_error_df(data, data_errors, failure_cases) + total_data_errors = data_error_df.shape[0] + + summary = pandas.DataFrame( + [ + {"Metadata": "Schema Name", "Value": schema.name}, + {"Metadata": "Data Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"}, + {"Metadata": "Total schema errors", "Value": total_schema_errors}, + {"Metadata": "Total data errors", "Value": total_data_errors}, + {"Metadata": "Schema Object", "Value": f"```\n{error.schema!r}\n```"}, + ] + ) + return PandasReport( + summary=summary, + data_preview=data.head(DATA_PREVIEW_HEAD), + schema_error_df=schema_error_df, + data_error_df=data_error_df, + ) + + def _format_summary_df(self, df: "pandas.DataFrame") -> str: + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Summary**"), + subtitle="A high-level overview of the schema errors found in the DataFrame.", + ) + .cols_width(cases={"Metadata": "20%", "Value": "80%"}) + .fmt_markdown(["Value"]) + .tab_stub(rowname_col="Metadata") + .tab_stubhead(label="Metadata") + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style( + style=gt.style.text(weight="bold"), + locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.stub()], + ) + .as_raw_html() + ) + + def _format_data_preview_df(self, df: "pandas.DataFrame") -> str: + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Data Preview**"), + subtitle=f"A preview of the first {min(DATA_PREVIEW_HEAD, df.shape[0])} rows of the data.", + ) + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels()) + .tab_style(style=gt.style.text(align="left"), locations=[gt.loc.body(), gt.loc.column_labels()]) + .as_raw_html() + ) + + @staticmethod + def _format_error(x: str) -> str: + if len(x) > ERROR_COLUMN_MAX_WIDTH: + x = f"{x[:ERROR_COLUMN_MAX_WIDTH]}..." + return f"```\n{x}\n```" + + def _format_schema_error_df(self, df: "pandas.DataFrame") -> str: + df = df.assign( + error=lambda df: df["error"].map(self._format_error), + error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"), + check=lambda df: df["check"].map(lambda x: f"`{x}`"), + ) + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Schema-level Errors**"), + subtitle="Schema-level metadata errors, e.g. column names, dtypes.", + ) + .fmt_markdown(["error_code", "check", "error"]) + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels()) + .as_raw_html() + ) + + def _format_data_error_df(self, df: "pandas.DataFrame") -> str: + df = df.assign( + error=lambda df: df["error"].map(self._format_error), + error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"), + check=lambda df: df["check"].map(lambda x: f"`{x}`"), + )[DATA_ERROR_DISPLAY_ORDER] + + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Data-level Errors**"), + subtitle="Data-level value errors, e.g. null values, out-of-range values.", + ) + .fmt_markdown(["error_code", "check", "error"]) + .fmt_percent("percent_valid", decimals=2) + .data_color(columns=["percent_valid"], palette="RdYlGn", domain=[0, 1], alpha=0.2) + .tab_stub(groupname_col="column", rowname_col="error_code") + .tab_stubhead(label="column") + .tab_style( + style=gt.style.text(align="left"), + locations=[gt.loc.header(), gt.loc.column_labels(), gt.loc.body()], + ) + .tab_style(style=gt.style.text(align="center"), locations=gt.loc.body(columns="percent_valid")) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style( + style=gt.style.text(weight="bold"), + locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.row_groups()], + ) + .tab_style( + style=gt.style.fill(color="#f4f4f4"), + locations=[gt.loc.row_groups(), gt.loc.stub()], + ) + .as_raw_html() + ) + + def to_html( + self, + data: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + error: Optional["pandera.errors.SchemaErrors"] = None, + ) -> str: + error_segments = "" + if error is None: + report_dfs = self._create_success_report(data, schema) + top_message = "Data validation succeeded." + else: + report_dfs = self._create_error_report(data, schema, error) + top_message = "Data validation failed." + + if report_dfs.schema_error_df is not None: + error_segments += f"
{self._format_schema_error_df(report_dfs.schema_error_df)}" + if report_dfs.data_error_df is not None: + error_segments += f"
{self._format_data_error_df(report_dfs.data_error_df)}" + + return f""" + + + + + + Pandera Report + + + + +
+
+ Pandera Logo +

Pandera Report

+
+

{top_message}

+ {self._format_summary_df(report_dfs.summary)} +
+ {self._format_data_preview_df(report_dfs.data_preview)} + {error_segments} +
+ + + """ diff --git a/plugins/pandera/src/flyteplugins/pandera/transformer.py b/plugins/pandera/src/flyteplugins/pandera/transformer.py new file mode 100644 index 000000000..525937a0c --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/transformer.py @@ -0,0 +1,152 @@ +"""Pandera type transformer for Flyte v2 SDK. + +This module provides a TypeTransformer that validates pandas DataFrames against +Pandera schemas during serialization and deserialization in Flyte tasks. +""" + +import functools +import typing +from typing import Type, Union + +from flyte._logging import logger +from flyte._utils import lazy_module +from flyte.io._dataframe.dataframe import DataFrameTransformerEngine +from flyte.types import TypeEngine, TypeTransformer +from flyteidl2.core import literals_pb2, types_pb2 + +from .config import ValidationConfig +from .renderer import PandasReportRenderer + +if typing.TYPE_CHECKING: + import pandas + + import pandera +else: + pandas = lazy_module("pandas") + pandera = lazy_module("pandera") + +T = typing.TypeVar("T") + + +class PanderaTransformer(TypeTransformer["pandera.typing.DataFrame"]): + """Type transformer for pandera.typing.DataFrame. + + Wraps the DataFrameTransformerEngine to add automatic Pandera schema + validation when data flows between Flyte tasks. + """ + + _VALIDATION_MEMO: typing.ClassVar[set] = set() + + def __init__(self): + super().__init__("Pandera Transformer", pandera.typing.DataFrame) + self._sd_transformer = DataFrameTransformerEngine() + + @staticmethod + def _get_pandera_schema( + t: Type["pandera.typing.DataFrame"], + ) -> tuple["pandera.DataFrameSchema", ValidationConfig]: + config = ValidationConfig() + if typing.get_origin(t) is typing.Annotated: + t, *args = typing.get_args(t) + for arg in args: + if isinstance(arg, ValidationConfig): + config = arg + break + + type_args = typing.get_args(t) + if type_args: + schema_model, *_ = type_args + schema = schema_model.to_schema() + else: + schema = pandera.DataFrameSchema() + return schema, config + + def assert_type(self, t: Type[T], v: T): + if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)): + raise TypeError(f"Type of Val '{v}' is not an instance of {t}") + + def get_literal_type(self, t: Type["pandera.typing.DataFrame"]) -> types_pb2.LiteralType: + if typing.get_origin(t) is typing.Annotated: + t, _ = typing.get_args(t) + return self._sd_transformer.get_literal_type(t) + + def _validate_and_report( + self, + df: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + config: ValidationConfig, + ) -> "pandas.DataFrame": + """Validate a DataFrame against a Pandera schema and generate a report. + + Returns the validated DataFrame (which may have coerced types). + Raises SchemaErrors if validation fails and config.on_error == "raise". + """ + renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}") + try: + val = schema.validate(df, lazy=True) + except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc: + renderer.to_html(df, schema, exc) + val = df + if config.on_error == "raise": + raise + elif config.on_error == "warn": + logger.warning(str(exc)) + else: + raise ValueError(f"Invalid on_error value: {config.on_error}") + else: + renderer.to_html(val, schema) + return val + + async def to_literal( + self, + python_val: Union["pandas.DataFrame", typing.Any], + python_type: Type["pandera.typing.DataFrame"], + expected: types_pb2.LiteralType, + ) -> literals_pb2.Literal: + if not isinstance(python_val, pandas.DataFrame): + raise AssertionError( + f"Only pandas DataFrame objects can be returned from a Pandera-validated task, got {type(python_val)}" + ) + + schema, config = self._get_pandera_schema(python_type) + val = self._validate_and_report(python_val, schema, config) + + lv = await self._sd_transformer.to_literal(val, pandas.DataFrame, expected) + + # Cache the URI + schema name to skip re-validation during local execution + # where to_literal is followed immediately by to_python_value + if lv.scalar and lv.scalar.structured_dataset: + self._VALIDATION_MEMO.add((lv.scalar.structured_dataset.uri, schema.name)) + + return lv + + async def to_python_value( + self, + lv: literals_pb2.Literal, + expected_python_type: Type["pandera.typing.DataFrame"], + ) -> "pandera.typing.DataFrame": + if not (lv and lv.scalar and lv.scalar.structured_dataset): + raise AssertionError("Can only convert a literal structured dataset to a pandera schema") + + df = await self._sd_transformer.to_python_value(lv, pandas.DataFrame) + schema, config = self._get_pandera_schema(expected_python_type) + + # Skip validation if we already validated this data in to_literal (local execution) + if (lv.scalar.structured_dataset.uri, schema.name) in self._VALIDATION_MEMO: + return df + + return self._validate_and_report(df, schema, config) + + +@functools.lru_cache(maxsize=None) +def register_pandera_transformer(): + """Register the Pandera transformer with Flyte's type engine. + + This function is called automatically via the flyte.plugins.types entry point + when flyte.init() is called with load_plugin_type_transformers=True (the default). + """ + TypeEngine.register(PanderaTransformer()) + + +# Also register at module import time for backwards compatibility +register_pandera_transformer() diff --git a/plugins/pandera/tests/__init__.py b/plugins/pandera/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/pandera/tests/conftest.py b/plugins/pandera/tests/conftest.py new file mode 100644 index 000000000..9e30eb29c --- /dev/null +++ b/plugins/pandera/tests/conftest.py @@ -0,0 +1,37 @@ +import os +from unittest.mock import patch + +import pytest +import pytest_asyncio +from flyte._cache.local_cache import LocalTaskCache +from flyte._context import RawDataPath, internal_ctx + + +@pytest.fixture +def ctx_with_test_raw_data_path(): + """Pytest fixture to set a RawDataPath in the internal_ctx.""" + raw_data_path = RawDataPath.from_local_folder() + ctx = internal_ctx() + new_context = ctx.new_raw_data_path(raw_data_path=raw_data_path) + with new_context as ctx: + yield ctx + + +@pytest_asyncio.fixture(autouse=True) +async def isolate_local_cache(tmp_path): + """ + Global fixture to isolate LocalTaskCache for each test. + Uses temporary directory to avoid polluting local development cache. + """ + with patch.object(LocalTaskCache, "_get_cache_path", return_value=str(tmp_path / "test_cache.db")): + LocalTaskCache._initialized = False + yield + await LocalTaskCache.close() + + +@pytest.fixture(autouse=True) +def patch_os_exit(monkeypatch): + def mock_exit(code): + raise SystemExit(code) + + monkeypatch.setattr(os, "_exit", mock_exit) diff --git a/plugins/pandera/tests/test_transformer.py b/plugins/pandera/tests/test_transformer.py new file mode 100644 index 000000000..8ef5d0ff3 --- /dev/null +++ b/plugins/pandera/tests/test_transformer.py @@ -0,0 +1,258 @@ +"""Tests for the Pandera plugin transformer.""" + +import typing + +import pandas as pd +import pandera as pa +import pytest +from flyte.types import TypeEngine +from flyteidl2.core import literals_pb2 +from pandera.typing import DataFrame, Series + +# Import to ensure registration +import flyteplugins.pandera.transformer # noqa: F401 +from flyteplugins.pandera.config import ValidationConfig +from flyteplugins.pandera.transformer import PanderaTransformer + +# ============================================================================ +# Test schemas +# ============================================================================ + + +class UserSchema(pa.DataFrameModel): + name: Series[str] + age: Series[int] = pa.Field(ge=0, le=120) + email: Series[str] + + +class StrictSchema(pa.DataFrameModel): + value: Series[float] = pa.Field(ge=0.0, le=1.0) + label: Series[str] + + +# ============================================================================ +# Sample data +# ============================================================================ + +VALID_USER_DATA = { + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "email": ["alice@example.com", "bob@example.com", "charlie@example.com"], +} + +INVALID_USER_DATA = { + "name": ["Alice", "Bob", "Charlie"], + "age": [25, -5, 200], # -5 and 200 are out of range + "email": ["alice@example.com", "bob@example.com", "charlie@example.com"], +} + +VALID_STRICT_DATA = { + "value": [0.1, 0.5, 0.9], + "label": ["low", "mid", "high"], +} + + +# ============================================================================ +# Registration tests +# ============================================================================ + + +def test_pandera_transformer_registered(): + """Test that the Pandera transformer is registered with the TypeEngine.""" + transformer = TypeEngine.get_transformer(DataFrame[UserSchema]) + assert isinstance(transformer, PanderaTransformer) + + +def test_pandera_transformer_name(): + """Test the transformer has the correct name.""" + transformer = PanderaTransformer() + assert transformer.name == "Pandera Transformer" + + +# ============================================================================ +# Literal type tests +# ============================================================================ + + +def test_get_literal_type_basic(): + """Test that get_literal_type works for basic pandera DataFrame.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame[UserSchema]) + assert lt.structured_dataset_type is not None + + +def test_get_literal_type_unannotated(): + """Test get_literal_type for an unannotated pandera DataFrame.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame) + assert lt.structured_dataset_type is not None + + +# ============================================================================ +# Schema extraction tests +# ============================================================================ + + +def test_get_pandera_schema_with_model(): + """Test schema extraction from pandera DataFrameModel.""" + schema, config = PanderaTransformer._get_pandera_schema(DataFrame[UserSchema]) + assert schema is not None + assert "name" in schema.columns + assert "age" in schema.columns + assert config.on_error == "raise" + + +def test_get_pandera_schema_without_model(): + """Test schema extraction for unannotated pandera DataFrame.""" + schema, config = PanderaTransformer._get_pandera_schema(DataFrame) + assert isinstance(schema, pa.DataFrameSchema) + assert config.on_error == "raise" + + +def test_get_pandera_schema_with_config(): + """Test schema extraction with custom ValidationConfig.""" + annotated_type = typing.Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] + schema, config = PanderaTransformer._get_pandera_schema(annotated_type) + assert schema is not None + assert config.on_error == "warn" + + +# ============================================================================ +# Validation tests (to_literal) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_to_literal_valid_data(ctx_with_test_raw_data_path): + """Test to_literal with valid data passes validation.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + assert lit.scalar.structured_dataset.uri is not None + + +@pytest.mark.asyncio +async def test_to_literal_invalid_data_raises(ctx_with_test_raw_data_path): + """Test to_literal with invalid data raises SchemaErrors.""" + transformer = PanderaTransformer() + df = pd.DataFrame(INVALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + with pytest.raises((pa.errors.SchemaError, pa.errors.SchemaErrors)): + await transformer.to_literal(df, DataFrame[UserSchema], lt) + + +@pytest.mark.asyncio +async def test_to_literal_invalid_data_warn(ctx_with_test_raw_data_path): + """Test to_literal with invalid data and on_error=warn logs warning.""" + transformer = PanderaTransformer() + df = pd.DataFrame(INVALID_USER_DATA) + annotated_type = typing.Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] + lt = transformer.get_literal_type(annotated_type) + + # Should not raise, but should warn + lit = await transformer.to_literal(df, annotated_type, lt) + assert lit.scalar.structured_dataset.uri is not None + + +@pytest.mark.asyncio +async def test_to_literal_non_dataframe_raises(ctx_with_test_raw_data_path): + """Test to_literal raises when given a non-DataFrame value.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + with pytest.raises(AssertionError, match="Only pandas DataFrame"): + await transformer.to_literal("not a dataframe", DataFrame[UserSchema], lt) + + +# ============================================================================ +# Deserialization tests (to_python_value) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_roundtrip_valid_data(ctx_with_test_raw_data_path): + """Test roundtrip encoding/decoding with valid data.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + # Encode + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + + # Clear validation memo to force re-validation + PanderaTransformer._VALIDATION_MEMO.clear() + + # Decode + restored_df = await transformer.to_python_value(lit, DataFrame[UserSchema]) + assert isinstance(restored_df, pd.DataFrame) + assert restored_df.shape == df.shape + assert list(restored_df.columns) == list(df.columns) + + +@pytest.mark.asyncio +async def test_roundtrip_strict_schema(ctx_with_test_raw_data_path): + """Test roundtrip with a different schema.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_STRICT_DATA) + lt = transformer.get_literal_type(DataFrame[StrictSchema]) + + lit = await transformer.to_literal(df, DataFrame[StrictSchema], lt) + + PanderaTransformer._VALIDATION_MEMO.clear() + + restored_df = await transformer.to_python_value(lit, DataFrame[StrictSchema]) + assert isinstance(restored_df, pd.DataFrame) + assert list(restored_df["value"]) == VALID_STRICT_DATA["value"] + assert list(restored_df["label"]) == VALID_STRICT_DATA["label"] + + +@pytest.mark.asyncio +async def test_to_python_value_invalid_literal_raises(): + """Test to_python_value raises on invalid literal.""" + transformer = PanderaTransformer() + + with pytest.raises(AssertionError, match="Can only convert a literal structured dataset"): + await transformer.to_python_value(literals_pb2.Literal(), DataFrame[UserSchema]) + + +# ============================================================================ +# Validation memo (caching) tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_validation_memo_skips_revalidation(ctx_with_test_raw_data_path): + """Test that validation memo prevents duplicate validation.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + # Encode (adds to memo) + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + + # Decode should skip validation (in memo) + restored_df = await transformer.to_python_value(lit, DataFrame[UserSchema]) + assert isinstance(restored_df, pd.DataFrame) + + +# ============================================================================ +# Assert type tests +# ============================================================================ + + +def test_assert_type_valid(): + """Test assert_type with valid DataFrame.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + # Should not raise + transformer.assert_type(pd.DataFrame, df) + + +def test_assert_type_invalid(): + """Test assert_type with invalid type.""" + transformer = PanderaTransformer() + with pytest.raises(TypeError): + transformer.assert_type(pd.DataFrame, "not a dataframe")