From 1139b1d29918ee7fe9695c29608569eea60c56b2 Mon Sep 17 00:00:00 2001 From: BohuTANG Date: Fri, 28 Nov 2025 15:15:54 +0800 Subject: [PATCH] fix: update integration tests and UDF signatures - Update ai_parse_document to return uri in metadata - Update ai_list_files to handle directory entries in pattern matching - Fix unit tests for opendal API changes - Update UDF signatures to use stage_location parameter --- README.md | 6 +-- databend_aiserver/udfs/docparse.py | 24 ++++----- databend_aiserver/udfs/stage.py | 15 ++++-- .../integration/test_docparse_integration.py | 6 +-- tests/integration/test_listing_integration.py | 50 ++++++++++++++----- tests/unit/test_docparse_path.py | 2 +- tests/unit/test_opendal_api.py | 16 +++--- 7 files changed, 74 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index c7a93c8..0f8170d 100644 --- a/README.md +++ b/README.md @@ -22,15 +22,15 @@ uv run databend-aiserver --port 8815 ### 1. Register Functions in Databend ```sql -CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, max_files INT) -RETURNS TABLE (stage_name VARCHAR, path VARCHAR, fullpath VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR) +CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, pattern VARCHAR, max_files INT) +RETURNS TABLE (stage_name VARCHAR, path VARCHAR, uri VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR) LANGUAGE PYTHON HANDLER = 'ai_list_files' ADDRESS = ''; CREATE OR REPLACE FUNCTION ai_embed_1024(text VARCHAR) RETURNS VECTOR(1024) LANGUAGE PYTHON HANDLER = 'ai_embed_1024' ADDRESS = ''; -CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, path VARCHAR) +CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, file_path VARCHAR) RETURNS VARIANT LANGUAGE PYTHON HANDLER = 'ai_parse_document' ADDRESS = ''; ``` diff --git a/databend_aiserver/udfs/docparse.py b/databend_aiserver/udfs/docparse.py index 1b28a22..d09ec3d 100644 --- a/databend_aiserver/udfs/docparse.py +++ b/databend_aiserver/udfs/docparse.py @@ -218,8 +218,8 @@ def _chunk_document(doc: Any) -> Tuple[List[Dict[str, Any]], int]: def _format_response( - path: str, - full_path: str, + file_path: str, + uri: str, pages: List[Dict[str, Any]], file_size: int, timings: Dict[str, float], @@ -232,9 +232,9 @@ def _format_response( "chunk_size": DEFAULT_CHUNK_SIZE, "duration_ms": timings.get("total", 0.0), "file_size": file_size, - "filename": Path(path).name, + "filename": Path(file_path).name, "num_tokens": num_tokens, - "path": full_path, + "uri": uri, "timings_ms": timings, "version": 1, } @@ -257,27 +257,27 @@ def _format_response( result_type="VARIANT", io_threads=4, ) -def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any]: +def ai_parse_document(stage_location: StageLocation, file_path: str) -> Dict[str, Any]: """Parse a document and return Snowflake-compatible layout output.""" try: t_total_ns = perf_counter_ns() runtime = get_runtime() logger.info( "ai_parse_document start path=%s runtime_device=%s kind=%s", - path, + file_path, runtime.capabilities.preferred_device, runtime.capabilities.device_kind, ) backend = _get_doc_parser_backend() t_convert_start_ns = perf_counter_ns() - result, file_size = backend.convert(stage_location, path) + result, file_size = backend.convert(stage_location, file_path) t_convert_end_ns = perf_counter_ns() pages, num_tokens = _chunk_document(result.document) t_chunk_end_ns = perf_counter_ns() - full_path = resolve_full_path(stage_location, path) + uri = resolve_full_path(stage_location, file_path) timings = { "convert": (t_convert_end_ns - t_convert_start_ns) / 1_000_000.0, @@ -286,12 +286,12 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any } payload = _format_response( - path, full_path, pages, file_size, timings, num_tokens + file_path, uri, pages, file_size, timings, num_tokens ) logger.info( "ai_parse_document path=%s backend=%s chunks=%s duration_ms=%.1f", - path, + file_path, getattr(backend, "name", "unknown"), len(pages), timings["total"], @@ -301,8 +301,8 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any except Exception as exc: # pragma: no cover return { "metadata": { - "path": path, - "filename": Path(path).name, + "file_path": file_path, + "filename": Path(file_path).name, }, "chunks": [], "error_information": [{"message": str(exc), "type": exc.__class__.__name__}], diff --git a/databend_aiserver/udfs/stage.py b/databend_aiserver/udfs/stage.py index 7cbd9e0..8ebb2b4 100644 --- a/databend_aiserver/udfs/stage.py +++ b/databend_aiserver/udfs/stage.py @@ -16,6 +16,7 @@ from __future__ import annotations +import fnmatch import logging from time import perf_counter from typing import Any, Dict, Iterable, List, Optional @@ -116,11 +117,11 @@ def _collect_stage_files( @udf( stage_refs=["stage_location"], - input_types=["INT"], + input_types=["VARCHAR", "INT"], result_type=[ ("stage_name", "VARCHAR"), ("path", "VARCHAR"), - ("fullpath", "VARCHAR"), + ("uri", "VARCHAR"), ("size", "UINT64"), ("last_modified", "VARCHAR"), ("etag", "VARCHAR"), @@ -129,14 +130,15 @@ def _collect_stage_files( name="ai_list_files", ) def ai_list_files( - stage_location: StageLocation, max_files: Optional[int] + stage_location: StageLocation, pattern: Optional[str], max_files: Optional[int] ) -> Iterable[Dict[str, Any]]: """List objects in a stage.""" logging.getLogger(__name__).info( - "ai_list_files start stage=%s relative=%s max_files=%s", + "ai_list_files start stage=%s relative=%s pattern=%s max_files=%s", stage_location.stage_name, stage_location.relative_path, + pattern, max_files, ) @@ -154,6 +156,9 @@ def ai_list_files( count = 0 for entry in scanner: + if pattern and not fnmatch.fnmatch(entry.path, pattern): + continue + if max_files > 0 and count >= max_files: truncated = True break @@ -174,7 +179,7 @@ def ai_list_files( yield { "stage_name": stage_location.stage_name, "path": entry.path, - "fullpath": resolve_storage_uri(stage_location, entry.path), + "uri": resolve_storage_uri(stage_location, entry.path), "size": metadata.content_length, "last_modified": _format_last_modified( getattr(metadata, "last_modified", None) diff --git a/tests/integration/test_docparse_integration.py b/tests/integration/test_docparse_integration.py index 0de2349..a40b14a 100644 --- a/tests/integration/test_docparse_integration.py +++ b/tests/integration/test_docparse_integration.py @@ -49,7 +49,7 @@ def _normalize_payload(payload): return { "chunk_count": metadata.get("chunk_count"), "chunk_len": len(chunks), - "has_path": "path" in metadata, + "has_uri": "uri" in metadata, "has_filename": "filename" in metadata, "has_file_size": "file_size" in metadata, "has_duration": "duration_ms" in metadata, @@ -63,7 +63,7 @@ def test_docparse_pdf_structure(running_server, memory_stage): norm = _normalize_payload(payload) assert norm["chunk_count"] == norm["chunk_len"] - assert norm["has_path"] + assert norm["has_uri"] assert norm["has_filename"] assert norm["has_file_size"] assert norm["has_duration"] @@ -84,7 +84,7 @@ def test_docparse_docx_structure(running_server, memory_stage): norm = _normalize_payload(payload) assert norm["chunk_count"] == norm["chunk_len"] - assert norm["has_path"] + assert norm["has_uri"] assert norm["has_filename"] assert norm["has_file_size"] assert norm["has_duration"] diff --git a/tests/integration/test_listing_integration.py b/tests/integration/test_listing_integration.py index dcdcf7e..e482b6d 100644 --- a/tests/integration/test_listing_integration.py +++ b/tests/integration/test_listing_integration.py @@ -14,16 +14,23 @@ import pytest from databend_udf.client import UDFClient +from typing import List, Dict, Any -from tests.integration.conftest import build_stage_mapping +from tests.integration.conftest import build_stage_mapping, StageLocation -def _get_listing(running_server, memory_stage, max_files=0): - client = UDFClient(host="127.0.0.1", port=running_server) +def _get_listing( + server_port: int, stage: StageLocation, pattern: str = None, max_files: int = 0 +) -> List[Dict[str, Any]]: + client = UDFClient(host="127.0.0.1", port=server_port) + + # ai_list_files(stage_location, pattern, max_files) + # UDFClient.call_function accepts *args, not RecordBatch return client.call_function( "ai_list_files", + pattern, max_files, - stage_locations=[build_stage_mapping(memory_stage, "stage_location")], + stage_locations=[build_stage_mapping(stage, "stage_location")], ) @@ -43,10 +50,9 @@ def test_list_stage_files_content(running_server, memory_stage): def test_list_stage_files_metadata(running_server, memory_stage): rows = _get_listing(running_server, memory_stage) assert {row["stage_name"] for row in rows} == {memory_stage.stage_name} - # Check for fullpath instead of relative_path - # Memory stage fullpath might be just the path if no bucket/root - assert all("fullpath" in row for row in rows) - assert all(row["fullpath"].endswith(row["path"]) for row in rows) + # Memory stage uri might be just the path if no bucket/root + assert all("uri" in row for row in rows) + assert all(row["uri"].endswith(row["path"]) for row in rows) # Check that last_modified key exists (value might be None for memory backend) assert all("last_modified" in row for row in rows) @@ -55,18 +61,18 @@ def test_list_stage_files_schema(running_server, memory_stage): rows = _get_listing(running_server, memory_stage) for row in rows: assert "path" in row - assert "fullpath" in row + assert "uri" in row assert "size" in row assert "last_modified" in row assert "etag" in row # May be None assert "content_type" in row # May be None - - # Verify order implicitly by checking keys list if needed, + + # Verify order implicitly by checking keys list if needed, # but for now just existence is enough as dicts are ordered in Python 3.7+ keys = list(row.keys()) - # Expected keys: stage_name, path, fullpath, size, last_modified, etag, content_type + # Expected keys: stage_name, path, uri, size, last_modified, etag, content_type # Note: stage_name is added by _get_listing or the UDF logic, let's check the core ones - assert keys.index("path") < keys.index("fullpath") + assert keys.index("path") < keys.index("uri") assert keys.index("last_modified") < keys.index("etag") @@ -74,3 +80,21 @@ def test_list_stage_files_truncation(running_server, memory_stage): rows = _get_listing(running_server, memory_stage, max_files=1) assert len(rows) == 1 assert "last_modified" in rows[0] + + +def test_list_stage_files_pattern(running_server, memory_stage): + # Test pattern matching - patterns match against full path (e.g., "data/file.pdf") + rows = _get_listing(running_server, memory_stage, pattern="data/*.pdf") + assert len(rows) == 1 + assert rows[0]["path"].endswith(".pdf") + + rows = _get_listing(running_server, memory_stage, pattern="data/*.docx") + assert len(rows) == 1 + assert rows[0]["path"].endswith(".docx") + + rows = _get_listing(running_server, memory_stage, pattern="data/subdir/*") + # Matches data/subdir/ and data/subdir/note.txt + assert len(rows) == 2 + paths = {r["path"] for r in rows} + assert "data/subdir/note.txt" in paths + assert "data/subdir/" in paths diff --git a/tests/unit/test_docparse_path.py b/tests/unit/test_docparse_path.py index bfb0233..e88f347 100644 --- a/tests/unit/test_docparse_path.py +++ b/tests/unit/test_docparse_path.py @@ -20,5 +20,5 @@ def test_docparse_metadata_path_uses_root(memory_stage_with_root): raw = ai_parse_document(memory_stage_with_root, "2206.01062.pdf") payload = json.loads(raw) if isinstance(raw, str) else raw meta = payload.get("metadata", {}) - assert meta["path"] == "s3://wizardbend/dataset/data/2206.01062.pdf" + assert meta["uri"] == "s3://wizardbend/dataset/data/2206.01062.pdf" assert meta["filename"] == "2206.01062.pdf" diff --git a/tests/unit/test_opendal_api.py b/tests/unit/test_opendal_api.py index 37445a0..6c8ed52 100644 --- a/tests/unit/test_opendal_api.py +++ b/tests/unit/test_opendal_api.py @@ -31,16 +31,16 @@ def test_opendal_entry_has_path_attribute(): assert entry.path == "test.txt" -def test_opendal_entry_no_metadata_attribute(): - """Verify Entry objects don't have metadata attribute (API changed).""" +def test_opendal_entry_has_metadata_attribute(): + """Verify Entry objects have metadata attribute.""" op = Operator("memory") op.write("test.txt", b"hello") entries = list(op.list("")) entry = entries[0] - # Entry should NOT have metadata attribute - assert not hasattr(entry, "metadata") + # Entry SHOULD have metadata attribute in newer opendal + assert hasattr(entry, "metadata") def test_opendal_stat_returns_metadata(): @@ -60,15 +60,15 @@ def test_opendal_stat_returns_metadata(): assert metadata.content_length == 11 # len("hello world") -def test_opendal_metadata_no_is_dir_method(): - """Verify Metadata doesn't have is_dir() method.""" +def test_opendal_metadata_has_is_dir_method(): + """Verify Metadata has is_dir() method.""" op = Operator("memory") op.write("test.txt", b"hello") metadata = op.stat("test.txt") - # Metadata should NOT have is_dir() method - assert not hasattr(metadata, "is_dir") + # Metadata SHOULD have is_dir() method in newer opendal + assert hasattr(metadata, "is_dir") def test_opendal_directory_detection_via_path():