diff --git a/plugins/duckdb/README.md b/plugins/duckdb/README.md new file mode 100644 index 000000000..8afe9315b --- /dev/null +++ b/plugins/duckdb/README.md @@ -0,0 +1,146 @@ +# DuckDB Plugin for Flyte + +Run DuckDB SQL queries as Flyte tasks with parameterized inputs, extension support, and DataFrame output. + +DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute locally and synchronously, so no remote credentials or connection setup is required. + +## Installation + +```bash +pip install flyteplugins-duckdb +``` + +## Quick start + +```python +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +import flyte + +config = DuckDBConfig() + +query = DuckDB( + name="count_rows", + query_template="SELECT COUNT(*) AS total FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## In-memory queries + +By default, DuckDB runs in-memory. This is ideal for ad-hoc analytics and querying files directly: + +```python +config = DuckDBConfig() # defaults to database_path=":memory:" + +task = DuckDB( + name="analyze", + query_template="SELECT * FROM 'sales.parquet' WHERE amount > 100", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## File-based databases + +To query a persistent DuckDB database file: + +```python +config = DuckDBConfig(database_path="/data/analytics.duckdb") + +task = DuckDB( + name="query_db", + query_template="SELECT * FROM customers LIMIT 10", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## Parameterized queries + +Use `%(name)s` placeholders and typed `inputs`: + +```python +lookup = DuckDB( + name="lookup_user", + query_template="SELECT * FROM 'users.parquet' WHERE id = %(user_id)s", + plugin_config=config, + inputs={"user_id": int}, + output_dataframe_type=pd.DataFrame, +) +``` + +## Extensions + +DuckDB supports extensions for additional functionality. Install and load them via `DuckDBConfig.extensions`: + +```python +config = DuckDBConfig(extensions=["httpfs"]) + +task = DuckDB( + name="query_s3", + query_template="SELECT * FROM 's3://bucket/data.parquet' LIMIT 100", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +Common extensions: +- `httpfs` - Read files from HTTP/S3 +- `spatial` - Geospatial functions +- `json` - JSON processing +- `excel` - Read Excel files + +## Reading results as DataFrames + +Set `output_dataframe_type` to get query results as a pandas DataFrame: + +```python +import pandas as pd + +select_task = DuckDB( + name="get_data", + query_template="SELECT * FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) +``` + +## Full example + +```python +import pandas as pd +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +import flyte + +config = DuckDBConfig(extensions=["httpfs"]) + +analyze_task = DuckDB( + name="analyze_sales", + query_template="SELECT region, SUM(amount) as total FROM 'sales.parquet' GROUP BY region", + plugin_config=config, + output_dataframe_type=pd.DataFrame, +) + +duckdb_env = flyte.TaskEnvironment.from_task("duckdb_env", analyze_task) + +env = flyte.TaskEnvironment( + name="example_env", + image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-duckdb"), + depends_on=[duckdb_env], +) + + +@env.task +async def main() -> float: + df = await analyze_task() + return df["total"].sum().item() + + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.with_runcontext(mode="remote").run(main) + print(run.url) +``` diff --git a/plugins/duckdb/pyproject.toml b/plugins/duckdb/pyproject.toml new file mode 100644 index 000000000..0cf5e971d --- /dev/null +++ b/plugins/duckdb/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "flyteplugins-duckdb" +dynamic = ["version"] +description = "DuckDB plugin for flyte" +readme = "README.md" +authors = [{ name = "Andre Ahlert", email = "andreahlert@users.noreply.github.com" }] +requires-python = ">=3.10" +dependencies = [ + "flyte[connector]", + "duckdb", +] + +[project.entry-points."flyte.connectors"] +duckdb = "flyteplugins.duckdb.connector:DuckDBConnector" + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] +"tests/*" = ["ASYNC230", "ASYNC240"] diff --git a/plugins/duckdb/src/flyteplugins/duckdb/__init__.py b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py new file mode 100644 index 000000000..aaa2a3df2 --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py @@ -0,0 +1,52 @@ +""" +Key features: + +- Run SQL queries against DuckDB (in-memory or file-based) +- Parameterized SQL queries with typed inputs +- Query Parquet, CSV, and JSON files directly +- Load DuckDB extensions (httpfs, spatial, etc.) +- Returns query results as DataFrames + +Basic usage example: +```python +import flyte +from flyte.io import DataFrame +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +config = DuckDBConfig() + +count_rows = DuckDB( + name="count_rows", + query_template="SELECT COUNT(*) AS total FROM 'data.parquet'", + plugin_config=config, + output_dataframe_type=DataFrame, +) + +flyte.TaskEnvironment.from_task("duckdb_env", count_rows) + +if __name__ == "__main__": + flyte.init_from_config() + + # Run locally (connector runs in-process) + run = flyte.with_runcontext(mode="local").run(count_rows) + + # Run remotely (connector runs on the control plane) + run = flyte.with_runcontext(mode="remote").run(count_rows) + + print(run.url) +``` +""" + +from flyte.io._dataframe.dataframe import DataFrameTransformerEngine + +from flyteplugins.duckdb.connector import DuckDBConnector +from flyteplugins.duckdb.dataframe import ( + DuckDBToPandasDecodingHandler, + PandasToDuckDBEncodingHandler, +) +from flyteplugins.duckdb.task import DuckDB, DuckDBConfig + +DataFrameTransformerEngine.register(PandasToDuckDBEncodingHandler()) +DataFrameTransformerEngine.register(DuckDBToPandasDecodingHandler()) + +__all__ = ["DuckDB", "DuckDBConfig", "DuckDBConnector"] diff --git a/plugins/duckdb/src/flyteplugins/duckdb/connector.py b/plugins/duckdb/src/flyteplugins/duckdb/connector.py new file mode 100644 index 000000000..7f53b2cb2 --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/connector.py @@ -0,0 +1,150 @@ +import asyncio +import os +import tempfile +import uuid +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from flyte import logger +from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta +from flyte.io import DataFrame +from flyteidl2.core.execution_pb2 import TaskExecution +from flyteidl2.core.tasks_pb2 import TaskTemplate +from google.protobuf import json_format + +import duckdb + +TASK_TYPE = "duckdb" + + +@dataclass +class DuckDBJobMetadata(ResourceMeta): + """ + Metadata for a DuckDB query job. + + Attributes: + query_id: Unique identifier for tracking the query. + result_uri: Path to the temporary parquet file containing query results. + has_output: Indicates if the query produces output. + """ + + query_id: str + result_uri: Optional[str] = None + has_output: bool = False + + +class DuckDBConnector(AsyncConnector): + name = "DuckDB Connector" + task_type_name = TASK_TYPE + metadata_type = DuckDBJobMetadata + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> DuckDBJobMetadata: + """ + Execute a DuckDB query. + + DuckDB queries run locally and synchronously. The query executes entirely within + this method, and the result (if any) is written to a temporary parquet file. + + Args: + task_template: The Flyte task template containing the SQL query and configuration. + inputs: Optional dictionary of input parameters for parameterized queries. + + Returns: + A DuckDBJobMetadata object containing the query ID and result file path. + """ + custom = json_format.MessageToDict(task_template.custom) if task_template.custom else {} + + database_path = custom.get("database_path", ":memory:") + extensions = custom.get("extensions", []) + + query = task_template.sql.statement + has_output = task_template.interface.outputs is not None and len(task_template.interface.outputs.variables) > 0 + query_id = str(uuid.uuid4()) + + def _execute_query(): + conn = duckdb.connect(database=database_path) + try: + for ext in extensions: + conn.install_extension(ext) + conn.load_extension(ext) + + if inputs: + params = list(inputs.values()) + param_names = list(inputs.keys()) + positional_query = query + for name in param_names: + positional_query = positional_query.replace(f"%({name})s", "?") + result = conn.execute(positional_query, params) + else: + result = conn.execute(query) + + result_uri = None + if has_output: + df = result.fetchdf() + result_uri = os.path.join(tempfile.gettempdir(), f"duckdb_result_{query_id}.parquet") + df.to_parquet(result_uri) + + return result_uri + finally: + conn.close() + + loop = asyncio.get_running_loop() + result_uri = await loop.run_in_executor(None, _execute_query) + + logger.info(f"DuckDB query executed with ID: {query_id}") + + return DuckDBJobMetadata( + query_id=query_id, + result_uri=result_uri, + has_output=has_output, + ) + + async def get( + self, + resource_meta: DuckDBJobMetadata, + **kwargs, + ) -> Resource: + """ + Get the status of a DuckDB query. + + DuckDB queries complete synchronously in create(), so this always returns SUCCEEDED. + + Args: + resource_meta: The DuckDBJobMetadata from the create() call. + + Returns: + A Resource object with SUCCEEDED status and optional outputs. + """ + outputs = None + if resource_meta.has_output and resource_meta.result_uri: + outputs = {"results": DataFrame(uri=f"duckdb://{resource_meta.result_uri}")} + + return Resource(phase=TaskExecution.SUCCEEDED, message="Query completed", outputs=outputs) + + async def delete( + self, + resource_meta: DuckDBJobMetadata, + **kwargs, + ): + """ + Clean up temporary result files. + + Args: + resource_meta: The DuckDBJobMetadata containing the result file path. + """ + + def _cleanup(): + if resource_meta.result_uri and os.path.exists(resource_meta.result_uri): + os.remove(resource_meta.result_uri) + logger.info(f"Cleaned up temporary result file: {resource_meta.result_uri}") + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _cleanup) + + +ConnectorRegistry.register(DuckDBConnector()) diff --git a/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py b/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py new file mode 100644 index 000000000..6a2bcb70f --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/dataframe.py @@ -0,0 +1,63 @@ +import typing + +from flyte._utils import lazy_module +from flyte.io._dataframe.dataframe import DataFrame, DataFrameDecoder, DataFrameEncoder +from flyteidl2.core import literals_pb2, types_pb2 + +if typing.TYPE_CHECKING: + import pandas as pd +else: + pd = lazy_module("pandas") + +DUCKDB = "duckdb" +PROTOCOL_PREFIX = "duckdb://" + + +def _read_from_duckdb( + flyte_value: literals_pb2.StructuredDataset, + current_task_metadata: literals_pb2.StructuredDatasetMetadata, +) -> "pd.DataFrame": + uri = flyte_value.uri + if not uri: + raise ValueError("flyte_value.uri cannot be empty.") + + parquet_path = uri.removeprefix(PROTOCOL_PREFIX) + return pd.read_parquet(parquet_path) + + +def _write_to_duckdb(dataframe: DataFrame): + if not dataframe.uri: + raise ValueError("dataframe.uri cannot be None.") + + uri = typing.cast(str, dataframe.uri) + parquet_path = uri.removeprefix(PROTOCOL_PREFIX) + df = typing.cast("pd.DataFrame", dataframe.val) + df.to_parquet(parquet_path) + + +class PandasToDuckDBEncodingHandler(DataFrameEncoder): + def __init__(self): + super().__init__(pd.DataFrame, DUCKDB, "") + + async def encode( + self, + dataframe: DataFrame, + structured_dataset_type: types_pb2.StructuredDatasetType, + ) -> literals_pb2.StructuredDataset: + _write_to_duckdb(dataframe) + return literals_pb2.StructuredDataset( + uri=typing.cast(str, dataframe.uri), + metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), + ) + + +class DuckDBToPandasDecodingHandler(DataFrameDecoder): + def __init__(self): + super().__init__(pd.DataFrame, DUCKDB, "") + + async def decode( + self, + flyte_value: literals_pb2.StructuredDataset, + current_task_metadata: literals_pb2.StructuredDatasetMetadata, + ) -> "pd.DataFrame": + return _read_from_duckdb(flyte_value, current_task_metadata) diff --git a/plugins/duckdb/src/flyteplugins/duckdb/task.py b/plugins/duckdb/src/flyteplugins/duckdb/task.py new file mode 100644 index 000000000..caddd937a --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/task.py @@ -0,0 +1,83 @@ +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +from flyte.connectors import AsyncConnectorExecutorMixin +from flyte.extend import TaskTemplate +from flyte.models import NativeInterface, SerializationContext +from flyteidl2.core import tasks_pb2 + + +@dataclass +class DuckDBConfig(object): + """ + Configure a DuckDB Task using a `DuckDBConfig` object. + + Args: + database_path: Path to a DuckDB database file, or ":memory:" for an in-memory database. + extensions: Optional list of DuckDB extensions to install and load before executing + the query (e.g., ["httpfs", "parquet"]). + """ + + database_path: str = ":memory:" + extensions: Optional[List[str]] = None + + +class DuckDB(AsyncConnectorExecutorMixin, TaskTemplate): + _TASK_TYPE = "duckdb" + + def __init__( + self, + name: str, + query_template: str, + plugin_config: Optional[DuckDBConfig] = None, + inputs: Optional[Dict[str, Type]] = None, + output_dataframe_type: Optional[Type] = None, + **kwargs, + ): + """ + Task to run SQL queries against DuckDB. + + DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute + locally and synchronously, so no remote credentials or polling are required. + + Args: + name: The name of this task. + query_template: The SQL query to run. This can be parameterized using Python's + printf-style string formatting with named parameters (e.g. %(param_name)s). + plugin_config: Optional `DuckDBConfig` object. Defaults to in-memory database. + inputs: Name and type of inputs specified as a dictionary. + output_dataframe_type: If the query produces data, specify the output dataframe type. + """ + outputs = None + if output_dataframe_type is not None: + outputs = {"results": output_dataframe_type} + + super().__init__( + name=name, + interface=NativeInterface( + {k: (v, None) for k, v in inputs.items()} if inputs else {}, + outputs or {}, + ), + task_type=self._TASK_TYPE, + image=None, + **kwargs, + ) + + self.output_dataframe_type = output_dataframe_type + self.plugin_config = plugin_config or DuckDBConfig() + self.query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() + + def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]: + config = { + "database_path": self.plugin_config.database_path, + } + + if self.plugin_config.extensions: + config["extensions"] = self.plugin_config.extensions + + return config + + def sql(self, sctx: SerializationContext) -> Optional[str]: + sql = tasks_pb2.Sql(statement=self.query_template, dialect=tasks_pb2.Sql.Dialect.ANSI) + return sql diff --git a/plugins/duckdb/tests/__init__.py b/plugins/duckdb/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/plugins/duckdb/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/plugins/duckdb/tests/test_connector.py b/plugins/duckdb/tests/test_connector.py new file mode 100644 index 000000000..21af37d0e --- /dev/null +++ b/plugins/duckdb/tests/test_connector.py @@ -0,0 +1,291 @@ +import pathlib +import tempfile + +import pytest +from flyte.io import DataFrame +from flyteidl2.core.execution_pb2 import TaskExecution +from flyteidl2.core.interface_pb2 import Variable, VariableEntry, VariableMap +from flyteidl2.core.tasks_pb2 import Sql, TaskTemplate +from flyteidl2.core.types_pb2 import LiteralType, StructuredDatasetType +from google.protobuf import struct_pb2 + +from flyteplugins.duckdb.connector import ( + DuckDBConnector, + DuckDBJobMetadata, +) + + +def _make_output_variable_map(): + return VariableMap( + variables=[ + VariableEntry( + key="results", + value=Variable(type=LiteralType(structured_dataset_type=StructuredDatasetType())), + ) + ] + ) + + +def test_metadata_creation(): + """Test creating DuckDBJobMetadata instance.""" + metadata = DuckDBJobMetadata( + query_id="test-query-123", + result_uri="/tmp/duckdb_result_test.parquet", + has_output=True, + ) + assert metadata.query_id == "test-query-123" + assert metadata.result_uri == "/tmp/duckdb_result_test.parquet" + assert metadata.has_output is True + + +def test_metadata_defaults(): + """Test DuckDBJobMetadata default values.""" + metadata = DuckDBJobMetadata(query_id="test-query-456") + assert metadata.query_id == "test-query-456" + assert metadata.result_uri is None + assert metadata.has_output is False + + +class TestDuckDBConnector: + @pytest.fixture + def connector(self): + """Create a DuckDBConnector instance.""" + return DuckDBConnector() + + @pytest.fixture + def task_template_minimal(self): + """Create a minimal task template for testing.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 1", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + return template + + @pytest.fixture + def task_template_with_output(self): + """Create a task template with output variables.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT * FROM range(10)", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + return template + + @pytest.fixture + def task_template_with_extensions(self): + """Create a task template with extensions configured.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 1", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + + # Extensions as a list in the Struct + extensions_list = struct_pb2.ListValue() + extensions_list.values.add().string_value = "json" + custom.fields["extensions"].CopyFrom(struct_pb2.Value(list_value=extensions_list)) + + template.custom.CopyFrom(custom) + + return template + + def test_connector_class_attributes(self, connector): + """Test that the connector has the correct class attributes.""" + assert connector.name == "DuckDB Connector" + assert connector.task_type_name == "duckdb" + assert connector.metadata_type == DuckDBJobMetadata + + @pytest.mark.asyncio + async def test_create_minimal(self, connector, task_template_minimal): + """Test creating a DuckDB query without inputs and without output.""" + metadata = await connector.create(task_template_minimal, inputs=None) + + assert metadata.query_id is not None + assert metadata.has_output is False + assert metadata.result_uri is None + + @pytest.mark.asyncio + async def test_create_with_output(self, connector, task_template_with_output): + """Test creating a DuckDB query with output produces a parquet file.""" + metadata = await connector.create(task_template_with_output, inputs=None) + + assert metadata.query_id is not None + assert metadata.has_output is True + assert metadata.result_uri is not None + assert metadata.result_uri.endswith(".parquet") + assert pathlib.Path(metadata.result_uri).exists() + + # Clean up + pathlib.Path(metadata.result_uri).unlink() + + @pytest.mark.asyncio + async def test_create_with_inputs(self, connector): + """Test creating a DuckDB query with parameterized inputs.""" + template = TaskTemplate() + template.sql.CopyFrom( + Sql(statement="SELECT * FROM range(10) WHERE range > %(min_val)s", dialect=Sql.Dialect.ANSI) + ) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + metadata = await connector.create(template, inputs={"min_val": 5}) + + assert metadata.has_output is True + assert metadata.result_uri is not None + assert pathlib.Path(metadata.result_uri).exists() + + # Verify the result has correct data + import pandas as pd_lib + + df = pd_lib.read_parquet(metadata.result_uri) + assert len(df) == 4 # values 6, 7, 8, 9 + assert all(df["range"] > 5) + + # Clean up + pathlib.Path(metadata.result_uri).unlink() + + @pytest.mark.asyncio + async def test_create_without_output(self, connector): + """Test creating a DuckDB query that produces no output.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="CREATE TABLE test (id INTEGER)", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + metadata = await connector.create(template, inputs=None) + + assert metadata.has_output is False + assert metadata.result_uri is None + + @pytest.mark.asyncio + async def test_get_succeeded_with_output(self, connector): + """Test getting a completed DuckDB query with output.""" + metadata = DuckDBJobMetadata( + query_id="test-123", + result_uri="/tmp/duckdb_result_test.parquet", + has_output=True, + ) + + resource = await connector.get(metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message == "Query completed" + assert resource.outputs is not None + assert "results" in resource.outputs + assert isinstance(resource.outputs["results"], DataFrame) + assert "duckdb://" in resource.outputs["results"].uri + assert "/tmp/duckdb_result_test.parquet" in resource.outputs["results"].uri + + @pytest.mark.asyncio + async def test_get_succeeded_without_output(self, connector): + """Test getting a completed DuckDB query without output.""" + metadata = DuckDBJobMetadata( + query_id="test-456", + has_output=False, + ) + + resource = await connector.get(metadata) + + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message == "Query completed" + assert resource.outputs is None + + @pytest.mark.asyncio + async def test_delete_cleans_up_file(self, connector): + """Test that delete removes the temporary result file.""" + # Create a temporary file to simulate a result + tmp_path = pathlib.Path(tempfile.gettempdir()) / "duckdb_result_delete_test.parquet" + tmp_path.write_text("test") + + assert tmp_path.exists() + + metadata = DuckDBJobMetadata( + query_id="test-delete", + result_uri=str(tmp_path), + has_output=True, + ) + + await connector.delete(metadata) + + assert not tmp_path.exists() + + @pytest.mark.asyncio + async def test_delete_no_file(self, connector): + """Test that delete handles missing files gracefully.""" + metadata = DuckDBJobMetadata( + query_id="test-no-file", + result_uri="/tmp/nonexistent_file.parquet", + has_output=True, + ) + + # Should not raise + await connector.delete(metadata) + + @pytest.mark.asyncio + async def test_delete_no_result_uri(self, connector): + """Test that delete handles None result_uri gracefully.""" + metadata = DuckDBJobMetadata( + query_id="test-no-uri", + has_output=False, + ) + + # Should not raise + await connector.delete(metadata) + + @pytest.mark.asyncio + async def test_create_with_extensions(self, connector, task_template_with_extensions): + """Test that extensions are installed and loaded.""" + # json extension is bundled with DuckDB, so this should work without network + metadata = await connector.create(task_template_with_extensions, inputs=None) + assert metadata.query_id is not None + + @pytest.mark.asyncio + async def test_end_to_end_flow(self, connector): + """Test the complete create -> get -> delete flow.""" + template = TaskTemplate() + template.sql.CopyFrom(Sql(statement="SELECT 42 AS answer", dialect=Sql.Dialect.ANSI)) + template.metadata.runtime.version = "1.0.0" + + custom = struct_pb2.Struct() + custom["database_path"] = ":memory:" + template.custom.CopyFrom(custom) + + template.interface.outputs.CopyFrom(_make_output_variable_map()) + + # Create + metadata = await connector.create(template, inputs=None) + assert metadata.has_output is True + assert pathlib.Path(metadata.result_uri).exists() + + # Get + resource = await connector.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.outputs is not None + + # Verify result content + import pandas as pd_lib + + df = pd_lib.read_parquet(metadata.result_uri) + assert df["answer"].iloc[0] == 42 + + # Delete + await connector.delete(metadata) + assert not pathlib.Path(metadata.result_uri).exists() diff --git a/plugins/duckdb/tests/test_task.py b/plugins/duckdb/tests/test_task.py new file mode 100644 index 000000000..f94376f5a --- /dev/null +++ b/plugins/duckdb/tests/test_task.py @@ -0,0 +1,98 @@ +from flyte.models import SerializationContext +from flyteidl2.core.tasks_pb2 import Sql + +from flyteplugins.duckdb.task import DuckDB, DuckDBConfig + +SCTX = SerializationContext(version="test") + + +def _make_task(**kwargs) -> DuckDB: + defaults = { + "name": "test", + "query_template": "SELECT 1", + } + defaults.update(kwargs) + return DuckDB(**defaults) + + +class TestDuckDBTask: + def test_minimal_creation(self): + task = _make_task() + assert task._TASK_TYPE == "duckdb" + assert task.query_template == "SELECT 1" + assert task.plugin_config.database_path == ":memory:" + assert task.plugin_config.extensions is None + + def test_custom_config(self): + config = DuckDBConfig(database_path="/data/test.duckdb", extensions=["httpfs", "json"]) + task = _make_task(plugin_config=config) + assert task.plugin_config.database_path == "/data/test.duckdb" + assert task.plugin_config.extensions == ["httpfs", "json"] + + def test_whitespace_normalization(self): + task = _make_task( + query_template=""" + SELECT * + FROM users + WHERE id = 1 + """ + ) + assert task.query_template == "SELECT * FROM users WHERE id = 1" + + def test_tab_normalization(self): + task = _make_task(query_template="SELECT\t*\tFROM\tusers") + assert task.query_template == "SELECT * FROM users" + + +class TestCustomConfig: + def test_default_config(self): + task = _make_task() + config = task.custom_config(SCTX) + + assert config["database_path"] == ":memory:" + assert "extensions" not in config + + def test_custom_database_path(self): + db_config = DuckDBConfig(database_path="/data/analytics.duckdb") + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config["database_path"] == "/data/analytics.duckdb" + + def test_with_extensions(self): + db_config = DuckDBConfig(extensions=["httpfs", "spatial"]) + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config["extensions"] == ["httpfs", "spatial"] + + def test_no_extensions_by_default(self): + task = _make_task() + config = task.custom_config(SCTX) + + assert "extensions" not in config + + def test_full_config(self): + db_config = DuckDBConfig(database_path="/tmp/test.duckdb", extensions=["httpfs"]) + task = _make_task(plugin_config=db_config) + config = task.custom_config(SCTX) + + assert config == { + "database_path": "/tmp/test.duckdb", + "extensions": ["httpfs"], + } + + +class TestSql: + def test_sql_returns_ansi_dialect(self): + task = _make_task(query_template="SELECT * FROM users") + sql = task.sql(SCTX) + + assert sql.statement == "SELECT * FROM users" + assert sql.dialect == Sql.Dialect.ANSI + + def test_sql_with_parameterized_query(self): + task = _make_task(query_template="SELECT * FROM users WHERE id = %(user_id)s") + sql = task.sql(SCTX) + + assert sql.statement == "SELECT * FROM users WHERE id = %(user_id)s"