From 70add7ba5e36dbfa1fbcdb4fe9bdb7163b154ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Sun, 8 Feb 2026 08:26:02 -0300 Subject: [PATCH] feat: Add Pandera data validation plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port the Pandera plugin from flytekit v1 to the Flyte v2 SDK, enabling automatic runtime validation of pandas DataFrames against Pandera schemas as data flows between tasks. The plugin registers pandera.typing.DataFrame as a custom type with the TypeEngine, wrapping the DataFrameTransformerEngine to add schema validation on both serialization and deserialization. Features: - Automatic validation via pandera.typing.DataFrame type annotations - Configurable error handling (raise or warn) via ValidationConfig - HTML validation reports using great_tables for Flyte Decks - Validation memo to skip redundant re-validation in local execution Signed-off-by: André Ahlert --- plugins/pandera/README.md | 63 ++++ plugins/pandera/pyproject.toml | 73 ++++ .../src/flyteplugins/pandera/__init__.py | 19 ++ .../src/flyteplugins/pandera/config.py | 17 + .../src/flyteplugins/pandera/renderer.py | 320 ++++++++++++++++++ .../src/flyteplugins/pandera/transformer.py | 152 +++++++++ plugins/pandera/tests/__init__.py | 0 plugins/pandera/tests/conftest.py | 37 ++ plugins/pandera/tests/test_transformer.py | 258 ++++++++++++++ 9 files changed, 939 insertions(+) create mode 100644 plugins/pandera/README.md create mode 100644 plugins/pandera/pyproject.toml create mode 100644 plugins/pandera/src/flyteplugins/pandera/__init__.py create mode 100644 plugins/pandera/src/flyteplugins/pandera/config.py create mode 100644 plugins/pandera/src/flyteplugins/pandera/renderer.py create mode 100644 plugins/pandera/src/flyteplugins/pandera/transformer.py create mode 100644 plugins/pandera/tests/__init__.py create mode 100644 plugins/pandera/tests/conftest.py create mode 100644 plugins/pandera/tests/test_transformer.py diff --git a/plugins/pandera/README.md b/plugins/pandera/README.md new file mode 100644 index 000000000..a400860c3 --- /dev/null +++ b/plugins/pandera/README.md @@ -0,0 +1,63 @@ +# Flyte Pandera Plugin + +This plugin integrates [Pandera](https://pandera.readthedocs.io/) with [Flyte](https://flyte.org/), enabling automatic runtime validation of pandas DataFrames against Pandera schemas as data flows between tasks. + +## Installation + +```bash +pip install flyteplugins-pandera +``` + +## Usage + +Define a Pandera schema using `DataFrameModel` and use `pandera.typing.DataFrame` as your task's type annotation: + +```python +import flyte +import pandas as pd +import pandera as pa +from pandera.typing import DataFrame, Series + +env = flyte.TaskEnvironment(name="my-env") + + +class UserSchema(pa.DataFrameModel): + name: Series[str] + age: Series[int] = pa.Field(ge=0, le=120) + email: Series[str] + + +@env.task +async def generate_users() -> DataFrame[UserSchema]: + return pd.DataFrame({ + "name": ["Alice", "Bob"], + "age": [25, 30], + "email": ["alice@example.com", "bob@example.com"], + }) + + +@env.task +async def process_users(df: DataFrame[UserSchema]) -> DataFrame[UserSchema]: + df["age"] = df["age"] + 1 + return df +``` + +DataFrames are automatically validated against the schema on both input and output. If validation fails, a detailed error report is generated. + +## Configuration + +Control validation behavior using `ValidationConfig` with `typing.Annotated`: + +```python +from typing import Annotated +from flyteplugins.pandera import ValidationConfig + +@env.task +async def lenient_task( + df: Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] +) -> DataFrame[UserSchema]: + return df +``` + +- `on_error="raise"` (default): Raises an exception on validation failure +- `on_error="warn"`: Logs a warning and continues with the original data diff --git a/plugins/pandera/pyproject.toml b/plugins/pandera/pyproject.toml new file mode 100644 index 000000000..f3b48159a --- /dev/null +++ b/plugins/pandera/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "flyteplugins-pandera" +dynamic = ["version"] +description = "Pandera data validation plugin for Flyte" +readme = "README.md" +authors = [{ name = "Flyte Contributors", email = "admin@flyte.org" }] +requires-python = ">=3.10" +dependencies = [ + "pandera", + "flyte", + "great_tables", +] + +[project.entry-points."flyte.plugins.types"] +pandera = "flyteplugins.pandera.transformer:register_pandera_transformer" + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415", "ASYNC240"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] diff --git a/plugins/pandera/src/flyteplugins/pandera/__init__.py b/plugins/pandera/src/flyteplugins/pandera/__init__.py new file mode 100644 index 000000000..2e8759edb --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/__init__.py @@ -0,0 +1,19 @@ +""" +Pandera data validation plugin for Flyte. + +This plugin integrates Pandera's runtime data validation with Flyte's type system, +enabling automatic validation of pandas DataFrames against Pandera schemas when +data flows between tasks. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PanderaTransformer + PandasReportRenderer + ValidationConfig +""" + +from .config import ValidationConfig as ValidationConfig +from .renderer import PandasReportRenderer as PandasReportRenderer +from .transformer import PanderaTransformer as PanderaTransformer diff --git a/plugins/pandera/src/flyteplugins/pandera/config.py b/plugins/pandera/src/flyteplugins/pandera/config.py new file mode 100644 index 000000000..68654aa23 --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/config.py @@ -0,0 +1,17 @@ +"""Pandera validation configuration.""" + +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class ValidationConfig: + """Configuration for Pandera validation behavior. + + Attributes: + on_error: Determines how validation errors are handled. + "raise" will raise the SchemaError/SchemaErrors exception. + "warn" will log a warning and continue with the original data. + """ + + on_error: Literal["raise", "warn"] = "raise" diff --git a/plugins/pandera/src/flyteplugins/pandera/renderer.py b/plugins/pandera/src/flyteplugins/pandera/renderer.py new file mode 100644 index 000000000..96b78cb9e --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/renderer.py @@ -0,0 +1,320 @@ +"""Pandera validation report renderer for Flyte Decks.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from flyte._utils import lazy_module + +if TYPE_CHECKING: + import great_tables as gt + import pandas + + import pandera +else: + gt = lazy_module("great_tables") + pandas = lazy_module("pandas") + pandera = lazy_module("pandera") + + +SCHEMA_ERROR_KEY = "SCHEMA" +DATA_ERROR_KEY = "DATA" +SCHEMA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "failure_case", "error"] +DATA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "index", "failure_case", "error"] +DATA_ERROR_DISPLAY_ORDER = ["column", "error_code", "percent_valid", "check", "failure_cases", "error"] + +DATA_PREVIEW_HEAD = 5 +FAILURE_CASE_LIMIT = 10 +ERROR_COLUMN_MAX_WIDTH = 200 + + +@dataclass +class PandasReport: + summary: "pandas.DataFrame" + data_preview: "pandas.DataFrame" + schema_error_df: Optional["pandas.DataFrame"] = None + data_error_df: Optional["pandas.DataFrame"] = None + + +class PandasReportRenderer: + """Renders Pandera validation reports as HTML for Flyte Decks.""" + + def __init__(self, title: str = "Pandera Error Report"): + self._title = title + + def _create_success_report(self, data: "pandas.DataFrame", schema: "pandera.DataFrameSchema") -> PandasReport: + summary = pandas.DataFrame( + [ + {"Metadata": "Schema Name", "Value": schema.name}, + {"Metadata": "Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"}, + {"Metadata": "Total schema errors", "Value": 0}, + {"Metadata": "Total data errors", "Value": 0}, + {"Metadata": "Schema Object", "Value": f"```\n{schema!r}\n```"}, + ] + ) + return PandasReport( + summary=summary, + data_preview=data.head(DATA_PREVIEW_HEAD), + ) + + @staticmethod + def _reshape_long_failure_cases(long_failure_cases: "pandas.DataFrame") -> "pandas.DataFrame": + return ( + long_failure_cases.pivot( + index=["schema_context", "check", "index"], columns="column", values="failure_case" + ) + .apply(lambda s: s.to_dict(), axis="columns") + .rename("failure_case") + .reset_index(["index", "check"]) + .reset_index(drop=True)[["check", "index", "failure_case"]] + ) + + def _prepare_data_error_df( + self, + data: "pandas.DataFrame", + data_errors: dict, + failure_cases: "pandas.DataFrame", + ) -> "pandas.DataFrame": + def num_failure_cases(series): + return len(series) + + def _failure_cases(series): + series = series.astype(str) + out = ", ".join(str(x) for x in series.iloc[:FAILURE_CASE_LIMIT]) + if len(series) > FAILURE_CASE_LIMIT: + out += f" ... (+{len(series) - FAILURE_CASE_LIMIT} more)" + return out + + data_errors = pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in data_errors.items()) + long_failure_case_selector = (failure_cases["schema_context"] == "DataFrameSchema") & ( + failure_cases["column"].notna() + ) + long_failure_cases = failure_cases[long_failure_case_selector] + + data_error_df = [data_errors.merge(failure_cases, how="inner", on=["column", "check"])] + if long_failure_cases.shape[0] > 0: + reshaped_failure_cases = self._reshape_long_failure_cases(long_failure_cases) + long_data_errors = data_errors.assign( + column=data_errors.column.where(~(data_errors.column == data_errors.schema), "NA") + ) + data_error_df.append( + long_data_errors.merge(reshaped_failure_cases, how="inner", on=["check"]).assign(column="NA") + ) + + data_error_df = pandas.concat(data_error_df) + out_df = ( + data_error_df[DATA_ERROR_COLUMNS] + .groupby(["column", "error_code", "check", "error"]) + .failure_case.agg([num_failure_cases, _failure_cases]) + .reset_index() + .rename(columns={"_failure_cases": "failure_cases"}) + .assign(percent_valid=lambda df: 1 - (df["num_failure_cases"] / data.shape[0])) + ) + return out_df + + def _create_error_report( + self, + data: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + error: "pandera.errors.SchemaErrors", + ) -> PandasReport: + failure_cases = error.failure_cases + error_dict = error.args[0] + + schema_errors = error_dict.get(SCHEMA_ERROR_KEY) + data_errors = error_dict.get(DATA_ERROR_KEY) + + if schema_errors is None: + schema_error_df = None + total_schema_errors = 0 + else: + schema_error_df = ( + pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in schema_errors.items()) + .merge(failure_cases, how="left", on=["column", "check"])[SCHEMA_ERROR_COLUMNS] + .drop(["schema"], axis="columns") + ) + total_schema_errors = schema_error_df.shape[0] + + if data_errors is None: + data_error_df = None + total_data_errors = 0 + else: + data_error_df = self._prepare_data_error_df(data, data_errors, failure_cases) + total_data_errors = data_error_df.shape[0] + + summary = pandas.DataFrame( + [ + {"Metadata": "Schema Name", "Value": schema.name}, + {"Metadata": "Data Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"}, + {"Metadata": "Total schema errors", "Value": total_schema_errors}, + {"Metadata": "Total data errors", "Value": total_data_errors}, + {"Metadata": "Schema Object", "Value": f"```\n{error.schema!r}\n```"}, + ] + ) + return PandasReport( + summary=summary, + data_preview=data.head(DATA_PREVIEW_HEAD), + schema_error_df=schema_error_df, + data_error_df=data_error_df, + ) + + def _format_summary_df(self, df: "pandas.DataFrame") -> str: + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Summary**"), + subtitle="A high-level overview of the schema errors found in the DataFrame.", + ) + .cols_width(cases={"Metadata": "20%", "Value": "80%"}) + .fmt_markdown(["Value"]) + .tab_stub(rowname_col="Metadata") + .tab_stubhead(label="Metadata") + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style( + style=gt.style.text(weight="bold"), + locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.stub()], + ) + .as_raw_html() + ) + + def _format_data_preview_df(self, df: "pandas.DataFrame") -> str: + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Data Preview**"), + subtitle=f"A preview of the first {min(DATA_PREVIEW_HEAD, df.shape[0])} rows of the data.", + ) + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels()) + .tab_style(style=gt.style.text(align="left"), locations=[gt.loc.body(), gt.loc.column_labels()]) + .as_raw_html() + ) + + @staticmethod + def _format_error(x: str) -> str: + if len(x) > ERROR_COLUMN_MAX_WIDTH: + x = f"{x[:ERROR_COLUMN_MAX_WIDTH]}..." + return f"```\n{x}\n```" + + def _format_schema_error_df(self, df: "pandas.DataFrame") -> str: + df = df.assign( + error=lambda df: df["error"].map(self._format_error), + error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"), + check=lambda df: df["check"].map(lambda x: f"`{x}`"), + ) + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Schema-level Errors**"), + subtitle="Schema-level metadata errors, e.g. column names, dtypes.", + ) + .fmt_markdown(["error_code", "check", "error"]) + .tab_style(style=gt.style.text(align="left"), locations=gt.loc.header()) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels()) + .as_raw_html() + ) + + def _format_data_error_df(self, df: "pandas.DataFrame") -> str: + df = df.assign( + error=lambda df: df["error"].map(self._format_error), + error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"), + check=lambda df: df["check"].map(lambda x: f"`{x}`"), + )[DATA_ERROR_DISPLAY_ORDER] + + return ( + gt.GT(df) + .tab_header( + title=gt.md("**Data-level Errors**"), + subtitle="Data-level value errors, e.g. null values, out-of-range values.", + ) + .fmt_markdown(["error_code", "check", "error"]) + .fmt_percent("percent_valid", decimals=2) + .data_color(columns=["percent_valid"], palette="RdYlGn", domain=[0, 1], alpha=0.2) + .tab_stub(groupname_col="column", rowname_col="error_code") + .tab_stubhead(label="column") + .tab_style( + style=gt.style.text(align="left"), + locations=[gt.loc.header(), gt.loc.column_labels(), gt.loc.body()], + ) + .tab_style(style=gt.style.text(align="center"), locations=gt.loc.body(columns="percent_valid")) + .tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header()) + .tab_style( + style=gt.style.text(weight="bold"), + locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.row_groups()], + ) + .tab_style( + style=gt.style.fill(color="#f4f4f4"), + locations=[gt.loc.row_groups(), gt.loc.stub()], + ) + .as_raw_html() + ) + + def to_html( + self, + data: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + error: Optional["pandera.errors.SchemaErrors"] = None, + ) -> str: + error_segments = "" + if error is None: + report_dfs = self._create_success_report(data, schema) + top_message = "Data validation succeeded." + else: + report_dfs = self._create_error_report(data, schema, error) + top_message = "Data validation failed." + + if report_dfs.schema_error_df is not None: + error_segments += f"
{self._format_schema_error_df(report_dfs.schema_error_df)}" + if report_dfs.data_error_df is not None: + error_segments += f"
{self._format_data_error_df(report_dfs.data_error_df)}" + + return f""" + + + + + + Pandera Report + + + + +
+
+ Pandera Logo +

Pandera Report

+
+

{top_message}

+ {self._format_summary_df(report_dfs.summary)} +
+ {self._format_data_preview_df(report_dfs.data_preview)} + {error_segments} +
+ + + """ diff --git a/plugins/pandera/src/flyteplugins/pandera/transformer.py b/plugins/pandera/src/flyteplugins/pandera/transformer.py new file mode 100644 index 000000000..525937a0c --- /dev/null +++ b/plugins/pandera/src/flyteplugins/pandera/transformer.py @@ -0,0 +1,152 @@ +"""Pandera type transformer for Flyte v2 SDK. + +This module provides a TypeTransformer that validates pandas DataFrames against +Pandera schemas during serialization and deserialization in Flyte tasks. +""" + +import functools +import typing +from typing import Type, Union + +from flyte._logging import logger +from flyte._utils import lazy_module +from flyte.io._dataframe.dataframe import DataFrameTransformerEngine +from flyte.types import TypeEngine, TypeTransformer +from flyteidl2.core import literals_pb2, types_pb2 + +from .config import ValidationConfig +from .renderer import PandasReportRenderer + +if typing.TYPE_CHECKING: + import pandas + + import pandera +else: + pandas = lazy_module("pandas") + pandera = lazy_module("pandera") + +T = typing.TypeVar("T") + + +class PanderaTransformer(TypeTransformer["pandera.typing.DataFrame"]): + """Type transformer for pandera.typing.DataFrame. + + Wraps the DataFrameTransformerEngine to add automatic Pandera schema + validation when data flows between Flyte tasks. + """ + + _VALIDATION_MEMO: typing.ClassVar[set] = set() + + def __init__(self): + super().__init__("Pandera Transformer", pandera.typing.DataFrame) + self._sd_transformer = DataFrameTransformerEngine() + + @staticmethod + def _get_pandera_schema( + t: Type["pandera.typing.DataFrame"], + ) -> tuple["pandera.DataFrameSchema", ValidationConfig]: + config = ValidationConfig() + if typing.get_origin(t) is typing.Annotated: + t, *args = typing.get_args(t) + for arg in args: + if isinstance(arg, ValidationConfig): + config = arg + break + + type_args = typing.get_args(t) + if type_args: + schema_model, *_ = type_args + schema = schema_model.to_schema() + else: + schema = pandera.DataFrameSchema() + return schema, config + + def assert_type(self, t: Type[T], v: T): + if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)): + raise TypeError(f"Type of Val '{v}' is not an instance of {t}") + + def get_literal_type(self, t: Type["pandera.typing.DataFrame"]) -> types_pb2.LiteralType: + if typing.get_origin(t) is typing.Annotated: + t, _ = typing.get_args(t) + return self._sd_transformer.get_literal_type(t) + + def _validate_and_report( + self, + df: "pandas.DataFrame", + schema: "pandera.DataFrameSchema", + config: ValidationConfig, + ) -> "pandas.DataFrame": + """Validate a DataFrame against a Pandera schema and generate a report. + + Returns the validated DataFrame (which may have coerced types). + Raises SchemaErrors if validation fails and config.on_error == "raise". + """ + renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}") + try: + val = schema.validate(df, lazy=True) + except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc: + renderer.to_html(df, schema, exc) + val = df + if config.on_error == "raise": + raise + elif config.on_error == "warn": + logger.warning(str(exc)) + else: + raise ValueError(f"Invalid on_error value: {config.on_error}") + else: + renderer.to_html(val, schema) + return val + + async def to_literal( + self, + python_val: Union["pandas.DataFrame", typing.Any], + python_type: Type["pandera.typing.DataFrame"], + expected: types_pb2.LiteralType, + ) -> literals_pb2.Literal: + if not isinstance(python_val, pandas.DataFrame): + raise AssertionError( + f"Only pandas DataFrame objects can be returned from a Pandera-validated task, got {type(python_val)}" + ) + + schema, config = self._get_pandera_schema(python_type) + val = self._validate_and_report(python_val, schema, config) + + lv = await self._sd_transformer.to_literal(val, pandas.DataFrame, expected) + + # Cache the URI + schema name to skip re-validation during local execution + # where to_literal is followed immediately by to_python_value + if lv.scalar and lv.scalar.structured_dataset: + self._VALIDATION_MEMO.add((lv.scalar.structured_dataset.uri, schema.name)) + + return lv + + async def to_python_value( + self, + lv: literals_pb2.Literal, + expected_python_type: Type["pandera.typing.DataFrame"], + ) -> "pandera.typing.DataFrame": + if not (lv and lv.scalar and lv.scalar.structured_dataset): + raise AssertionError("Can only convert a literal structured dataset to a pandera schema") + + df = await self._sd_transformer.to_python_value(lv, pandas.DataFrame) + schema, config = self._get_pandera_schema(expected_python_type) + + # Skip validation if we already validated this data in to_literal (local execution) + if (lv.scalar.structured_dataset.uri, schema.name) in self._VALIDATION_MEMO: + return df + + return self._validate_and_report(df, schema, config) + + +@functools.lru_cache(maxsize=None) +def register_pandera_transformer(): + """Register the Pandera transformer with Flyte's type engine. + + This function is called automatically via the flyte.plugins.types entry point + when flyte.init() is called with load_plugin_type_transformers=True (the default). + """ + TypeEngine.register(PanderaTransformer()) + + +# Also register at module import time for backwards compatibility +register_pandera_transformer() diff --git a/plugins/pandera/tests/__init__.py b/plugins/pandera/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/pandera/tests/conftest.py b/plugins/pandera/tests/conftest.py new file mode 100644 index 000000000..9e30eb29c --- /dev/null +++ b/plugins/pandera/tests/conftest.py @@ -0,0 +1,37 @@ +import os +from unittest.mock import patch + +import pytest +import pytest_asyncio +from flyte._cache.local_cache import LocalTaskCache +from flyte._context import RawDataPath, internal_ctx + + +@pytest.fixture +def ctx_with_test_raw_data_path(): + """Pytest fixture to set a RawDataPath in the internal_ctx.""" + raw_data_path = RawDataPath.from_local_folder() + ctx = internal_ctx() + new_context = ctx.new_raw_data_path(raw_data_path=raw_data_path) + with new_context as ctx: + yield ctx + + +@pytest_asyncio.fixture(autouse=True) +async def isolate_local_cache(tmp_path): + """ + Global fixture to isolate LocalTaskCache for each test. + Uses temporary directory to avoid polluting local development cache. + """ + with patch.object(LocalTaskCache, "_get_cache_path", return_value=str(tmp_path / "test_cache.db")): + LocalTaskCache._initialized = False + yield + await LocalTaskCache.close() + + +@pytest.fixture(autouse=True) +def patch_os_exit(monkeypatch): + def mock_exit(code): + raise SystemExit(code) + + monkeypatch.setattr(os, "_exit", mock_exit) diff --git a/plugins/pandera/tests/test_transformer.py b/plugins/pandera/tests/test_transformer.py new file mode 100644 index 000000000..8ef5d0ff3 --- /dev/null +++ b/plugins/pandera/tests/test_transformer.py @@ -0,0 +1,258 @@ +"""Tests for the Pandera plugin transformer.""" + +import typing + +import pandas as pd +import pandera as pa +import pytest +from flyte.types import TypeEngine +from flyteidl2.core import literals_pb2 +from pandera.typing import DataFrame, Series + +# Import to ensure registration +import flyteplugins.pandera.transformer # noqa: F401 +from flyteplugins.pandera.config import ValidationConfig +from flyteplugins.pandera.transformer import PanderaTransformer + +# ============================================================================ +# Test schemas +# ============================================================================ + + +class UserSchema(pa.DataFrameModel): + name: Series[str] + age: Series[int] = pa.Field(ge=0, le=120) + email: Series[str] + + +class StrictSchema(pa.DataFrameModel): + value: Series[float] = pa.Field(ge=0.0, le=1.0) + label: Series[str] + + +# ============================================================================ +# Sample data +# ============================================================================ + +VALID_USER_DATA = { + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "email": ["alice@example.com", "bob@example.com", "charlie@example.com"], +} + +INVALID_USER_DATA = { + "name": ["Alice", "Bob", "Charlie"], + "age": [25, -5, 200], # -5 and 200 are out of range + "email": ["alice@example.com", "bob@example.com", "charlie@example.com"], +} + +VALID_STRICT_DATA = { + "value": [0.1, 0.5, 0.9], + "label": ["low", "mid", "high"], +} + + +# ============================================================================ +# Registration tests +# ============================================================================ + + +def test_pandera_transformer_registered(): + """Test that the Pandera transformer is registered with the TypeEngine.""" + transformer = TypeEngine.get_transformer(DataFrame[UserSchema]) + assert isinstance(transformer, PanderaTransformer) + + +def test_pandera_transformer_name(): + """Test the transformer has the correct name.""" + transformer = PanderaTransformer() + assert transformer.name == "Pandera Transformer" + + +# ============================================================================ +# Literal type tests +# ============================================================================ + + +def test_get_literal_type_basic(): + """Test that get_literal_type works for basic pandera DataFrame.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame[UserSchema]) + assert lt.structured_dataset_type is not None + + +def test_get_literal_type_unannotated(): + """Test get_literal_type for an unannotated pandera DataFrame.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame) + assert lt.structured_dataset_type is not None + + +# ============================================================================ +# Schema extraction tests +# ============================================================================ + + +def test_get_pandera_schema_with_model(): + """Test schema extraction from pandera DataFrameModel.""" + schema, config = PanderaTransformer._get_pandera_schema(DataFrame[UserSchema]) + assert schema is not None + assert "name" in schema.columns + assert "age" in schema.columns + assert config.on_error == "raise" + + +def test_get_pandera_schema_without_model(): + """Test schema extraction for unannotated pandera DataFrame.""" + schema, config = PanderaTransformer._get_pandera_schema(DataFrame) + assert isinstance(schema, pa.DataFrameSchema) + assert config.on_error == "raise" + + +def test_get_pandera_schema_with_config(): + """Test schema extraction with custom ValidationConfig.""" + annotated_type = typing.Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] + schema, config = PanderaTransformer._get_pandera_schema(annotated_type) + assert schema is not None + assert config.on_error == "warn" + + +# ============================================================================ +# Validation tests (to_literal) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_to_literal_valid_data(ctx_with_test_raw_data_path): + """Test to_literal with valid data passes validation.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + assert lit.scalar.structured_dataset.uri is not None + + +@pytest.mark.asyncio +async def test_to_literal_invalid_data_raises(ctx_with_test_raw_data_path): + """Test to_literal with invalid data raises SchemaErrors.""" + transformer = PanderaTransformer() + df = pd.DataFrame(INVALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + with pytest.raises((pa.errors.SchemaError, pa.errors.SchemaErrors)): + await transformer.to_literal(df, DataFrame[UserSchema], lt) + + +@pytest.mark.asyncio +async def test_to_literal_invalid_data_warn(ctx_with_test_raw_data_path): + """Test to_literal with invalid data and on_error=warn logs warning.""" + transformer = PanderaTransformer() + df = pd.DataFrame(INVALID_USER_DATA) + annotated_type = typing.Annotated[DataFrame[UserSchema], ValidationConfig(on_error="warn")] + lt = transformer.get_literal_type(annotated_type) + + # Should not raise, but should warn + lit = await transformer.to_literal(df, annotated_type, lt) + assert lit.scalar.structured_dataset.uri is not None + + +@pytest.mark.asyncio +async def test_to_literal_non_dataframe_raises(ctx_with_test_raw_data_path): + """Test to_literal raises when given a non-DataFrame value.""" + transformer = PanderaTransformer() + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + with pytest.raises(AssertionError, match="Only pandas DataFrame"): + await transformer.to_literal("not a dataframe", DataFrame[UserSchema], lt) + + +# ============================================================================ +# Deserialization tests (to_python_value) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_roundtrip_valid_data(ctx_with_test_raw_data_path): + """Test roundtrip encoding/decoding with valid data.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + # Encode + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + + # Clear validation memo to force re-validation + PanderaTransformer._VALIDATION_MEMO.clear() + + # Decode + restored_df = await transformer.to_python_value(lit, DataFrame[UserSchema]) + assert isinstance(restored_df, pd.DataFrame) + assert restored_df.shape == df.shape + assert list(restored_df.columns) == list(df.columns) + + +@pytest.mark.asyncio +async def test_roundtrip_strict_schema(ctx_with_test_raw_data_path): + """Test roundtrip with a different schema.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_STRICT_DATA) + lt = transformer.get_literal_type(DataFrame[StrictSchema]) + + lit = await transformer.to_literal(df, DataFrame[StrictSchema], lt) + + PanderaTransformer._VALIDATION_MEMO.clear() + + restored_df = await transformer.to_python_value(lit, DataFrame[StrictSchema]) + assert isinstance(restored_df, pd.DataFrame) + assert list(restored_df["value"]) == VALID_STRICT_DATA["value"] + assert list(restored_df["label"]) == VALID_STRICT_DATA["label"] + + +@pytest.mark.asyncio +async def test_to_python_value_invalid_literal_raises(): + """Test to_python_value raises on invalid literal.""" + transformer = PanderaTransformer() + + with pytest.raises(AssertionError, match="Can only convert a literal structured dataset"): + await transformer.to_python_value(literals_pb2.Literal(), DataFrame[UserSchema]) + + +# ============================================================================ +# Validation memo (caching) tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_validation_memo_skips_revalidation(ctx_with_test_raw_data_path): + """Test that validation memo prevents duplicate validation.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + lt = transformer.get_literal_type(DataFrame[UserSchema]) + + # Encode (adds to memo) + lit = await transformer.to_literal(df, DataFrame[UserSchema], lt) + + # Decode should skip validation (in memo) + restored_df = await transformer.to_python_value(lit, DataFrame[UserSchema]) + assert isinstance(restored_df, pd.DataFrame) + + +# ============================================================================ +# Assert type tests +# ============================================================================ + + +def test_assert_type_valid(): + """Test assert_type with valid DataFrame.""" + transformer = PanderaTransformer() + df = pd.DataFrame(VALID_USER_DATA) + # Should not raise + transformer.assert_type(pd.DataFrame, df) + + +def test_assert_type_invalid(): + """Test assert_type with invalid type.""" + transformer = PanderaTransformer() + with pytest.raises(TypeError): + transformer.assert_type(pd.DataFrame, "not a dataframe")