Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ uv run databend-aiserver --port 8815
### 1. Register Functions in Databend

```sql
CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, max_files INT)
RETURNS TABLE (stage_name VARCHAR, path VARCHAR, fullpath VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR)
CREATE OR REPLACE FUNCTION ai_list_files(stage_location STAGE_LOCATION, pattern VARCHAR, max_files INT)
RETURNS TABLE (stage_name VARCHAR, path VARCHAR, uri VARCHAR, size UINT64, last_modified VARCHAR, etag VARCHAR, content_type VARCHAR)
LANGUAGE PYTHON HANDLER = 'ai_list_files' ADDRESS = '<your-ai-server-address>';

CREATE OR REPLACE FUNCTION ai_embed_1024(text VARCHAR)
RETURNS VECTOR(1024)
LANGUAGE PYTHON HANDLER = 'ai_embed_1024' ADDRESS = '<your-ai-server-address>';

CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, path VARCHAR)
CREATE OR REPLACE FUNCTION ai_parse_document(stage_location STAGE_LOCATION, file_path VARCHAR)
RETURNS VARIANT
LANGUAGE PYTHON HANDLER = 'ai_parse_document' ADDRESS = '<your-ai-server-address>';
```
Expand Down
24 changes: 12 additions & 12 deletions databend_aiserver/udfs/docparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def _chunk_document(doc: Any) -> Tuple[List[Dict[str, Any]], int]:


def _format_response(
path: str,
full_path: str,
file_path: str,
uri: str,
pages: List[Dict[str, Any]],
file_size: int,
timings: Dict[str, float],
Expand All @@ -232,9 +232,9 @@ def _format_response(
"chunk_size": DEFAULT_CHUNK_SIZE,
"duration_ms": timings.get("total", 0.0),
"file_size": file_size,
"filename": Path(path).name,
"filename": Path(file_path).name,
"num_tokens": num_tokens,
"path": full_path,
"uri": uri,
"timings_ms": timings,
"version": 1,
}
Expand All @@ -257,27 +257,27 @@ def _format_response(
result_type="VARIANT",
io_threads=4,
)
def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any]:
def ai_parse_document(stage_location: StageLocation, file_path: str) -> Dict[str, Any]:
"""Parse a document and return Snowflake-compatible layout output."""
try:
t_total_ns = perf_counter_ns()
runtime = get_runtime()
logger.info(
"ai_parse_document start path=%s runtime_device=%s kind=%s",
path,
file_path,
runtime.capabilities.preferred_device,
runtime.capabilities.device_kind,
)

backend = _get_doc_parser_backend()
t_convert_start_ns = perf_counter_ns()
result, file_size = backend.convert(stage_location, path)
result, file_size = backend.convert(stage_location, file_path)
t_convert_end_ns = perf_counter_ns()

pages, num_tokens = _chunk_document(result.document)
t_chunk_end_ns = perf_counter_ns()

full_path = resolve_full_path(stage_location, path)
uri = resolve_full_path(stage_location, file_path)

timings = {
"convert": (t_convert_end_ns - t_convert_start_ns) / 1_000_000.0,
Expand All @@ -286,12 +286,12 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
}

payload = _format_response(
path, full_path, pages, file_size, timings, num_tokens
file_path, uri, pages, file_size, timings, num_tokens
)

logger.info(
"ai_parse_document path=%s backend=%s chunks=%s duration_ms=%.1f",
path,
file_path,
getattr(backend, "name", "unknown"),
len(pages),
timings["total"],
Expand All @@ -301,8 +301,8 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
except Exception as exc: # pragma: no cover
return {
"metadata": {
"path": path,
"filename": Path(path).name,
"file_path": file_path,
"filename": Path(file_path).name,
},
"chunks": [],
"error_information": [{"message": str(exc), "type": exc.__class__.__name__}],
Expand Down
15 changes: 10 additions & 5 deletions databend_aiserver/udfs/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import fnmatch
import logging
from time import perf_counter
from typing import Any, Dict, Iterable, List, Optional
Expand Down Expand Up @@ -116,11 +117,11 @@ def _collect_stage_files(

@udf(
stage_refs=["stage_location"],
input_types=["INT"],
input_types=["VARCHAR", "INT"],
result_type=[
("stage_name", "VARCHAR"),
("path", "VARCHAR"),
("fullpath", "VARCHAR"),
("uri", "VARCHAR"),
("size", "UINT64"),
("last_modified", "VARCHAR"),
("etag", "VARCHAR"),
Expand All @@ -129,14 +130,15 @@ def _collect_stage_files(
name="ai_list_files",
)
def ai_list_files(
stage_location: StageLocation, max_files: Optional[int]
stage_location: StageLocation, pattern: Optional[str], max_files: Optional[int]
) -> Iterable[Dict[str, Any]]:
"""List objects in a stage."""

logging.getLogger(__name__).info(
"ai_list_files start stage=%s relative=%s max_files=%s",
"ai_list_files start stage=%s relative=%s pattern=%s max_files=%s",
stage_location.stage_name,
stage_location.relative_path,
pattern,
max_files,
)

Expand All @@ -154,6 +156,9 @@ def ai_list_files(

count = 0
for entry in scanner:
if pattern and not fnmatch.fnmatch(entry.path, pattern):
continue

if max_files > 0 and count >= max_files:
truncated = True
break
Expand All @@ -174,7 +179,7 @@ def ai_list_files(
yield {
"stage_name": stage_location.stage_name,
"path": entry.path,
"fullpath": resolve_storage_uri(stage_location, entry.path),
"uri": resolve_storage_uri(stage_location, entry.path),
"size": metadata.content_length,
"last_modified": _format_last_modified(
getattr(metadata, "last_modified", None)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_docparse_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _normalize_payload(payload):
return {
"chunk_count": metadata.get("chunk_count"),
"chunk_len": len(chunks),
"has_path": "path" in metadata,
"has_uri": "uri" in metadata,
"has_filename": "filename" in metadata,
"has_file_size": "file_size" in metadata,
"has_duration": "duration_ms" in metadata,
Expand All @@ -63,7 +63,7 @@ def test_docparse_pdf_structure(running_server, memory_stage):

norm = _normalize_payload(payload)
assert norm["chunk_count"] == norm["chunk_len"]
assert norm["has_path"]
assert norm["has_uri"]
assert norm["has_filename"]
assert norm["has_file_size"]
assert norm["has_duration"]
Expand All @@ -84,7 +84,7 @@ def test_docparse_docx_structure(running_server, memory_stage):

norm = _normalize_payload(payload)
assert norm["chunk_count"] == norm["chunk_len"]
assert norm["has_path"]
assert norm["has_uri"]
assert norm["has_filename"]
assert norm["has_file_size"]
assert norm["has_duration"]
Expand Down
50 changes: 37 additions & 13 deletions tests/integration/test_listing_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@

import pytest
from databend_udf.client import UDFClient
from typing import List, Dict, Any

from tests.integration.conftest import build_stage_mapping
from tests.integration.conftest import build_stage_mapping, StageLocation


def _get_listing(running_server, memory_stage, max_files=0):
client = UDFClient(host="127.0.0.1", port=running_server)
def _get_listing(
server_port: int, stage: StageLocation, pattern: str = None, max_files: int = 0
) -> List[Dict[str, Any]]:
client = UDFClient(host="127.0.0.1", port=server_port)

# ai_list_files(stage_location, pattern, max_files)
# UDFClient.call_function accepts *args, not RecordBatch
return client.call_function(
"ai_list_files",
pattern,
max_files,
stage_locations=[build_stage_mapping(memory_stage, "stage_location")],
stage_locations=[build_stage_mapping(stage, "stage_location")],
)


Expand All @@ -43,10 +50,9 @@ def test_list_stage_files_content(running_server, memory_stage):
def test_list_stage_files_metadata(running_server, memory_stage):
rows = _get_listing(running_server, memory_stage)
assert {row["stage_name"] for row in rows} == {memory_stage.stage_name}
# Check for fullpath instead of relative_path
# Memory stage fullpath might be just the path if no bucket/root
assert all("fullpath" in row for row in rows)
assert all(row["fullpath"].endswith(row["path"]) for row in rows)
# Memory stage uri might be just the path if no bucket/root
assert all("uri" in row for row in rows)
assert all(row["uri"].endswith(row["path"]) for row in rows)
# Check that last_modified key exists (value might be None for memory backend)
assert all("last_modified" in row for row in rows)

Expand All @@ -55,22 +61,40 @@ def test_list_stage_files_schema(running_server, memory_stage):
rows = _get_listing(running_server, memory_stage)
for row in rows:
assert "path" in row
assert "fullpath" in row
assert "uri" in row
assert "size" in row
assert "last_modified" in row
assert "etag" in row # May be None
assert "content_type" in row # May be None
# Verify order implicitly by checking keys list if needed,

# Verify order implicitly by checking keys list if needed,
# but for now just existence is enough as dicts are ordered in Python 3.7+
keys = list(row.keys())
# Expected keys: stage_name, path, fullpath, size, last_modified, etag, content_type
# Expected keys: stage_name, path, uri, size, last_modified, etag, content_type
# Note: stage_name is added by _get_listing or the UDF logic, let's check the core ones
assert keys.index("path") < keys.index("fullpath")
assert keys.index("path") < keys.index("uri")
assert keys.index("last_modified") < keys.index("etag")


def test_list_stage_files_truncation(running_server, memory_stage):
rows = _get_listing(running_server, memory_stage, max_files=1)
assert len(rows) == 1
assert "last_modified" in rows[0]


def test_list_stage_files_pattern(running_server, memory_stage):
# Test pattern matching - patterns match against full path (e.g., "data/file.pdf")
rows = _get_listing(running_server, memory_stage, pattern="data/*.pdf")
assert len(rows) == 1
assert rows[0]["path"].endswith(".pdf")

rows = _get_listing(running_server, memory_stage, pattern="data/*.docx")
assert len(rows) == 1
assert rows[0]["path"].endswith(".docx")

rows = _get_listing(running_server, memory_stage, pattern="data/subdir/*")
# Matches data/subdir/ and data/subdir/note.txt
assert len(rows) == 2
paths = {r["path"] for r in rows}
assert "data/subdir/note.txt" in paths
assert "data/subdir/" in paths
2 changes: 1 addition & 1 deletion tests/unit/test_docparse_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def test_docparse_metadata_path_uses_root(memory_stage_with_root):
raw = ai_parse_document(memory_stage_with_root, "2206.01062.pdf")
payload = json.loads(raw) if isinstance(raw, str) else raw
meta = payload.get("metadata", {})
assert meta["path"] == "s3://wizardbend/dataset/data/2206.01062.pdf"
assert meta["uri"] == "s3://wizardbend/dataset/data/2206.01062.pdf"
assert meta["filename"] == "2206.01062.pdf"
16 changes: 8 additions & 8 deletions tests/unit/test_opendal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def test_opendal_entry_has_path_attribute():
assert entry.path == "test.txt"


def test_opendal_entry_no_metadata_attribute():
"""Verify Entry objects don't have metadata attribute (API changed)."""
def test_opendal_entry_has_metadata_attribute():
"""Verify Entry objects have metadata attribute."""
op = Operator("memory")
op.write("test.txt", b"hello")

entries = list(op.list(""))
entry = entries[0]

# Entry should NOT have metadata attribute
assert not hasattr(entry, "metadata")
# Entry SHOULD have metadata attribute in newer opendal
assert hasattr(entry, "metadata")


def test_opendal_stat_returns_metadata():
Expand All @@ -60,15 +60,15 @@ def test_opendal_stat_returns_metadata():
assert metadata.content_length == 11 # len("hello world")


def test_opendal_metadata_no_is_dir_method():
"""Verify Metadata doesn't have is_dir() method."""
def test_opendal_metadata_has_is_dir_method():
"""Verify Metadata has is_dir() method."""
op = Operator("memory")
op.write("test.txt", b"hello")

metadata = op.stat("test.txt")

# Metadata should NOT have is_dir() method
assert not hasattr(metadata, "is_dir")
# Metadata SHOULD have is_dir() method in newer opendal
assert hasattr(metadata, "is_dir")


def test_opendal_directory_detection_via_path():
Expand Down