From 100a36cabeca5493f5838b3b91c6d0e314c60870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Sun, 8 Feb 2026 13:14:15 -0300 Subject: [PATCH] Add DuckDB plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a DuckDB connector plugin following the same patterns as the Snowflake plugin. DuckDB is an embedded analytical database that runs queries locally and synchronously, so the connector executes queries in create() and get() always returns SUCCEEDED. Features: - In-memory and file-based database support - Parameterized SQL queries with typed inputs - Extension installation and loading (httpfs, json, etc.) - Query results returned as pandas DataFrames via temp parquet files - Automatic cleanup of temporary result files Signed-off-by: André Ahlert --- plugins/duckdb/README.md | 146 +++++++++ plugins/duckdb/pyproject.toml | 73 +++++ .../src/flyteplugins/duckdb/__init__.py | 52 ++++ .../src/flyteplugins/duckdb/connector.py | 150 +++++++++ .../src/flyteplugins/duckdb/dataframe.py | 63 ++++ .../duckdb/src/flyteplugins/duckdb/task.py | 83 +++++ plugins/duckdb/tests/__init__.py | 1 + plugins/duckdb/tests/test_connector.py | 291 ++++++++++++++++++ plugins/duckdb/tests/test_task.py | 98 ++++++ 9 files changed, 957 insertions(+) create mode 100644 plugins/duckdb/README.md create mode 100644 plugins/duckdb/pyproject.toml create mode 100644 plugins/duckdb/src/flyteplugins/duckdb/__init__.py create mode 100644 plugins/duckdb/src/flyteplugins/duckdb/connector.py create mode 100644 plugins/duckdb/src/flyteplugins/duckdb/dataframe.py create mode 100644 plugins/duckdb/src/flyteplugins/duckdb/task.py create mode 100644 plugins/duckdb/tests/__init__.py create mode 100644 plugins/duckdb/tests/test_connector.py create mode 100644 plugins/duckdb/tests/test_task.py diff --git a/plugins/duckdb/README.md b/plugins/duckdb/README.md new file mode 100644 index 000000000..8afe9315b --- /dev/null +++ b/plugins/duckdb/README.md @@ -0,0 +1,146 @@ +# DuckDB Plugin for Flyte + +Run DuckDB SQL queries as Flyte tasks with parameterized inputs, extension support, and DataFrame output. + +DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute locally and synchronously, so no remote credentials or connection setup is required. + +## Installation + +```bash +pip install flyteplugins-duckdb +``` + +## Quick start + +```python +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +import flyte + +config = DuckDBConfig() + +query = DuckDB( + name="count_rows", + query_template="SELECT COUNT(*) AS total FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## In-memory queries + +By default, DuckDB runs in-memory. This is ideal for ad-hoc analytics and querying files directly: + +```python +config = DuckDBConfig() # defaults to database_path=":memory:" + +task = DuckDB( + name="analyze", + query_template="SELECT * FROM 'sales.parquet' WHERE amount > 100", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## File-based databases + +To query a persistent DuckDB database file: + +```python +config = DuckDBConfig(database_path="/data/analytics.duckdb") + +task = DuckDB( + name="query_db", + query_template="SELECT * FROM customers LIMIT 10", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## Parameterized queries + +Use `%(name)s` placeholders and typed `inputs`: + +```python +lookup = DuckDB( + name="lookup_user", + query_template="SELECT * FROM 'users.parquet' WHERE id = %(user_id)s", + plugin_config=config, + inputs={"user_id": int}, + output_dataframe_type=pd.DataFrame, +) +``` + +## Extensions + +DuckDB supports extensions for additional functionality. Install and load them via `DuckDBConfig.extensions`: + +```python +config = DuckDBConfig(extensions=["httpfs"]) + +task = DuckDB( + name="query_s3", + query_template="SELECT * FROM 's3://bucket/data.parquet' LIMIT 100", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +Common extensions: +- `httpfs` - Read files from HTTP/S3 +- `spatial` - Geospatial functions +- `json` - JSON processing +- `excel` - Read Excel files + +## Reading results as DataFrames + +Set `output_dataframe_type` to get query results as a pandas DataFrame: + +```python +import pandas as pd + +select_task = DuckDB( + name="get_data", + query_template="SELECT * FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## Full example + +```python +import pandas as pd +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +import flyte + +config = DuckDBConfig(extensions=["httpfs"]) + +analyze_task = DuckDB( + name="analyze_sales", + query_template="SELECT region, SUM(amount) as total FROM 'sales.parquet' GROUP BY region", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) + +duckdb_env = flyte.TaskEnvironment.from_task("duckdb_env", analyze_task) + +env = flyte.TaskEnvironment( + name="example_env", + image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-duckdb"), + depends_on=[duckdb_env], +) + + +@env.task +async def main() -> float: + df = await analyze_task() + return df["total"].sum().item() + + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.with_runcontext(mode="remote").run(main) + print(run.url) +``` diff --git a/plugins/duckdb/pyproject.toml b/plugins/duckdb/pyproject.toml new file mode 100644 index 000000000..0cf5e971d --- /dev/null +++ b/plugins/duckdb/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "flyteplugins-duckdb" +dynamic = ["version"] +description = "DuckDB plugin for flyte" +readme = "README.md" +authors = [{ name = "Andre Ahlert", email = "andreahlert@users.noreply.github.com" }] +requires-python = ">=3.10" +dependencies = [ + "flyte[connector]", + "duckdb", +] + +[project.entry-points."flyte.connectors"] +duckdb = "flyteplugins.duckdb.connector:DuckDBConnector" + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] +"tests/*" = ["ASYNC230", "ASYNC240"] diff --git a/plugins/duckdb/src/flyteplugins/duckdb/__init__.py b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py new file mode 100644 index 000000000..aaa2a3df2 --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py @@ -0,0 +1,52 @@ +""" +Key features: + +- Run SQL queries against DuckDB (in-memory or file-based) +- Parameterized SQL queries with typed inputs +- Query Parquet, CSV, and JSON files directly +- Load DuckDB extensions (httpfs, spatial, etc.) +- Returns query results as DataFrames + +Basic usage example: +```python +import flyte +from flyte.io import DataFrame +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +config = DuckDBConfig() + +count_rows = DuckDB( + name="count_rows", + query_template="SELECT COUNT(*) AS total FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=DataFrame, +) + +flyte.TaskEnvironment.from_task("duckdb_env", count_rows) + +if __name__ == "__main__": + flyte.init_from_config() + + # Run locally (connector runs in-process) + run = flyte.with_runcontext(mode="local").run(count_rows) + + # Run remotely (connector runs on the control plane) + run = flyte.with_runcontext(mode="remote").run(count_rows) + + print(run.url) +``` +""" + +from flyte.io._dataframe.dataframe import DataFrameTransformerEngine + +from flyteplugins.duckdb.connector import DuckDBConnector +from flyteplugins.duckdb.dataframe import ( + DuckDBToPandasDecodingHandler, + PandasToDuckDBEncodingHandler, +) +from flyteplugins.duckdb.task import DuckDB, DuckDBConfig + +DataFrameTransformerEngine.register(PandasToDuckDBEncodingHandler()) +DataFrameTransformerEngine.register(DuckDBToPandasDecodingHandler()) + +__all__ = ["DuckDB", "DuckDBConfig", "DuckDBConnector"] diff --git a/plugins/duckdb/src/flyteplugins/duckdb/connector.py b/plugins/duckdb/src/flyteplugins/duckdb/connector.py new file mode 100644 index 000000000..7f53b2cb2 --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/connector.py @@ -0,0 +1,150 @@ +import asyncio +import os +import tempfile +import uuid +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from flyte import logger +from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta +from flyte.io import DataFrame +from flyteidl2.core.execution_pb2 import TaskExecution +from flyteidl2.core.tasks_pb2 import TaskTemplate +from google.protobuf import json_format + +import duckdb + +TASK_TYPE = "duckdb" + + +@dataclass +class DuckDBJobMetadata(ResourceMeta): + """ + Metadata for a DuckDB query job. + + Attributes: + query_id: Unique identifier for tracking the query. + result_uri: Path to the temporary parquet file containing query results. + has_output: Indicates if the query produces output. + """ + + query_id: str + result_uri: Optional[str] = None + has_output: bool = False + + +class DuckDBConnector(AsyncConnector): + name = "DuckDB Connector" + task_type_name = TASK_TYPE + metadata_type = DuckDBJobMetadata + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> DuckDBJobMetadata: + """ + Execute a DuckDB query. + + DuckDB queries run locally and synchronously. The query executes entirely within + this method, and the result (if any) is written to a temporary parquet file. + + Args: + task_template: The Flyte task template containing the SQL query and configuration. + inputs: Optional dictionary of input parameters for parameterized queries. + + Returns: + A DuckDBJobMetadata object containing the query ID and result file path. + """ + custom = json_format.MessageToDict(task_template.custom) if task_template.custom else {} + + database_path = custom.get("database_path", ":memory:") + extensions = custom.get("extensions", []) + + query = task_template.sql.statement + has_output = task_template.interface.outputs is not None and len(task_template.interface.outputs.variables) > 0 + query_id = str(uuid.uuid4()) + + def _execute_query(): + conn = duckdb.connect(database=database_path) + try: + for ext in extensions: + conn.install_extension(ext) + conn.load_extension(ext) + + if inputs: + params = list(inputs.values()) + param_names = list(inputs.keys()) + positional_query = query + for name in param_names: + positional_query = positional_query.replace(f"%({name})s", "?") + result = conn.execute(positional_query, params) + else: + result = conn.execute(query) + + result_uri = None + if has_output: + df = result.fetchdf() + result_uri = os.path.join(tempfile.gettempdir(), f"duckdb_result_{query_id}.parquet") + df.to_parquet(result_uri) + + return result_uri + finally: + conn.close() + + loop = asyncio.get_running_loop() + result_uri = await loop.run_in_executor(None, _execute_query) + + logger.info(f"DuckDB query executed with ID: {query_id}") + + return DuckDBJobMetadata( + query_id=query_id, + result_uri=result_uri, + has_output=has_output, + ) + + async def get( + self, + resource_meta: DuckDBJobMetadata, + **kwargs, + ) -> Resource: + """ + Get the status of a DuckDB query. + + DuckDB queries complete synchronously in create(), so this always returns SUCCEEDED. + + Args: + resource_meta: The DuckDBJobMetadata from the create() call. + + Returns: + A Resource object with SUCCEEDED status and optional outputs. + """ + outputs = None + if resource_meta.has_output and resource_meta.result_uri: + outputs = {"results": DataFrame(uri=f"duckdb://{resource_meta.result_uri}")} + + return Resource(phase=TaskExecution.SUCCEEDED, message="Query completed", outputs=outputs) + + async def delete( + self, + resource_meta: DuckDBJobMetadata, + **kwargs, + ): + """ + Clean up temporary result files. + + Args: + resource_meta: The DuckDBJobMetadata containing the result file path. + """ + + def _cleanup(): + if resource_meta.result_uri and os.path.exists(resource_meta.result_uri): + os.remove(resource_meta.result_uri) + logger.info(f"Cleaned up temporary result file: {resource_meta.result_uri}") + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _cleanup) + + +ConnectorRegistry.register(DuckDBConnector()) diff --git a/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py b/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py new file mode 100644 index 000000000..6a2bcb70f --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py @@ -0,0 +1,63 @@ +import typing + +from flyte._utils import lazy_module +from flyte.io._dataframe.dataframe import DataFrame, DataFrameDecoder, DataFrameEncoder +from flyteidl2.core import literals_pb2, types_pb2 + +if typing.TYPE_CHECKING: + import pandas as pd +else: + pd = lazy_module("pandas") + +DUCKDB = "duckdb" +PROTOCOL_PREFIX = "duckdb://" + + +def _read_from_duckdb( + flyte_value: literals_pb2.StructuredDataset, + current_task_metadata: literals_pb2.StructuredDatasetMetadata, +) -> "pd.DataFrame": + uri = flyte_value.uri + if not uri: + raise ValueError("flyte_value.uri cannot be empty.") + + parquet_path = uri.removeprefix(PROTOCOL_PREFIX) + return pd.read_parquet(parquet_path) + + +def _write_to_duckdb(dataframe: DataFrame): + if not dataframe.uri: + raise ValueError("dataframe.uri cannot be None.") + + uri = typing.cast(str, dataframe.uri) + parquet_path = uri.removeprefix(PROTOCOL_PREFIX) + df = typing.cast("pd.DataFrame", dataframe.val) + df.to_parquet(parquet_path) + + +class PandasToDuckDBEncodingHandler(DataFrameEncoder): + def __init__(self): + super().__init__(pd.DataFrame, DUCKDB, "") + + async def encode( + self, + dataframe: DataFrame, + structured_dataset_type: types_pb2.StructuredDatasetType, + ) -> literals_pb2.StructuredDataset: + _write_to_duckdb(dataframe) + return literals_pb2.StructuredDataset( + uri=typing.cast(str, dataframe.uri), + metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), + ) + + +class DuckDBToPandasDecodingHandler(DataFrameDecoder): + def __init__(self): + super().__init__(pd.DataFrame, DUCKDB, "") + + async def decode( + self, + flyte_value: literals_pb2.StructuredDataset, + current_task_metadata: literals_pb2.StructuredDatasetMetadata, + ) -> "pd.DataFrame": + return _read_from_duckdb(flyte_value, current_task_metadata) diff --git a/plugins/duckdb/src/flyteplugins/duckdb/task.py b/plugins/duckdb/src/flyteplugins/duckdb/task.py new file mode 100644 index 000000000..caddd937a --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/task.py @@ -0,0 +1,83 @@ +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +from flyte.connectors import AsyncConnectorExecutorMixin +from flyte.extend import TaskTemplate +from flyte.models import NativeInterface, SerializationContext +from flyteidl2.core import tasks_pb2 + + +@dataclass +class DuckDBConfig(object): + """ + Configure a DuckDB Task using a `DuckDBConfig` object. + + Args: + database_path: Path to a DuckDB database file, or ":memory:" for an in-memory database. + extensions: Optional list of DuckDB extensions to install and load before executing + the query (e.g., ["httpfs", "parquet"]). + """ + + database_path: str = ":memory:" + extensions: Optional[List[str]] = None + + +class DuckDB(AsyncConnectorExecutorMixin, TaskTemplate): + _TASK_TYPE = "duckdb" + + def __init__( + self, + name: str, + query_template: str, + plugin_config: Optional[DuckDBConfig] = None, + inputs: Optional[Dict[str, Type]] = None, + output_dataframe_type: Optional[Type] = None, + **kwargs, + ): + """ + Task to run SQL queries against DuckDB. + + DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute + locally and synchronously, so no remote credentials or polling are required. + + Args: + name: The name of this task. + query_template: The SQL query to run. This can be parameterized using Python's + printf-style string formatting with named parameters (e.g. %(param_name)s). + plugin_config: Optional `DuckDBConfig` object. Defaults to in-memory database. + inputs: Name and type of inputs specified as a dictionary. + output_dataframe_type: If the query produces data, specify the output dataframe type. + """ + outputs = None + if output_dataframe_type is not None: + outputs = {"results": output_dataframe_type} + + super().__init__( + name=name, + interface=NativeInterface( + {k: (v, None) for k, v in inputs.items()} if inputs else {}, + outputs or {}, + ), + task_type=self._TASK_TYPE, + image=None, + **kwargs, + ) + + self.output_dataframe_type = output_dataframe_type + self.plugin_config = plugin_config or DuckDBConfig() + self.query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() + + def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]: + config = { + "database_path": self.plugin_config.database_path, + } + + if self.plugin_config.extensions: + config["extensions"] = self.plugin_config.extensions + + return config + + def sql(self, sctx: SerializationContext) -> Optional[str]: + sql = tasks_pb2.Sql(statement=self.query_template, dialect=tasks_pb2.Sql.Dialect.ANSI) + return sql diff --git a/plugins/duckdb/tests/__init__.py b/plugins/duckdb/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/plugins/duckdb/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/plugins/duckdb/tests/test_connector.py b/plugins/duckdb/tests/test_connector.py new file mode 100644 index 000000000..21af37d0e --- /dev/null +++ b/plugins/duckdb/tests/test_connector.py @@ -0,0 +1,291 @@ +import pathlib +import tempfile + +import pytest +from flyte.io import DataFrame +from flyteidl2.core.execution_pb2 import TaskExecution +from flyteidl2.core.interface_pb2 import Variable, VariableEntry, VariableMap +from flyteidl2.core.tasks_pb2 import Sql, TaskTemplate +from flyteidl2.core.types_pb2 import LiteralType, StructuredDatasetType +from google.protobuf import struct_pb2 + +from flyteplugins.duckdb.connector import ( + DuckDBConnector, + DuckDBJobMetadata, +) + + +def _make_output_variable_map(): + return VariableMap( + variables=[ + VariableEntry( + key="results", + value=Variable(type=LiteralType(structured_dataset_type=StructuredDatasetType())), + ) + ] + ) + + +def test_metadata_creation(): + """Test creating DuckDBJobMetadata instance.""" + metadata = DuckDBJobMetadata( + query_id="test-query-123", + result_uri="/tmp/duckdb_result_test.parquet", + has_output=True, + ) + assert metadata.query_id == "test-query-123" + assert metadata.result_uri == "/tmp/duckdb_result_test.parquet" + assert metadata.has_output is True + + +def test_metadata_defaults(): + """Test DuckDBJobMetadata default values.""" + metadata = DuckDBJobMetadata(query_id="test-query-456") + assert metadata.query_id == "test-query-456" + assert metadata.result_uri is None + assert metadata.has_output is False + + +class TestDuckDBConnector: + @pytest.fixture + def connector(self): + """Create a DuckDBConnector instance.""" + return DuckDBConnector() + + @pytest.fixture + def task_template_minimal(self): + """Create a minimal task template for testing.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 1", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + return template + + @pytest.fixture + def task_template_with_output(self): + """Create a task template with output variables.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT * FROM range(10)", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + return template + + @pytest.fixture + def task_template_with_extensions(self): + """Create a task template with extensions configured.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 1", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + + # Extensions as a list in the Struct + extensions_list = struct_pb2.ListValue() + extensions_list.values.add().string_value = "json" + custom.fields["extensions"].CopyFrom(struct_pb2.Value(list_value=extensions_list)) + + template.custom.CopyFrom(custom) + + return template + + def test_connector_class_attributes(self, connector): + """Test that the connector has the correct class attributes.""" + assert connector.name == "DuckDB Connector" + assert connector.task_type_name == "duckdb" + assert connector.metadata_type == DuckDBJobMetadata + + @pytest.mark.asyncio + async def test_create_minimal(self, connector, task_template_minimal): + """Test creating a DuckDB query without inputs and without output.""" + metadata = await connector.create(task_template_minimal, inputs=None) + + assert metadata.query_id is not None + assert metadata.has_output is False + assert metadata.result_uri is None + + @pytest.mark.asyncio + async def test_create_with_output(self, connector, task_template_with_output): + """Test creating a DuckDB query with output produces a parquet file.""" + metadata = await connector.create(task_template_with_output, inputs=None) + + assert metadata.query_id is not None + assert metadata.has_output is True + assert metadata.result_uri is not None + assert metadata.result_uri.endswith(".parquet") + assert pathlib.Path(metadata.result_uri).exists() + + # Clean up + pathlib.Path(metadata.result_uri).unlink() + + @pytest.mark.asyncio + async def test_create_with_inputs(self, connector): + """Test creating a DuckDB query with parameterized inputs.""" + template = TaskTemplate() + template.sql.CopyFrom( + Sql(statement="SELECT * FROM range(10) WHERE range > %(min_val)s", dialect=Sql.Dialect.ANSI) + ) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + metadata = await connector.create(template, inputs={"min_val": 5}) + + assert metadata.has_output is True + assert metadata.result_uri is not None + assert pathlib.Path(metadata.result_uri).exists() + + # Verify the result has correct data + import pandas as pd_lib + + df = pd_lib.read_parquet(metadata.result_uri) + assert len(df) == 4 # values 6, 7, 8, 9 + assert all(df["range"] > 5) + + # Clean up + pathlib.Path(metadata.result_uri).unlink() + + @pytest.mark.asyncio + async def test_create_without_output(self, connector): + """Test creating a DuckDB query that produces no output.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="CREATE TABLE test (id INTEGER)", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + metadata = await connector.create(template, inputs=None) + + assert metadata.has_output is False + assert metadata.result_uri is None + + @pytest.mark.asyncio + async def test_get_succeeded_with_output(self, connector): + """Test getting a completed DuckDB query with output.""" + metadata = DuckDBJobMetadata( + query_id="test-123", + result_uri="/tmp/duckdb_result_test.parquet", + has_output=True, + ) + + resource = await connector.get(metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message == "Query completed" + assert resource.outputs is not None + assert "results" in resource.outputs + assert isinstance(resource.outputs["results"], DataFrame) + assert "duckdb://" in resource.outputs["results"].uri + assert "/tmp/duckdb_result_test.parquet" in resource.outputs["results"].uri + + @pytest.mark.asyncio + async def test_get_succeeded_without_output(self, connector): + """Test getting a completed DuckDB query without output.""" + metadata = DuckDBJobMetadata( + query_id="test-456", + has_output=False, + ) + + resource = await connector.get(metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message == "Query completed" + assert resource.outputs is None + + @pytest.mark.asyncio + async def test_delete_cleans_up_file(self, connector): + """Test that delete removes the temporary result file.""" + # Create a temporary file to simulate a result + tmp_path = pathlib.Path(tempfile.gettempdir()) / "duckdb_result_delete_test.parquet" + tmp_path.write_text("test") + + assert tmp_path.exists() + + metadata = DuckDBJobMetadata( + query_id="test-delete", + result_uri=str(tmp_path), + has_output=True, + ) + + await connector.delete(metadata) + + assert not tmp_path.exists() + + @pytest.mark.asyncio + async def test_delete_no_file(self, connector): + """Test that delete handles missing files gracefully.""" + metadata = DuckDBJobMetadata( + query_id="test-no-file", + result_uri="/tmp/nonexistent_file.parquet", + has_output=True, + ) + + # Should not raise + await connector.delete(metadata) + + @pytest.mark.asyncio + async def test_delete_no_result_uri(self, connector): + """Test that delete handles None result_uri gracefully.""" + metadata = DuckDBJobMetadata( + query_id="test-no-uri", + has_output=False, + ) + + # Should not raise + await connector.delete(metadata) + + @pytest.mark.asyncio + async def test_create_with_extensions(self, connector, task_template_with_extensions): + """Test that extensions are installed and loaded.""" + # json extension is bundled with DuckDB, so this should work without network + metadata = await connector.create(task_template_with_extensions, inputs=None) + assert metadata.query_id is not None + + @pytest.mark.asyncio + async def test_end_to_end_flow(self, connector): + """Test the complete create -> get -> delete flow.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 42 AS answer", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + # Create + metadata = await connector.create(template, inputs=None) + assert metadata.has_output is True + assert pathlib.Path(metadata.result_uri).exists() + + # Get + resource = await connector.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.outputs is not None + + # Verify result content + import pandas as pd_lib + + df = pd_lib.read_parquet(metadata.result_uri) + assert df["answer"].iloc[0] == 42 + + # Delete + await connector.delete(metadata) + assert not pathlib.Path(metadata.result_uri).exists() diff --git a/plugins/duckdb/tests/test_task.py b/plugins/duckdb/tests/test_task.py new file mode 100644 index 000000000..f94376f5a --- /dev/null +++ b/plugins/duckdb/tests/test_task.py @@ -0,0 +1,98 @@ +from flyte.models import SerializationContext +from flyteidl2.core.tasks_pb2 import Sql + +from flyteplugins.duckdb.task import DuckDB, DuckDBConfig + +SCTX = SerializationContext(version="test") + + +def _make_task(**kwargs) -> DuckDB: + defaults = { + "name": "test", + "query_template": "SELECT 1", + } + defaults.update(kwargs) + return DuckDB(**defaults) + + +class TestDuckDBTask: + def test_minimal_creation(self): + task = _make_task() + assert task._TASK_TYPE == "duckdb" + assert task.query_template == "SELECT 1" + assert task.plugin_config.database_path == ":memory:" + assert task.plugin_config.extensions is None + + def test_custom_config(self): + config = DuckDBConfig(database_path="/data/test.duckdb", extensions=["httpfs", "json"]) + task = _make_task(plugin_config=config) + assert task.plugin_config.database_path == "/data/test.duckdb" + assert task.plugin_config.extensions == ["httpfs", "json"] + + def test_whitespace_normalization(self): + task = _make_task( + query_template=""" + SELECT * + FROM users + WHERE id = 1 + """ + ) + assert task.query_template == "SELECT * FROM users WHERE id = 1" + + def test_tab_normalization(self): + task = _make_task(query_template="SELECT\t*\tFROM\tusers") + assert task.query_template == "SELECT * FROM users" + + +class TestCustomConfig: + def test_default_config(self): + task = _make_task() + config = task.custom_config(SCTX) + + assert config["database_path"] == ":memory:" + assert "extensions" not in config + + def test_custom_database_path(self): + db_config = DuckDBConfig(database_path="/data/analytics.duckdb") + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config["database_path"] == "/data/analytics.duckdb" + + def test_with_extensions(self): + db_config = DuckDBConfig(extensions=["httpfs", "spatial"]) + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config["extensions"] == ["httpfs", "spatial"] + + def test_no_extensions_by_default(self): + task = _make_task() + config = task.custom_config(SCTX) + + assert "extensions" not in config + + def test_full_config(self): + db_config = DuckDBConfig(database_path="/tmp/test.duckdb", extensions=["httpfs"]) + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config == { + "database_path": "/tmp/test.duckdb", + "extensions": ["httpfs"], + } + + +class TestSql: + def test_sql_returns_ansi_dialect(self): + task = _make_task(query_template="SELECT * FROM users") + sql = task.sql(SCTX) + + assert sql.statement == "SELECT * FROM users" + assert sql.dialect == Sql.Dialect.ANSI + + def test_sql_with_parameterized_query(self): + task = _make_task(query_template="SELECT * FROM users WHERE id = %(user_id)s") + sql = task.sql(SCTX) + + assert sql.statement == "SELECT * FROM users WHERE id = %(user_id)s"