Skip to content

Commit 07f4a5e

Browse files
authored
Merge pull request #4 from databendlabs/fix/integration-tests
fix: update integration tests and UDF signatures
2 parents e366407 + 1139b1d commit 07f4a5e

File tree

7 files changed

+74
-45
lines changed

7 files changed

+74
-45
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ uv run databend-aiserver --port 8815
2222
### 1. Register Functions in Databend
2323

2424
```sql
25-
CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, max_files INT)
26-
RETURNS TABLE (stage_name VARCHAR, path VARCHAR, fullpath VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR)
25+
CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, pattern VARCHAR, max_files INT)
26+
RETURNS TABLE (stage_name VARCHAR, path VARCHAR, uri VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR)
2727
LANGUAGE PYTHON HANDLER = 'ai_list_files' ADDRESS = '<your-ai-server-address>';
2828

2929
CREATE OR REPLACE FUNCTION ai_embed_1024(text VARCHAR)
3030
RETURNS VECTOR(1024)
3131
LANGUAGE PYTHON HANDLER = 'ai_embed_1024' ADDRESS = '<your-ai-server-address>';
3232

33-
CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, path VARCHAR)
33+
CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, file_path VARCHAR)
3434
RETURNS VARIANT
3535
LANGUAGE PYTHON HANDLER = 'ai_parse_document' ADDRESS = '<your-ai-server-address>';
3636
```

databend_aiserver/udfs/docparse.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def _chunk_document(doc: Any) -> Tuple[List[Dict[str, Any]], int]:
218218

219219

220220
def _format_response(
221-
path: str,
222-
full_path: str,
221+
file_path: str,
222+
uri: str,
223223
pages: List[Dict[str, Any]],
224224
file_size: int,
225225
timings: Dict[str, float],
@@ -232,9 +232,9 @@ def _format_response(
232232
"chunk_size": DEFAULT_CHUNK_SIZE,
233233
"duration_ms": timings.get("total", 0.0),
234234
"file_size": file_size,
235-
"filename": Path(path).name,
235+
"filename": Path(file_path).name,
236236
"num_tokens": num_tokens,
237-
"path": full_path,
237+
"uri": uri,
238238
"timings_ms": timings,
239239
"version": 1,
240240
}
@@ -257,27 +257,27 @@ def _format_response(
257257
result_type="VARIANT",
258258
io_threads=4,
259259
)
260-
def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any]:
260+
def ai_parse_document(stage_location: StageLocation, file_path: str) -> Dict[str, Any]:
261261
"""Parse a document and return Snowflake-compatible layout output."""
262262
try:
263263
t_total_ns = perf_counter_ns()
264264
runtime = get_runtime()
265265
logger.info(
266266
"ai_parse_document start path=%s runtime_device=%s kind=%s",
267-
path,
267+
file_path,
268268
runtime.capabilities.preferred_device,
269269
runtime.capabilities.device_kind,
270270
)
271271

272272
backend = _get_doc_parser_backend()
273273
t_convert_start_ns = perf_counter_ns()
274-
result, file_size = backend.convert(stage_location, path)
274+
result, file_size = backend.convert(stage_location, file_path)
275275
t_convert_end_ns = perf_counter_ns()
276276

277277
pages, num_tokens = _chunk_document(result.document)
278278
t_chunk_end_ns = perf_counter_ns()
279279

280-
full_path = resolve_full_path(stage_location, path)
280+
uri = resolve_full_path(stage_location, file_path)
281281

282282
timings = {
283283
"convert": (t_convert_end_ns - t_convert_start_ns) / 1_000_000.0,
@@ -286,12 +286,12 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
286286
}
287287

288288
payload = _format_response(
289-
path, full_path, pages, file_size, timings, num_tokens
289+
file_path, uri, pages, file_size, timings, num_tokens
290290
)
291291

292292
logger.info(
293293
"ai_parse_document path=%s backend=%s chunks=%s duration_ms=%.1f",
294-
path,
294+
file_path,
295295
getattr(backend, "name", "unknown"),
296296
len(pages),
297297
timings["total"],
@@ -301,8 +301,8 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
301301
except Exception as exc: # pragma: no cover
302302
return {
303303
"metadata": {
304-
"path": path,
305-
"filename": Path(path).name,
304+
"file_path": file_path,
305+
"filename": Path(file_path).name,
306306
},
307307
"chunks": [],
308308
"error_information": [{"message": str(exc), "type": exc.__class__.__name__}],

databend_aiserver/udfs/stage.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import fnmatch
1920
import logging
2021
from time import perf_counter
2122
from typing import Any, Dict, Iterable, List, Optional
@@ -116,11 +117,11 @@ def _collect_stage_files(
116117

117118
@udf(
118119
stage_refs=["stage_location"],
119-
input_types=["INT"],
120+
input_types=["VARCHAR", "INT"],
120121
result_type=[
121122
("stage_name", "VARCHAR"),
122123
("path", "VARCHAR"),
123-
("fullpath", "VARCHAR"),
124+
("uri", "VARCHAR"),
124125
("size", "UINT64"),
125126
("last_modified", "VARCHAR"),
126127
("etag", "VARCHAR"),
@@ -129,14 +130,15 @@ def _collect_stage_files(
129130
name="ai_list_files",
130131
)
131132
def ai_list_files(
132-
stage_location: StageLocation, max_files: Optional[int]
133+
stage_location: StageLocation, pattern: Optional[str], max_files: Optional[int]
133134
) -> Iterable[Dict[str, Any]]:
134135
"""List objects in a stage."""
135136

136137
logging.getLogger(__name__).info(
137-
"ai_list_files start stage=%s relative=%s max_files=%s",
138+
"ai_list_files start stage=%s relative=%s pattern=%s max_files=%s",
138139
stage_location.stage_name,
139140
stage_location.relative_path,
141+
pattern,
140142
max_files,
141143
)
142144

@@ -154,6 +156,9 @@ def ai_list_files(
154156

155157
count = 0
156158
for entry in scanner:
159+
if pattern and not fnmatch.fnmatch(entry.path, pattern):
160+
continue
161+
157162
if max_files > 0 and count >= max_files:
158163
truncated = True
159164
break
@@ -174,7 +179,7 @@ def ai_list_files(
174179
yield {
175180
"stage_name": stage_location.stage_name,
176181
"path": entry.path,
177-
"fullpath": resolve_storage_uri(stage_location, entry.path),
182+
"uri": resolve_storage_uri(stage_location, entry.path),
178183
"size": metadata.content_length,
179184
"last_modified": _format_last_modified(
180185
getattr(metadata, "last_modified", None)

tests/integration/test_docparse_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _normalize_payload(payload):
4949
return {
5050
"chunk_count": metadata.get("chunk_count"),
5151
"chunk_len": len(chunks),
52-
"has_path": "path" in metadata,
52+
"has_uri": "uri" in metadata,
5353
"has_filename": "filename" in metadata,
5454
"has_file_size": "file_size" in metadata,
5555
"has_duration": "duration_ms" in metadata,
@@ -63,7 +63,7 @@ def test_docparse_pdf_structure(running_server, memory_stage):
6363

6464
norm = _normalize_payload(payload)
6565
assert norm["chunk_count"] == norm["chunk_len"]
66-
assert norm["has_path"]
66+
assert norm["has_uri"]
6767
assert norm["has_filename"]
6868
assert norm["has_file_size"]
6969
assert norm["has_duration"]
@@ -84,7 +84,7 @@ def test_docparse_docx_structure(running_server, memory_stage):
8484

8585
norm = _normalize_payload(payload)
8686
assert norm["chunk_count"] == norm["chunk_len"]
87-
assert norm["has_path"]
87+
assert norm["has_uri"]
8888
assert norm["has_filename"]
8989
assert norm["has_file_size"]
9090
assert norm["has_duration"]

tests/integration/test_listing_integration.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,23 @@
1414

1515
import pytest
1616
from databend_udf.client import UDFClient
17+
from typing import List, Dict, Any
1718

18-
from tests.integration.conftest import build_stage_mapping
19+
from tests.integration.conftest import build_stage_mapping, StageLocation
1920

2021

21-
def _get_listing(running_server, memory_stage, max_files=0):
22-
client = UDFClient(host="127.0.0.1", port=running_server)
22+
def _get_listing(
23+
server_port: int, stage: StageLocation, pattern: str = None, max_files: int = 0
24+
) -> List[Dict[str, Any]]:
25+
client = UDFClient(host="127.0.0.1", port=server_port)
26+
27+
# ai_list_files(stage_location, pattern, max_files)
28+
# UDFClient.call_function accepts *args, not RecordBatch
2329
return client.call_function(
2430
"ai_list_files",
31+
pattern,
2532
max_files,
26-
stage_locations=[build_stage_mapping(memory_stage, "stage_location")],
33+
stage_locations=[build_stage_mapping(stage, "stage_location")],
2734
)
2835

2936

@@ -43,10 +50,9 @@ def test_list_stage_files_content(running_server, memory_stage):
4350
def test_list_stage_files_metadata(running_server, memory_stage):
4451
rows = _get_listing(running_server, memory_stage)
4552
assert {row["stage_name"] for row in rows} == {memory_stage.stage_name}
46-
# Check for fullpath instead of relative_path
47-
# Memory stage fullpath might be just the path if no bucket/root
48-
assert all("fullpath" in row for row in rows)
49-
assert all(row["fullpath"].endswith(row["path"]) for row in rows)
53+
# Memory stage uri might be just the path if no bucket/root
54+
assert all("uri" in row for row in rows)
55+
assert all(row["uri"].endswith(row["path"]) for row in rows)
5056
# Check that last_modified key exists (value might be None for memory backend)
5157
assert all("last_modified" in row for row in rows)
5258

@@ -55,22 +61,40 @@ def test_list_stage_files_schema(running_server, memory_stage):
5561
rows = _get_listing(running_server, memory_stage)
5662
for row in rows:
5763
assert "path" in row
58-
assert "fullpath" in row
64+
assert "uri" in row
5965
assert "size" in row
6066
assert "last_modified" in row
6167
assert "etag" in row # May be None
6268
assert "content_type" in row # May be None
63-
64-
# Verify order implicitly by checking keys list if needed,
69+
70+
# Verify order implicitly by checking keys list if needed,
6571
# but for now just existence is enough as dicts are ordered in Python 3.7+
6672
keys = list(row.keys())
67-
# Expected keys: stage_name, path, fullpath, size, last_modified, etag, content_type
73+
# Expected keys: stage_name, path, uri, size, last_modified, etag, content_type
6874
# Note: stage_name is added by _get_listing or the UDF logic, let's check the core ones
69-
assert keys.index("path") < keys.index("fullpath")
75+
assert keys.index("path") < keys.index("uri")
7076
assert keys.index("last_modified") < keys.index("etag")
7177

7278

7379
def test_list_stage_files_truncation(running_server, memory_stage):
7480
rows = _get_listing(running_server, memory_stage, max_files=1)
7581
assert len(rows) == 1
7682
assert "last_modified" in rows[0]
83+
84+
85+
def test_list_stage_files_pattern(running_server, memory_stage):
86+
# Test pattern matching - patterns match against full path (e.g., "data/file.pdf")
87+
rows = _get_listing(running_server, memory_stage, pattern="data/*.pdf")
88+
assert len(rows) == 1
89+
assert rows[0]["path"].endswith(".pdf")
90+
91+
rows = _get_listing(running_server, memory_stage, pattern="data/*.docx")
92+
assert len(rows) == 1
93+
assert rows[0]["path"].endswith(".docx")
94+
95+
rows = _get_listing(running_server, memory_stage, pattern="data/subdir/*")
96+
# Matches data/subdir/ and data/subdir/note.txt
97+
assert len(rows) == 2
98+
paths = {r["path"] for r in rows}
99+
assert "data/subdir/note.txt" in paths
100+
assert "data/subdir/" in paths

tests/unit/test_docparse_path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ def test_docparse_metadata_path_uses_root(memory_stage_with_root):
2020
raw = ai_parse_document(memory_stage_with_root, "2206.01062.pdf")
2121
payload = json.loads(raw) if isinstance(raw, str) else raw
2222
meta = payload.get("metadata", {})
23-
assert meta["path"] == "s3://wizardbend/dataset/data/2206.01062.pdf"
23+
assert meta["uri"] == "s3://wizardbend/dataset/data/2206.01062.pdf"
2424
assert meta["filename"] == "2206.01062.pdf"

tests/unit/test_opendal_api.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def test_opendal_entry_has_path_attribute():
3131
assert entry.path == "test.txt"
3232

3333

34-
def test_opendal_entry_no_metadata_attribute():
35-
"""Verify Entry objects don't have metadata attribute (API changed)."""
34+
def test_opendal_entry_has_metadata_attribute():
35+
"""Verify Entry objects have metadata attribute."""
3636
op = Operator("memory")
3737
op.write("test.txt", b"hello")
3838

3939
entries = list(op.list(""))
4040
entry = entries[0]
4141

42-
# Entry should NOT have metadata attribute
43-
assert not hasattr(entry, "metadata")
42+
# Entry SHOULD have metadata attribute in newer opendal
43+
assert hasattr(entry, "metadata")
4444

4545

4646
def test_opendal_stat_returns_metadata():
@@ -60,15 +60,15 @@ def test_opendal_stat_returns_metadata():
6060
assert metadata.content_length == 11 # len("hello world")
6161

6262

63-
def test_opendal_metadata_no_is_dir_method():
64-
"""Verify Metadata doesn't have is_dir() method."""
63+
def test_opendal_metadata_has_is_dir_method():
64+
"""Verify Metadata has is_dir() method."""
6565
op = Operator("memory")
6666
op.write("test.txt", b"hello")
6767

6868
metadata = op.stat("test.txt")
6969

70-
# Metadata should NOT have is_dir() method
71-
assert not hasattr(metadata, "is_dir")
70+
# Metadata SHOULD have is_dir() method in newer opendal
71+
assert hasattr(metadata, "is_dir")
7272

7373

7474
def test_opendal_directory_detection_via_path():

0 commit comments

Comments
 (0)