diff --git a/.claude/skills/address-pr-review/SKILL.md b/.claude/skills/address-pr-review/SKILL.md index f9eb8bc..94bb7a1 100644 --- a/.claude/skills/address-pr-review/SKILL.md +++ b/.claude/skills/address-pr-review/SKILL.md @@ -79,18 +79,20 @@ python .claude/skills/address-pr-review/scripts/fetch_comments.py --all | Phase | Actions | |-------|---------| -| **Fetch** | Run `--summary` first to see counts
Then `--id ` for each comment to analyze
Exit if no unresolved comments | +| **Fetch** | Run `--summary` first to see counts
**Only process unresolved comments** — resolved ones are already closed, skip them
Then `--id ` for each unresolved comment to analyze
Exit if no unresolved comments | | **Per Comment** | Show: file:line, author, comment, ±10 lines context
Analyze: Valid/Nitpick/Disagree/Question
Recommend: Fix/Reply/Skip with reasoning | -| **Fix** | Minimal changes per llm/rules-*.md
Offer reply draft: `Fixed: [what]. [why]`
Show: `gh api --method POST repos/{owner}/{repo}/pulls/comments/$ID/replies -f body="..."` | -| **Reply** | Draft based on type: Question/Suggestion/Disagreement
Let user edit
Show gh command (never auto-post) | +| **Fix** | Minimal changes per llm/rules-*.md
Do NOT reply — just fix the code | +| **Reply** | Draft based on type: Question/Suggestion/Disagreement
Wait 2 minutes between each reply
Post with: `gh api --method POST repos/{owner}/{repo}/pulls/{PR}/comments -f body="..." -F in_reply_to=`
(never auto-post without user confirmation) | | **Summary** | Processed X/N: Fixed Y, Replied Z, Skipped W
List: files modified, reply drafts, next steps | ## Critical Principles | Principle | Violation Pattern | |-----------|-------------------| +| **Unresolved only** | Processing already-resolved comments — the script default filters to unresolved; never re-open resolved threads | | **Analyze first** | Accepting all feedback as valid without critical analysis | -| **Never auto-post** | Posting replies automatically instead of showing gh command | +| **Never auto-post** | Posting replies automatically without user confirmation or skipping 2-minute wait between replies | +| **No reply on fix** | Replying to comments that were addressed with a code fix — fixes speak for themselves | | **One at a time** | Batch processing all comments without individual analysis | | **Show context** | Making changes without displaying ±10 lines around code | | **Minimal changes** | Large refactors in response to small comments | diff --git a/.claude/skills/creating-pipeline-templates/SKILL.md b/.claude/skills/creating-pipeline-templates/SKILL.md index b1ed9a8..217d6af 100644 --- a/.claude/skills/creating-pipeline-templates/SKILL.md +++ b/.claude/skills/creating-pipeline-templates/SKILL.md @@ -71,7 +71,58 @@ StructureSampler → SemanticInfiller → DuplicateRemover # generation + metrics StructuredGenerator → FieldMapper → RagasMetrics + +# generation + review-friendly output +StructuredGenerator → FieldMapper (flatten for review) +``` + +## Adding a FieldMapper for Review + +The Review page displays records from the **last block's accumulated_state**. Only **first-level keys** are shown as primary/secondary fields. Nested objects (e.g. `generated.confirmed_dependencies`) appear as raw JSON strings and can't be configured as separate review fields. + +**Always add a `FieldMapper` as the last block** to surface the fields reviewers need at the top level. + +### Why it matters + +Without a FieldMapper, the accumulated_state after a `StructuredGenerator` looks like: +```json +{ + "input_field": "...", + "generated": { + "question": "...", + "answer": "...", + "contexts": ["..."] + } +} ``` +The review UI sees `input_field` and `generated` (a blob). Reviewers can't configure `question` or `answer` as primary fields. + +### How to add it + +Add a `FieldMapper` as the **last block** (or last before metrics/observability blocks): + +```yaml + - type: FieldMapper + config: + mappings: + # Flatten nested fields to top level + question: "{{ generated.question }}" + answer: "{{ generated.answer }}" + # tojson is safe only for structured data (IDs, numbers, short labels) + # avoid tojson on arrays/objects with free-text — newlines/quotes break JSON parsing + context_count: "{{ generated.contexts | length }}" + # Carry forward useful seed metadata + source: "{{ source_document }}" +``` + +### Rules + +1. **Map every field the reviewer needs** — if it's not a first-level key after the last block, it won't be configurable in the review field settings +2. **Use `| tojson`** for arrays/objects — FieldMapper auto-parses JSON strings back to objects, so the review UI can display them properly. **Exception:** `tojson` on arrays/objects whose values contain unescaped quotes or newlines (e.g. free-text descriptions) will break FieldMapper JSON parsing. In that case, map only scalar summaries (counts, IDs) and let the array flow through as an existing first-level key. +3. **Use `| length`** for counts — gives reviewers a quick numeric summary without expanding lists +4. **Use `| default('')`** for optional fields — prevents Jinja2 errors when a field is missing +5. **Don't map internal/noisy fields** — skip `folder_path`, `_usage`, `_seed_samples` etc. Only map what's useful for human review +6. **Order matters** — FieldMapper outputs merge into accumulated_state, so its keys become the available fields in the Review "Configure Fields" modal ## Step-by-Step Workflow @@ -126,6 +177,8 @@ StructuredGenerator → FieldMapper → RagasMetrics | Missing seed variable referenced in prompt | Add the variable to seed metadata | | MarkdownMultiplierBlock not first | Multiplier blocks must always be first | | Seed file not named `seed_.*` | Template ID must match: `foo.yaml` → `seed_foo.json` | +| Nested fields not visible in Review UI | Add a `FieldMapper` as last block to flatten nested outputs to top-level keys | +| Review shows `generated` as a JSON blob | Map individual sub-fields: `question: "{{ generated.question }}"` | ## Checklist @@ -135,6 +188,7 @@ StructuredGenerator → FieldMapper → RagasMetrics - [ ] Single execution produces expected output fields - [ ] Trace shows all blocks executed successfully - [ ] Seed file has 2-3 diverse examples +- [ ] FieldMapper as last block flattens outputs for Review UI (all reviewer-relevant fields are top-level keys) ## Related Skills diff --git a/.claude/skills/implementing-datagenflow-blocks/SKILL.md b/.claude/skills/implementing-datagenflow-blocks/SKILL.md index 9cd2c05..6769a54 100644 --- a/.claude/skills/implementing-datagenflow-blocks/SKILL.md +++ b/.claude/skills/implementing-datagenflow-blocks/SKILL.md @@ -467,6 +467,88 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: cached_embeddings = self._embeddings_cache[trace_id] ``` +## Agentic Tool-Calling Block Pattern + +For blocks that need multi-turn LLM reasoning with tool use (e.g. exploring an external data source before generating output): + +```python +async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager + + llm_config = await llm_config_manager.get_llm_model(self.model_name) + total_usage = pipeline.Usage(input_tokens=0, output_tokens=0, cached_tokens=0) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": render_template(self.user_prompt, context.accumulated_state)}, + ] + + for turn in range(self.max_turns): + if turn == self.max_turns - 1: + messages.append({"role": "user", "content": "Wrap up and return final JSON now."}) + + llm_params = llm_config_manager.prepare_llm_call( + llm_config, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + tools=TOOLS, + tool_choice="auto", + ) + llm_params["metadata"] = {"trace_id": context.trace_id, "tags": ["datagenflow"]} + + response = await litellm.acompletion(**llm_params) + msg = response.choices[0].message + total_usage.input_tokens += response.usage.prompt_tokens or 0 + total_usage.output_tokens += response.usage.completion_tokens or 0 + total_usage.cached_tokens += getattr(response.usage, "cache_read_input_tokens", 0) or 0 + + if not msg.tool_calls: + # final answer — parse JSON + try: + result = json.loads(msg.content or "{}") + except json.JSONDecodeError: + result = {} + return {"my_result": result.get("my_result", []), "_usage": total_usage.model_dump()} + + # append assistant message and process tool calls + messages.append({"role": "assistant", "content": None, "tool_calls": [ + {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + for tc in msg.tool_calls + ]}) + for tc in msg.tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + args = {} + # always use .get() — LLM may send malformed args + tool_result = _execute_tool(tc.function.name, args) + messages.append({"role": "tool", "tool_call_id": tc.id, "content": tool_result}) + + # max turns exhausted — force final answer without tools + messages.append({"role": "user", "content": "No more tool calls. Return final JSON NOW."}) + llm_params = llm_config_manager.prepare_llm_call( + llm_config, messages=messages, + temperature=self.temperature, max_tokens=self.max_tokens, + ) + llm_params["metadata"] = {"trace_id": context.trace_id, "tags": ["datagenflow"]} + response = await litellm.acompletion(**llm_params) + try: + result = json.loads(response.choices[0].message.content or "{}") + except json.JSONDecodeError: + result = {} + return {"my_result": result.get("my_result", []), "_usage": total_usage.model_dump()} +``` + +**Key rules:** +- Always nudge on last turn (`turn == max_turns - 1`) before the forced final call +- Always force a final call without tools when max_turns exhausted — otherwise you get no output +- Use `args.get("key", "")` not `args["key"]` — LLM may send malformed arguments +- If tool responses contain `"$ref"` keys, rename before sending: `output.replace('"$ref"', '"schema_ref"')` — Gemini rejects `$ref` in tool responses +- Cap tool result sizes (e.g. 50 items max) to avoid context overflow + +--- + ## Multiplier Blocks Blocks that generate multiple items from one input: diff --git a/.claude/skills/testing-pipeline-templates/SKILL.md b/.claude/skills/testing-pipeline-templates/SKILL.md index c5d8f5d..2159fcb 100644 --- a/.claude/skills/testing-pipeline-templates/SKILL.md +++ b/.claude/skills/testing-pipeline-templates/SKILL.md @@ -32,7 +32,12 @@ curl -s -X POST http://localhost:8000/api/pipelines//execute \ - `trace` — each entry has `block_type`, `execution_time`, `output` - `accumulated_state` — data flowing correctly between blocks? -**Red flags:** missing fields, metadata pollution (extra fields like `samples`, `target_count`), execution_time >30s, empty/null generator outputs. +**Check review readiness:** +- Look at the **last trace entry's `accumulated_state`** — these are the fields visible in the Review UI +- All reviewer-relevant fields should be **first-level keys** (not nested under `generated` or other objects) +- If useful fields are nested, add a `FieldMapper` as the last block to flatten them (see `creating-pipeline-templates` skill) + +**Red flags:** missing fields, metadata pollution (extra fields like `samples`, `target_count`), execution_time >30s, empty/null generator outputs, reviewer-relevant data buried in nested objects. ## Phase 2: Small Batch diff --git a/.claude/skills/using-datagenflow-extensibility/SKILL.md b/.claude/skills/using-datagenflow-extensibility/SKILL.md new file mode 100644 index 0000000..e73fed9 --- /dev/null +++ b/.claude/skills/using-datagenflow-extensibility/SKILL.md @@ -0,0 +1,255 @@ +--- +name: using-datagenflow-extensibility +description: Use when creating data generation pipelines with DataGenFlow's extensibility system — user_templates, user_blocks, docker-compose setup, and dgf CLI. Use for any task involving generating synthetic data, building custom pipelines, or extending DataGenFlow from an external project without modifying its source. +--- + +# Using DataGenFlow Extensibility + +Build data generation pipelines from your own repo using DataGenFlow as a Docker image. Custom blocks and templates live in your project — no DataGenFlow source modifications needed. + +## Project Structure + +```text +your-project/ + user_blocks/ # custom Python blocks (auto-discovered) + user_templates/ # custom YAML pipelines (auto-discovered) + data/ # persisted output data + docker-compose.yml # mounts volumes into DataGenFlow container + .env # API keys + config +``` + +## Quick Setup + +```bash +# 1. Create project structure +mkdir -p user_blocks user_templates data + +# 2. Create .env with your LLM provider API key +echo "LLM_API_KEY=your-api-key" > .env + +# 3. Create docker-compose.yml (see Docker section) + +# 4. Start DataGenFlow +docker-compose up -d + +# 5. Verify +curl http://localhost:8000/health +``` + +## docker-compose.yml + +```yaml +services: + datagenflow: + image: datagenflow:local + ports: + - "8000:8000" + volumes: + - ./user_blocks:/app/user_blocks + - ./user_templates:/app/user_templates + - ./data:/app/data + env_file: + - .env + environment: + - DATAGENFLOW_HOT_RELOAD=true + restart: unless-stopped +``` + +**Note (contributor setup):** A published image is not yet available. Build locally from the DataGenFlow repo: `docker build -f docker/Dockerfile -t datagenflow:local .` + +## Writing Templates + +Templates are YAML files in `user_templates/`. Template ID = filename stem. + +### YAML Format + +```yaml +name: "Display Name" +description: "What this pipeline generates" +blocks: + - type: BlockClassName # exact class name + config: + param: value # exact __init__ parameter names + user_prompt: "{{ var }}" # Jinja2 refs to seed metadata +``` + +### Seed Files + +Place next to template: `user_templates/seed_.json` + +```json +[ + {"repetitions": 3, "metadata": {"content": "input text here"}}, + {"repetitions": 2, "metadata": {"content": "another input"}} +] +``` + +For `MarkdownMultiplierBlock` as first block, use `seed_.md` instead. + +### Variable Flow + +Seed `metadata` keys become `{{ key }}` in first block's prompt. Each block's outputs become available to subsequent blocks via `{{ output_name }}`. + +Key output names by block: +- `TextGenerator` outputs: `assistant`, `system`, `user` +- `StructuredGenerator` outputs: `generated` +- `JSONValidatorBlock` outputs: `valid`, `parsed_json` +- `FieldMapper` outputs: dynamic (whatever you map) + +## Available Blocks — Quick Reference + +| Block | Config Params | Use For | +|-------|--------------|---------| +| `TextGenerator` | `model`, `temperature`, `max_tokens`, `system_prompt`, `user_prompt` | Free-text generation | +| `StructuredGenerator` | `model`, `temperature`, `max_tokens`, `user_prompt`, `json_schema` | JSON generation with schema | +| `JSONValidatorBlock` | `field_name`, `required_fields`, `strict` | Validate JSON output | +| `FieldMapper` | `mappings` | Rename/transform fields between blocks | +| `MarkdownMultiplierBlock` | `parser_type`, `chunk_size`, `chunk_overlap` | Split documents (must be first) | +| `StructureSampler` | _(see source)_ | Sample from structure (must be first) | +| `ValidatorBlock` | _(see source)_ | Text rule validation | +| `DuplicateRemover` | _(see source)_ | Embedding-based dedup | +| `DiversityScore` | _(see source)_ | Lexical diversity metric | +| `CoherenceScore` | _(see source)_ | Text coherence metric | +| `RagasMetrics` | _(see source)_ | RAGAS QA evaluation | +| `LangfuseBlock` | _(see source)_ | Observability tracing | + +## Common Pipeline Patterns + +```text +# Simple: generate structured JSON + validate +StructuredGenerator → JSONValidatorBlock + +# Document processing: chunk → generate text → structure → validate +MarkdownMultiplierBlock → TextGenerator → StructuredGenerator → JSONValidatorBlock + +# Augmentation: sample → fill → deduplicate +StructureSampler → SemanticInfiller → DuplicateRemover + +# With metrics: generate → map fields → evaluate +StructuredGenerator → FieldMapper → RagasMetrics +``` + +## Writing Custom Blocks + +Place `.py` files in `user_blocks/`. Auto-discovered if class inherits `BaseBlock`. + +```python +from lib.blocks.base import BaseBlock +from lib.entities.block_execution_context import BlockExecutionContext +from typing import Any + + +class MyCustomBlock(BaseBlock): + name = "My Custom Block" + description = "What it does" + category = "validators" # generators, validators, metrics, utilities, seeders, observability + inputs = ["text"] + outputs = ["result"] + + # optional: pip deps auto-detected + dependencies = ["some-package>=1.0"] + + def __init__(self, threshold: float = 0.5): + self.threshold = threshold + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + text = context.get_state("text", "") + return {"result": f"processed: {text}"} +``` + +Scaffold with: `dgf blocks scaffold MyBlock -c validators` + +## dgf CLI Commands + +```bash +# Status +dgf status # server health + counts + +# Blocks +dgf blocks list # all blocks with source/status +dgf blocks validate ./my_block.py # check syntax +dgf blocks scaffold MyBlock -c general # generate starter + +# Templates +dgf templates list # all templates with source +dgf templates validate ./flow.yaml # check YAML structure +dgf templates scaffold "My Flow" # generate starter + +# Image (production) +dgf image scaffold --blocks-dir ./user_blocks # Dockerfile with deps +dgf image build -t my-datagenflow:latest # build custom image +``` + +**Note (contributor setup):** Until `dgf` is published to PyPI, run from the DataGenFlow repo: `cd /path/to/DataGenFlow && uv run dgf ` + +## Testing a Template + +```bash +# 1. Validate YAML +uv run dgf templates validate ./user_templates/my_template.yaml + +# 2. Check it's discovered +uv run dgf templates list + +# 3. Create pipeline from template +curl -s -X POST http://localhost:8000/api/pipelines/from_template/my_template | python -m json.tool + +# 4. Execute with seed +curl -s -X POST http://localhost:8000/api/pipelines//execute \ + -H 'Content-Type: application/json' \ + -d '{"content": "test input"}' | python -m json.tool +``` + +Or use the UI at `http://localhost:8000` — templates appear in the pipeline creation flow. + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `DATAGENFLOW_ENDPOINT` | `http://localhost:8000` | API endpoint (for CLI) | +| `DATAGENFLOW_BLOCKS_PATH` | `user_blocks` | Path to user blocks dir | +| `DATAGENFLOW_TEMPLATES_PATH` | `user_templates` | Path to user templates dir | +| `DATAGENFLOW_HOT_RELOAD` | `true` | Enable file watching | +| `DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS` | `500` | Debounce interval | + +## Step-by-Step Workflow + +1. **Define the use case** — what data to generate, what schema, what seed inputs +2. **Choose blocks** — pick from table, wire outputs → inputs +3. **Write YAML** in `user_templates/.yaml` +4. **Write seed file** — `user_templates/seed_.json` with all `{{ vars }}` as metadata keys +5. **Validate** — `dgf templates validate` + `dgf templates list` +6. **Test single execution** — create pipeline from template, run with seed +7. **Iterate** — adjust prompts, schema, temperature based on output quality +8. **Scale** — increase seed repetitions, add more seed examples + +## Common Mistakes + +| Mistake | Fix | +|---------|-----| +| Template ID conflicts with builtin | Rename your file — builtins take precedence | +| Block `type` doesn't match class name | Use exact class name (e.g., `JSONValidatorBlock` not `JSONValidator`) | +| Config key doesn't match `__init__` param | Read block source or use `dgf blocks list` | +| Seed variable missing from metadata | Every `{{ var }}` in prompts needs a matching metadata key | +| Multiplier block not first | `MarkdownMultiplierBlock` and `StructureSampler` must be first | +| Hot reload not picking up changes | Check `DATAGENFLOW_HOT_RELOAD=true` and dirs exist before startup | +| Block shows unavailable | Missing deps — install via API or build custom image | + +## Checklist + +- [ ] Project structure created (user_blocks/, user_templates/, data/, docker-compose.yml, .env) +- [ ] DataGenFlow running and healthy (`curl /health`) +- [ ] Template YAML with correct block types and config keys +- [ ] Seed file named `seed_.json` with all referenced variables +- [ ] Template appears in `dgf templates list` +- [ ] Single execution produces expected output fields +- [ ] Seed file has 2-3 diverse examples for quality testing +- [ ] Custom blocks (if any) appear in `dgf blocks list` as source "user" + +## Related DataGenFlow Skills + +- `creating-pipeline-templates` — reference for builtin template patterns +- `implementing-datagenflow-blocks` — deep dive on block internals +- `debugging-pipelines` — troubleshooting execution failures +- `testing-pipeline-templates` — thorough end-to-end testing +- `configuring-models` — LLM provider setup diff --git a/.claude/skills/writing-e2e-tests/SKILL.md b/.claude/skills/writing-e2e-tests/SKILL.md index 56ad4a0..af3942b 100644 --- a/.claude/skills/writing-e2e-tests/SKILL.md +++ b/.claude/skills/writing-e2e-tests/SKILL.md @@ -23,7 +23,7 @@ tests/e2e/ ```bash # single suite python .claude/skills/webapp-testing/scripts/with_server.py \ - --server "cd /home/nicof/develop/DataGenFlow && DATABASE_PATH=data/test_qa_records.db uv run python app.py" \ + --server "DATABASE_PATH=data/test_qa_records.db uv run python app.py" \ --port 8000 \ -- python tests/e2e/test__e2e.py @@ -121,7 +121,7 @@ page.locator("text=completed").wait_for(timeout=30000) ```bash echo "Running tests..." python .claude/skills/webapp-testing/scripts/with_server.py \ - --server "cd $PROJECT_DIR && DATABASE_PATH=data/test_qa_records.db uv run python app.py" \ + --server "DATABASE_PATH=data/test_qa_records.db uv run python app.py" \ --port 8000 \ -- python tests/e2e/test__e2e.py ``` diff --git a/.env.example b/.env.example index 2a0a3c8..b4b8acd 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,15 @@ PORT=8000 # optional: enable debug logging (defaults to false) # DEBUG=true +# =========================================== +# Extensibility System +# =========================================== +# DATAGENFLOW_ENDPOINT=http://localhost:8000 +# DATAGENFLOW_BLOCKS_PATH=user_blocks +# DATAGENFLOW_TEMPLATES_PATH=user_templates +# DATAGENFLOW_HOT_RELOAD=true +# DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS=500 + # langfuse configuration LANGFUSE_SECRET_KEY="sk-..." LANGFUSE_PUBLIC_KEY="pk-..." diff --git a/app.py b/app.py index 625b97d..6fcb9f6 100644 --- a/app.py +++ b/app.py @@ -11,8 +11,9 @@ from pydantic import ValidationError as PydanticValidationError from config import settings +from lib.api.extensions import router as extensions_router from lib.blocks.registry import registry -from lib.constants import RECORD_UPDATABLE_FIELDS +from lib.constants import DEFAULT_BLOCKS_PATH, DEFAULT_TEMPLATES_PATH, RECORD_UPDATABLE_FIELDS from lib.entities import ( ConnectionTestResult, EmbeddingModelConfig, @@ -26,7 +27,9 @@ SeedValidationRequest, ValidationConfig, ) +from lib.entities.extensions import BlockInfo, TemplateInfo from lib.errors import BlockExecutionError, BlockNotFoundError, ValidationError +from lib.file_watcher import ExtensionFileWatcher from lib.job_processor import process_job_in_thread from lib.job_queue import JobQueue from lib.llm_config import LLMConfigError, LLMConfigManager, LLMConfigNotFoundError @@ -84,6 +87,18 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: await storage.init_db() + # ensure extension directories exist + Path(os.getenv("DATAGENFLOW_BLOCKS_PATH", DEFAULT_BLOCKS_PATH)).mkdir( + parents=True, exist_ok=True + ) + Path(os.getenv("DATAGENFLOW_TEMPLATES_PATH", DEFAULT_TEMPLATES_PATH)).mkdir( + parents=True, exist_ok=True + ) + + # start file watcher for hot reload + file_watcher = ExtensionFileWatcher(registry, template_registry) + file_watcher.start() + # patch langfuse bug before enabling it _patch_langfuse_usage_bug() @@ -97,7 +112,8 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: litellm.callbacks = [UsageTracker.callback] yield - # close storage connection on shutdown + + file_watcher.stop() await storage.close() @@ -480,7 +496,7 @@ async def download_export( @api_router.get("/blocks") -async def list_blocks() -> list[dict[str, Any]]: +async def list_blocks() -> list[BlockInfo]: """list all registered blocks with dynamically injected model options""" blocks = registry.list_blocks() @@ -492,16 +508,13 @@ async def list_blocks() -> list[dict[str, Any]]: # inject model options into block schemas for block in blocks: - block_type = block.get("type") - props = block.get("config_schema", {}).get("properties", {}) + props = block.config_schema.get("properties", {}) - # inject LLM model options - if block_type in ["TextGenerator", "StructuredGenerator", "RagasMetrics"]: + if block.type in ["TextGenerator", "StructuredGenerator", "RagasMetrics"]: if "model" in props: props["model"]["enum"] = model_names - # inject embedding model options for RagasMetrics - if block_type == "RagasMetrics": + if block.type == "RagasMetrics": if "embedding_model" in props: props["embedding_model"]["enum"] = embedding_names @@ -788,7 +801,7 @@ async def test_embedding_connection( @api_router.get("/templates") -async def list_templates() -> list[dict[str, Any]]: +async def list_templates() -> list[TemplateInfo]: """List all available pipeline templates""" return template_registry.list_templates() @@ -808,8 +821,9 @@ async def create_pipeline_from_template(template_id: str) -> dict[str, Any]: return {"id": pipeline_id, "name": pipeline_name, "template_id": template_id} -# include api router with /api prefix +# include api routers app.include_router(api_router, prefix="/api") +app.include_router(extensions_router, prefix="/api") # serve frontend (built react app) frontend_dir = Path("frontend/build") diff --git a/docker/Dockerfile b/docker/Dockerfile index 29607a4..7486202 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -61,8 +61,8 @@ COPY --from=builder /app/pyproject.toml /app/pyproject.toml # Copy built frontend from frontend-builder COPY --from=frontend-builder /app/build /app/frontend/build -# Create data directory and ensure custom blocks directory exists -RUN mkdir -p /app/data /app/lib/blocks/custom +# Create data directory, custom blocks, and user extension directories +RUN mkdir -p /app/data /app/lib/blocks/custom /app/user_blocks /app/user_templates # Set environment variables ENV PATH="/app/.venv/bin:$PATH" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 08610a2..01bd850 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,18 +1,22 @@ services: - backend: + datagenflow: build: context: .. dockerfile: docker/Dockerfile + image: datagenflow:local container_name: datagenflow ports: - "8000:8000" env_file: - .env + environment: + - DATAGENFLOW_HOT_RELOAD=true volumes: - # Mount data directory for persistence + # data persistence - ./data:/app/data - # Mount custom blocks directory for hot-reloading (restart required for new blocks) - - ./lib/blocks/custom:/app/lib/blocks/custom + # user extensions (create these directories in your project) + - ./user_blocks:/app/user_blocks + - ./user_templates:/app/user_templates restart: unless-stopped healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] @@ -21,6 +25,3 @@ services: retries: 3 start_period: 40s -volumes: - data: - diff --git a/docs/extensibility.md b/docs/extensibility.md new file mode 100644 index 0000000..597f33a --- /dev/null +++ b/docs/extensibility.md @@ -0,0 +1,389 @@ +--- +title: Extensibility System +description: Use DataGenFlow without cloning it — add custom blocks and templates from your own repo +--- + +# Extensibility System + +DataGenFlow's extensibility system lets engineers consume DataGenFlow as a Docker image and maintain custom blocks and templates in their own repositories. + +## Table of Contents +- [Overview](#overview) +- [Quick Start](#quick-start) +- [Writing Custom Blocks](#writing-custom-blocks) + - [Block with Dependencies](#block-with-dependencies) + - [Block Discovery](#block-discovery) +- [Writing Custom Templates](#writing-custom-templates) +- [CLI Reference](#cli-reference) + - [Status](#status) + - [Blocks Commands](#blocks-commands) + - [Templates Commands](#templates-commands) + - [Image Commands](#image-commands) + - [Configuration](#configuration) +- [Hot Reload](#hot-reload) +- [Extensions API](#extensions-api) +- [Extensions Page](#extensions-page) +- [Docker Setup](#docker-setup) +- [Building Custom Images](#building-custom-images) +- [Troubleshooting](#troubleshooting) + +## Overview + +Engineers clone the DataGenFlow repository once, then: + +1. Build the Docker image locally +2. Mount custom `user_blocks/` and `user_templates/` directories +3. Manage extensions with the `dgf` CLI or the Extensions UI page + +```text +your-repo/ + user_blocks/ + sentiment_analyzer.py + translator.py + user_templates/ + my_qa_pipeline.yaml + docker-compose.yml + .env +``` + +The system provides: +- **Block registry** with source tracking (`builtin`, `custom`, `user`) +- **Dependency declaration** via class attribute on blocks +- **Hot reload** via file watcher (watchdog) with 500ms debounce +- **CLI tool** (`dgf`) for managing blocks, templates, and images +- **Extensions page** in the frontend showing all blocks and templates with status + +## Quick Start + +```bash +# 1. clone DataGenFlow +git clone https://github.com/your-org/DataGenFlow.git +cd DataGenFlow + +# 2. build the Docker image +docker build -f docker/Dockerfile -t datagenflow:local . + +# 3. create your project directory +mkdir -p my-project/user_blocks my-project/user_templates my-project/data +cd my-project + +# 4. create docker-compose.yml (see Docker Setup section) + +# 5. start DataGenFlow +docker-compose up -d + +# 6. scaffold a block +cd ../DataGenFlow && uv run dgf blocks scaffold SentimentAnalyzer -c validators +mv sentiment_analyzer.py ../my-project/user_blocks/ + +# 7. check it's registered +uv run dgf blocks list + +# 8. open the Extensions page in the UI +open http://localhost:8000/extensions +``` + +## Writing Custom Blocks + +Custom blocks follow the same `BaseBlock` interface as builtin blocks. See [How to Create Custom Blocks](how_to_create_blocks) for the full guide. + +### Block with Dependencies + +Blocks can declare pip dependencies via a `dependencies` class attribute. Missing dependencies are detected at registration time, and the block appears as "unavailable" in the UI with an actionable error. + +```python +from lib.blocks.base import BaseBlock +from lib.entities.block_execution_context import BlockExecutionContext +from typing import Any + + +class SentimentAnalyzer(BaseBlock): + name = "Sentiment Analyzer" + description = "Analyze text sentiment using transformers" + category = "validators" + inputs = ["text"] + outputs = ["sentiment", "confidence"] + + # declare pip dependencies + dependencies = ["transformers>=4.30.0", "torch>=2.0.0"] + + def __init__(self, model: str = "distilbert-base-uncased"): + self.model = model + self._pipeline = None + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + if self._pipeline is None: + from transformers import pipeline + self._pipeline = pipeline("sentiment-analysis", model=self.model) + + text = context.get_state("text", "") + result = self._pipeline(text)[0] + + return { + "sentiment": result["label"], + "confidence": result["score"], + } +``` + +Install missing dependencies via CLI or the Extensions page: + +```bash +dgf blocks list # see which blocks are unavailable +# POST /api/extensions/blocks/SentimentAnalyzer/install-deps +``` + +### Block Discovery + +Blocks are discovered from three directories: + +| Directory | Source Label | Purpose | +|-----------|-------------|---------| +| `lib/blocks/builtin/` | `builtin` | Ships with DataGenFlow | +| `lib/blocks/custom/` | `custom` | Project-specific blocks | +| `user_blocks/` | `user` | User-mounted blocks (extensibility) | + +Any `.py` file (not starting with `_`) containing a `BaseBlock` subclass is auto-discovered. The `user_blocks/` path is configurable via the `DATAGENFLOW_BLOCKS_PATH` environment variable. + +## Writing Custom Templates + +Templates are YAML files that define pre-configured pipelines. + +```yaml +name: "My QA Pipeline" +description: "Generate question-answer pairs from content" + +blocks: + - type: TextGenerator + config: + model: "gpt-4o-mini" + user_prompt: | + Generate a question-answer pair from: + {{ content }} +``` + +Place templates in `user_templates/` (or the path set by `DATAGENFLOW_TEMPLATES_PATH`). They appear in the Templates section of the UI and CLI. + +> **Note:** If a user template has the same ID (filename stem) as a builtin template, the builtin takes precedence and the user template is skipped. + +## CLI Reference + +The `dgf` CLI is included in the DataGenFlow repository. Run it with `uv`: + +```bash +cd /path/to/DataGenFlow +uv run dgf --help +``` + +Or install globally (requires the repo to be cloned): + +```bash +cd /path/to/DataGenFlow +uv pip install -e . +dgf --help +``` + +### Status + +```bash +dgf status +``` + +Shows server health, block counts, template counts, and hot reload status. + +### Blocks Commands + +```bash +dgf blocks list # list all blocks with status and source +dgf blocks validate ./my_block.py # check syntax and find block classes +dgf blocks scaffold MyBlock -c general # generate a starter block file +``` + +### Templates Commands + +```bash +dgf templates list # list all templates with source +dgf templates validate ./flow.yaml # check YAML structure and required fields +dgf templates scaffold "My Flow" # generate a starter template YAML +``` + +### Image Commands + +```bash +dgf image scaffold --blocks-dir ./user_blocks # generate Dockerfile with deps +dgf image build -t my-datagenflow:latest # build custom Docker image +``` + +The scaffold command parses `dependencies` attributes from block files and generates a `Dockerfile.custom` with the right `uv pip install` commands. + +### Configuration + +```bash +dgf configure --show # show current endpoint +dgf configure --endpoint https://my-server:8000 +``` + +Configuration resolution order: +1. `DATAGENFLOW_ENDPOINT` environment variable (highest priority) +2. `.env` file in current directory +3. Default: `http://localhost:8000` + +## Hot Reload + +The file watcher monitors `user_blocks/` and `user_templates/` for changes. When a file is created, modified, or deleted: + +- **Blocks**: The block registry re-scans all directories +- **Templates**: The specific template is registered or unregistered + +Events are debounced at 500ms (configurable via `DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS`) to handle rapid saves. + +| Environment Variable | Default | Description | +|---------------------|---------|-------------| +| `DATAGENFLOW_HOT_RELOAD` | `true` | Enable/disable file watching | +| `DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS` | `500` | Debounce interval in milliseconds | + +> **Tip:** Set `DATAGENFLOW_HOT_RELOAD=false` in production to avoid unnecessary file system overhead. + +## Extensions API + +All extension endpoints live under `/api/extensions/`. + +| Method | Endpoint | Description | +|--------|----------|-------------| +| `GET` | `/api/extensions/status` | Block/template counts by source | +| `GET` | `/api/extensions/blocks` | List all blocks with source and availability | +| `GET` | `/api/extensions/templates` | List all templates with source | +| `POST` | `/api/extensions/reload` | Trigger manual reload of all extensions | +| `POST` | `/api/extensions/blocks/{name}/validate` | Validate block availability and dependencies | +| `GET` | `/api/extensions/blocks/{name}/dependencies` | Get dependency info for a block | +| `POST` | `/api/extensions/blocks/{name}/install-deps` | Install missing dependencies via uv | + +**Example response** — `GET /api/extensions/status`: + +```json +{ + "blocks": { + "total": 14, + "builtin_blocks": 12, + "custom_blocks": 0, + "user_blocks": 2, + "available": 13, + "unavailable": 1 + }, + "templates": { + "total": 6, + "builtin_templates": 4, + "user_templates": 2 + } +} +``` + +## Extensions Page + +The Extensions page (`/extensions`) in the frontend shows: + +- **Status cards** with block and template counts by source +- **Block list** with availability status, source badges, and dependency info. Unavailable blocks show a red border, error message, and an "Install Deps" button. +- **Template list** with source badges and a **"Create Pipeline"** button that creates a pipeline from the template and navigates to `/pipelines` +- **Reload button** to trigger a manual re-scan of all extension directories + +## Docker Setup + +### Building the Image + +```bash +# from DataGenFlow repository root +docker build -f docker/Dockerfile -t datagenflow:local . +``` + +### docker-compose.yml for Your Project + +Create this in your project directory (outside DataGenFlow): + +```yaml +services: + datagenflow: + image: datagenflow:local + ports: + - "8000:8000" + volumes: + - ./user_blocks:/app/user_blocks + - ./user_templates:/app/user_templates + - ./data:/app/data + env_file: + - .env + environment: + - DATAGENFLOW_HOT_RELOAD=true + restart: unless-stopped +``` + +### Environment Variables + +Create a `.env` file: + +```bash +# Required: LLM provider API key +LLM_API_KEY=your-api-key + +# Optional: endpoint for dgf CLI +DATAGENFLOW_ENDPOINT=http://localhost:8000 + +# Optional: hot reload settings +DATAGENFLOW_HOT_RELOAD=true +DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS=500 +``` + +**All extensibility variables:** + +| Variable | Default | Description | +|----------|---------|-------------| +| `DATAGENFLOW_ENDPOINT` | `http://localhost:8000` | API endpoint (for CLI) | +| `DATAGENFLOW_BLOCKS_PATH` | `user_blocks` | Path to user blocks directory | +| `DATAGENFLOW_TEMPLATES_PATH` | `user_templates` | Path to user templates directory | +| `DATAGENFLOW_HOT_RELOAD` | `true` | Enable file watching | + +## Building Custom Images + +For production, pre-bake dependencies into the image: + +```bash +# 1. generate Dockerfile with dependencies from your blocks +cd /path/to/DataGenFlow +uv run dgf image scaffold --blocks-dir /path/to/my-project/user_blocks -o /path/to/my-project/Dockerfile.custom + +# 2. build the custom image (from DataGenFlow repo root) +docker build -f /path/to/my-project/Dockerfile.custom -t my-datagenflow:latest . + +# 3. update docker-compose.yml to use new image +# image: my-datagenflow:latest +``` + +The generated Dockerfile builds from source and runs `uv pip install` for all declared dependencies. + +## Troubleshooting + +### Block not appearing in UI + +- **Cause**: File not in a discovered directory, or class doesn't inherit from `BaseBlock` +- **Fix**: Verify the file is in `user_blocks/`, the filename doesn't start with `_`, and the class inherits from `BaseBlock` + +### Block shows as unavailable + +Two sub-cases: + +1. **Import succeeded but runtime deps are missing** — `dependencies` attribute is readable, `GET /dependencies` lists them, `POST /install-deps` installs and reloads automatically. +2. **Import itself failed** (syntax error, missing module) — `block_class` is `None`, so `/dependencies` and `/install-deps` both return `422` with the import error. Fix the source file (or install the missing module), then trigger a reload via `POST /api/extensions/reload`. Once the class loads successfully the block becomes available. + +### Hot reload not working + +- **Cause**: `DATAGENFLOW_HOT_RELOAD=false` or directory doesn't exist at startup +- **Fix**: Check the environment variable and ensure `user_blocks/` and `user_templates/` exist before the server starts + +### CLI cannot connect + +- **Cause**: Wrong endpoint or server not running +- **Fix**: Run `dgf configure --show` to check the endpoint, then `dgf status` to test connectivity + +### User template ignored + +- **Cause**: Template ID (filename stem) conflicts with a builtin template +- **Fix**: Rename the template file to avoid the collision. Check server logs for "skipped: conflicts with builtin" warnings. diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index ca08629..b102a95 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -8,11 +8,13 @@ import { ChecklistIcon, WorkflowIcon, GearIcon, + PackageIcon, } from "@primer/octicons-react"; import Generator from "./pages/Generator"; import Review from "./pages/Review"; import Pipelines from "./pages/Pipelines"; import Settings from "./pages/Settings"; +import Extensions from "./pages/Extensions"; import GlobalJobIndicator from "./components/GlobalJobIndicator"; import { JobProvider } from "./contexts/JobContext"; import { useTheme as shadcnUseTheme, ThemeProvider as ShadcnThemeProvider } from "next-themes"; @@ -54,6 +56,7 @@ function Navigation() { { path: "/pipelines", label: "Pipelines", icon: WorkflowIcon }, { path: "/", label: "Generator", icon: BeakerIcon }, { path: "/review", label: "Review", icon: ChecklistIcon }, + { path: "/extensions", label: "Extensions", icon: PackageIcon }, { path: "/settings", label: "Settings", icon: GearIcon }, ]; @@ -147,6 +150,7 @@ function Navigation() { } /> } /> } /> + } /> } /> diff --git a/frontend/src/components/pipeline-editor/BlockPalette.tsx b/frontend/src/components/pipeline-editor/BlockPalette.tsx index 6bc1ba8..05c4a4e 100644 --- a/frontend/src/components/pipeline-editor/BlockPalette.tsx +++ b/frontend/src/components/pipeline-editor/BlockPalette.tsx @@ -1,5 +1,5 @@ import { useState, useMemo } from "react"; -import { Box, Text, TextInput } from "@primer/react"; +import { Box, Text, TextInput, Tooltip } from "@primer/react"; import { SearchIcon, ChevronDownIcon, ChevronRightIcon } from "@primer/octicons-react"; interface Block { @@ -10,6 +10,9 @@ interface Block { outputs: string[]; config_schema: Record; category: string; + source?: string; + available?: boolean; + error?: string | null; } interface BlockPaletteProps { @@ -185,46 +188,71 @@ export default function BlockPalette({ blocks }: BlockPaletteProps) { {/* Block Items */} {!isCollapsed && ( - {blocks.map((block) => ( - ) => - onDragStart(e, block.type) - } - sx={{ - display: "flex", - alignItems: "center", - gap: 1, - p: 2, - mb: 1, - borderRadius: 1, - bg: "canvas.subtle", - cursor: "grab", - borderLeft: "2px solid", - borderColor: info.color, - "&:hover": { - bg: "accent.subtle", - }, - "&:active": { - cursor: "grabbing", - }, - }} - > - {/* {info.icon} */} - { + const isAvailable = block.available !== false; + const blockItem = ( + ) => onDragStart(e, block.type) + : undefined + } sx={{ - fontSize: "13px", - color: "fg.default", - whiteSpace: "nowrap", - overflow: "hidden", - textOverflow: "ellipsis", + display: "flex", + alignItems: "center", + gap: 1, + p: 2, + mb: 1, + borderRadius: 1, + bg: isAvailable ? "canvas.subtle" : "neutral.subtle", + cursor: isAvailable ? "grab" : "not-allowed", + borderLeft: "2px solid", + borderColor: isAvailable ? info.color : "border.muted", + opacity: isAvailable ? 1 : 0.5, + "&:hover": isAvailable ? { bg: "accent.subtle" } : {}, + "&:active": isAvailable ? { cursor: "grabbing" } : {}, }} > - {block.name} - - - ))} + + {block.name} + + {block.source && block.source !== "builtin" && ( + + {block.source} + + )} + + ); + + if (!isAvailable && block.error) { + return ( + + {blockItem} + + ); + } + + return blockItem; + })} )} diff --git a/frontend/src/pages/Extensions.tsx b/frontend/src/pages/Extensions.tsx new file mode 100644 index 0000000..96f376d --- /dev/null +++ b/frontend/src/pages/Extensions.tsx @@ -0,0 +1,362 @@ +import { useCallback, useEffect, useState } from "react"; +import { Box, Heading, Text, Button, Spinner, Label } from "@primer/react"; +import { + SyncIcon, + CheckCircleFillIcon, + XCircleFillIcon, + PackageIcon, + CheckIcon, + DownloadIcon, + PlusIcon, +} from "@primer/octicons-react"; +import { toast } from "sonner"; +import { useNavigate } from "react-router-dom"; +import type { BlockInfo, TemplateInfo, ExtensionsStatus } from "../types"; +import { extensionsApi } from "../services/extensionsApi"; + +export default function Extensions() { + const [status, setStatus] = useState(null); + const [blocks, setBlocks] = useState([]); + const [templates, setTemplates] = useState([]); + const [loading, setLoading] = useState(true); + const [reloading, setReloading] = useState(false); + + const loadAll = useCallback(async () => { + try { + const [s, b, t] = await Promise.all([ + extensionsApi.getStatus(), + extensionsApi.listBlocks(), + extensionsApi.listTemplates(), + ]); + setStatus(s); + setBlocks(b); + setTemplates(t); + } catch (error) { + console.error("failed to load extensions:", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to load extensions: ${message}`); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + loadAll(); + }, [loadAll]); + + const handleReload = useCallback(async () => { + setReloading(true); + try { + await extensionsApi.reload(); + await loadAll(); + toast.success("Extensions reloaded"); + } catch (error) { + console.error("failed to reload extensions:", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to reload: ${message}`); + } finally { + setReloading(false); + } + }, [loadAll]); + + if (loading) { + return ( + + + + ); + } + + return ( + + + Extensions + + + + {/* status overview */} + {status && ( + + + + + )} + + {/* blocks section */} + + + Blocks + + + {blocks.map((block) => ( + + ))} + + + + {/* templates section */} + + + Templates + + + {templates.map((tmpl) => ( + + ))} + + + + ); +} + +function StatusCard({ + title, + items, + available, + unavailable, +}: { + title: string; + items: { label: string; value: number }[]; + available: number; + unavailable: number; +}) { + return ( + + + {title} + + + {items.map((item) => ( + + + {item.value} + + {item.label} + + ))} + + + + + {available} available + + {unavailable > 0 && ( + + + {unavailable} unavailable + + )} + + + ); +} + +function BlockCard({ block, onReload }: { block: BlockInfo; onReload: () => Promise }) { + const [validating, setValidating] = useState(false); + const [installing, setInstalling] = useState(false); + + const handleValidate = async () => { + setValidating(true); + try { + const result = await extensionsApi.validateBlock(block.type); + if (result.valid) { + toast.success(`${block.name} is valid`); + } else { + toast.error(`${block.name} validation failed: ${result.error || "unknown error"}`); + } + } catch (error) { + console.error("block validation failed:", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Validation error: ${message}`); + } finally { + setValidating(false); + } + }; + + const handleInstallDeps = async () => { + setInstalling(true); + try { + const result = await extensionsApi.installBlockDeps(block.type); + toast.success(`Installed: ${result.installed.join(", ") || "all deps satisfied"}`); + await onReload(); + } catch (error) { + console.error("dependency install failed:", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Install failed: ${message}`); + } finally { + setInstalling(false); + } + }; + + return ( + + + + {block.name} + + + + + {block.type} + + {!block.available && ( + + )} + + + + {block.description} + + {!block.available && block.error && ( + + {block.error} + + )} + {block.dependencies?.length > 0 && ( + + + {block.dependencies?.join(", ")} + + )} + + ); +} + +function TemplateCard({ template }: { template: TemplateInfo }) { + const navigate = useNavigate(); + const [creating, setCreating] = useState(false); + + const handleCreatePipeline = async () => { + setCreating(true); + try { + await extensionsApi.createPipelineFromTemplate(template.id); + toast.success("Pipeline created from template"); + navigate("/pipelines"); + } catch (error) { + console.error("failed to create pipeline from template:", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to create pipeline: ${message}`); + } finally { + setCreating(false); + } + }; + + return ( + + + + {template.name} + + + + + + {template.description} + + + ); +} + +function SourceBadge({ source }: { source: string }) { + const variants: Record = { + builtin: { bg: "accent.subtle", color: "accent.fg" }, + custom: { bg: "attention.subtle", color: "attention.fg" }, + user: { bg: "done.subtle", color: "done.fg" }, + }; + const style = variants[source] || variants.builtin; + + return ( + + {source} + + ); +} diff --git a/frontend/src/services/extensionsApi.ts b/frontend/src/services/extensionsApi.ts new file mode 100644 index 0000000..62979c9 --- /dev/null +++ b/frontend/src/services/extensionsApi.ts @@ -0,0 +1,70 @@ +import type { BlockInfo, TemplateInfo, ExtensionsStatus } from "../types"; + +const API_BASE = "/api"; + +class ExtensionsApi { + async getStatus(): Promise { + const response = await fetch(`${API_BASE}/extensions/status`); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } + + async listBlocks(): Promise { + const response = await fetch(`${API_BASE}/extensions/blocks`); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } + + async listTemplates(): Promise { + const response = await fetch(`${API_BASE}/extensions/templates`); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } + + async reload(): Promise<{ status: string; message: string }> { + const response = await fetch(`${API_BASE}/extensions/reload`, { method: "POST" }); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } + + async validateBlock(name: string): Promise<{ valid: boolean; block: string; error?: string }> { + const response = await fetch( + `${API_BASE}/extensions/blocks/${encodeURIComponent(name)}/validate`, + { + method: "POST", + } + ); + if (!response.ok) throw new Error(`http ${response.status}`); + return response.json(); + } + + async createPipelineFromTemplate(templateId: string): Promise { + const response = await fetch( + `${API_BASE}/pipelines/from_template/${encodeURIComponent(templateId)}`, + { method: "POST" } + ); + if (!response.ok) throw new Error(`http ${response.status}`); + } + + async installBlockDeps(name: string): Promise<{ status: string; installed: string[] }> { + const response = await fetch( + `${API_BASE}/extensions/blocks/${encodeURIComponent(name)}/install-deps`, + { + method: "POST", + } + ); + if (!response.ok) { + let detail = `http ${response.status}`; + try { + const error = await response.json(); + detail = error.detail || detail; + } catch { + // response body not JSON + } + throw new Error(detail); + } + return response.json(); + } +} + +export const extensionsApi = new ExtensionsApi(); diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 2ee19fb..df9a072 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -123,3 +123,65 @@ export interface ConnectionTestResult { message: string; latency_ms: number | null; } + +// extensions system types + +export type JsonValue = + | string + | number + | boolean + | null + | JsonValue[] + | { [key: string]: JsonValue }; + +export interface JsonSchemaObject { + properties?: Record; + required?: string[]; +} + +export interface BlockInfo { + type: string; + name: string; + description: string; + category: string; + inputs: string[]; + outputs: string[]; + config_schema: JsonSchemaObject; + is_multiplier: boolean; + dependencies: string[]; + source: string; + available: boolean; + error: string | null; +} + +export interface TemplateInfo { + id: string; + name: string; + description: string; + example_seed?: JsonValue; + source: string; +} + +export interface ExtensionsStatus { + blocks: { + total: number; + builtin_blocks: number; + custom_blocks: number; + user_blocks: number; + available: number; + unavailable: number; + }; + templates: { + total: number; + builtin_templates: number; + user_templates: number; + }; +} + +export interface DependencyInfo { + requirement: string; + name: string; + installed_version: string | null; + status: string; + error: string | null; +} diff --git a/lib/api/__init__.py b/lib/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/api/extensions.py b/lib/api/extensions.py new file mode 100644 index 0000000..74690f6 --- /dev/null +++ b/lib/api/extensions.py @@ -0,0 +1,121 @@ +from typing import Any + +from fastapi import APIRouter, HTTPException + +from lib.blocks.registry import registry +from lib.dependency_manager import DependencyError, dependency_manager +from lib.entities.extensions import ( + BlockInfo, + BlocksStatus, + DependencyInfo, + ExtensionsStatus, + TemplateInfo, + TemplatesStatus, +) +from lib.templates import template_registry + +router = APIRouter(prefix="/extensions", tags=["extensions"]) + + +@router.get("/status") +async def extensions_status() -> ExtensionsStatus: + blocks = registry.list_blocks() + templates = template_registry.list_templates() + + builtin_count = custom_count = user_count = available_count = 0 + for b in blocks: + if b.source == "builtin": + builtin_count += 1 + elif b.source == "custom": + custom_count += 1 + elif b.source == "user": + user_count += 1 + if b.available: + available_count += 1 + + return ExtensionsStatus( + blocks=BlocksStatus( + total=len(blocks), + builtin_blocks=builtin_count, + custom_blocks=custom_count, + user_blocks=user_count, + available=available_count, + unavailable=len(blocks) - available_count, + ), + templates=TemplatesStatus( + total=len(templates), + builtin_templates=sum(1 for t in templates if t.source == "builtin"), + user_templates=sum(1 for t in templates if t.source == "user"), + ), + ) + + +@router.get("/blocks") +async def extensions_blocks() -> list[BlockInfo]: + return registry.list_blocks() + + +@router.get("/templates") +async def extensions_templates() -> list[TemplateInfo]: + return template_registry.list_templates() + + +@router.post("/reload") +async def reload_extensions() -> dict[str, str]: + """manually trigger extension reload""" + registry.reload() + template_registry.reload() + return {"status": "ok", "message": "Extensions reloaded"} + + +@router.post("/blocks/{name}/validate") +async def validate_block(name: str) -> dict[str, Any]: + """validate a block's availability and dependencies""" + block_class = registry.get_block_class(name) + if block_class is None: + entry = registry.get_entry(name) + if entry and not entry.available: + return {"valid": False, "block": name, "error": entry.error} + raise HTTPException(status_code=404, detail=f"Block '{name}' not found") + + missing = dependency_manager.check_missing(block_class.dependencies) + if missing: + return {"valid": False, "block": name, "missing_dependencies": missing} + return {"valid": True, "block": name} + + +@router.get("/blocks/{name}/dependencies") +async def block_dependencies(name: str) -> list[DependencyInfo]: + """get dependency info for a block""" + entry = registry.get_entry(name) + if entry is None: + raise HTTPException(status_code=404, detail=f"Block '{name}' not found") + if entry.block_class is None: + raise HTTPException( + status_code=422, detail=f"Block '{name}' failed to import — dependencies unknown" + ) + return dependency_manager.get_dependency_info(entry.block_class.dependencies) + + +@router.post("/blocks/{name}/install-deps") +async def install_block_deps(name: str) -> dict[str, Any]: + """install missing dependencies for a block (works for unavailable blocks too)""" + entry = registry.get_entry(name) + if entry is None: + raise HTTPException(status_code=404, detail=f"Block '{name}' not found") + if entry.block_class is None: + raise HTTPException( + status_code=422, detail=f"Block '{name}' failed to import — dependencies unknown" + ) + + deps = entry.block_class.dependencies + missing = dependency_manager.check_missing(deps) + if not missing: + return {"status": "ok", "installed": [], "message": "All dependencies already installed"} + + try: + installed = await dependency_manager.install(missing) + registry.reload() + except DependencyError as e: + raise HTTPException(status_code=500, detail=str(e)) from e + return {"status": "ok", "installed": installed} diff --git a/lib/blocks/base.py b/lib/blocks/base.py index c1a8a4a..21351a6 100644 --- a/lib/blocks/base.py +++ b/lib/blocks/base.py @@ -11,6 +11,7 @@ class BaseBlock(ABC): category: str = "general" inputs: list[str] = [] outputs: list[str] = [] + dependencies: list[str] = [] @abstractmethod async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: @@ -40,6 +41,7 @@ def get_schema(cls) -> dict[str, Any]: "outputs": cls.outputs, "config_schema": cls.get_config_schema(), "is_multiplier": getattr(cls, "is_multiplier", False), + "dependencies": cls.dependencies, } diff --git a/lib/blocks/builtin/structure_sampler.py b/lib/blocks/builtin/structure_sampler.py index 66b20f5..bfbcb8c 100644 --- a/lib/blocks/builtin/structure_sampler.py +++ b/lib/blocks/builtin/structure_sampler.py @@ -74,6 +74,7 @@ def __init__( ) self.seed = seed self._rng = random.Random(seed) + self._field_deps: dict[str, list[str]] = {} def _validate_samples(self, samples: list[dict[str, Any]]) -> None: """validate samples meet minimum requirements""" @@ -109,7 +110,7 @@ def _compute_conditional_probabilities( ) -> dict[str, dict[str, float]]: """compute conditional probabilities for dependent fields""" conditional_probs = {} - for child_field, parent_fields in self.dependencies.items(): + for child_field, parent_fields in self._field_deps.items(): if child_field not in self.categorical_fields: continue @@ -194,7 +195,7 @@ def _topological_sort(self, fields: list[str]) -> list[str]: """ # build in-degree map in_degree = {field: 0 for field in fields} - for child_field, parent_fields in self.dependencies.items(): + for child_field, parent_fields in self._field_deps.items(): if child_field in in_degree: in_degree[child_field] = len(parent_fields) @@ -209,7 +210,7 @@ def _topological_sort(self, fields: list[str]) -> list[str]: if not no_deps: raise ValidationError( "Circular dependency detected in field dependencies", - detail={"dependencies": self.dependencies}, + detail={"dependencies": self._field_deps}, ) # add to result @@ -218,7 +219,7 @@ def _topological_sort(self, fields: list[str]) -> list[str]: # decrease in-degree for children for field in no_deps: - for child_field, parent_fields in self.dependencies.items(): + for child_field, parent_fields in self._field_deps.items(): if field in parent_fields and child_field in remaining: in_degree[child_field] -= 1 @@ -237,9 +238,9 @@ def _sample_categorical_field( self, field: str, skeleton: dict[str, Any], profile: dict[str, Any] ) -> Any: """sample value for a single categorical field, respecting dependencies""" - if field in self.dependencies: + if field in self._field_deps: # conditional sampling based on parent values - parent_fields = self.dependencies[field] + parent_fields = self._field_deps[field] parent_values = tuple(skeleton.get(p) for p in parent_fields) parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_values)) key = f"{field}|{parent_str}" @@ -370,7 +371,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: # store parsed values for use in methods self.categorical_fields = categorical_fields self.numeric_fields = numeric_fields - self.dependencies = dependencies + self._field_deps = dependencies # read samples from initial state samples = context.get_state("samples", []) diff --git a/lib/blocks/registry.py b/lib/blocks/registry.py index aa3d102..2330b78 100644 --- a/lib/blocks/registry.py +++ b/lib/blocks/registry.py @@ -1,29 +1,73 @@ import importlib import inspect import logging +import sys +import threading from pathlib import Path from typing import Any +from pydantic import BaseModel + from lib.blocks.base import BaseBlock, BaseMultiplierBlock +from lib.entities.extensions import BlockInfo logger = logging.getLogger(__name__) +# resolve builtin/custom paths relative to this file so they work regardless of cwd +_BLOCKS_DIR = Path(__file__).resolve().parent + +# maps (path, module_prefix) to source label +_SOURCE_MAP = { + (_BLOCKS_DIR / "builtin", "lib.blocks.builtin"): "builtin", + (_BLOCKS_DIR / "custom", "lib.blocks.custom"): "custom", + (Path("user_blocks").resolve(), "user_blocks"): "user", +} + + +class BlockEntry(BaseModel): + """internal registry entry — wraps a block class with extensibility metadata""" + + block_class: type[BaseBlock] | None = None # None when import failed + type_name: str = "" # used as fallback type when block_class is None + source: str = "builtin" + available: bool = True + error: str | None = None + + model_config = {"arbitrary_types_allowed": True} + + def to_block_info(self) -> BlockInfo: + if self.block_class is None: + return BlockInfo( + type=self.type_name, + name=self.type_name, + description="", + category="", + inputs=[], + outputs=[], + config_schema={}, + source=self.source, + available=False, + error=self.error, + ) + schema = self.block_class.get_schema() + return BlockInfo( + source=self.source, + available=self.available, + error=self.error, + **schema, + ) + class BlockRegistry: def __init__(self) -> None: - self._blocks: dict[str, type[BaseBlock]] = {} - self._discover_blocks() - - def _discover_blocks(self) -> None: - # scan lib/blocks/builtin/, lib/blocks/custom/, and user_blocks/ for block classes - scan_dirs = [ - "lib/blocks/builtin", - "lib/blocks/custom", - "user_blocks", - ] - - for blocks_dir in scan_dirs: - blocks_path = Path(blocks_dir) + self._lock = threading.Lock() + self._entries: dict[str, BlockEntry] = {} + self._entries = self._discover_blocks() + + def _discover_blocks(self) -> dict[str, BlockEntry]: + """scan all block directories and return a fresh entries dict""" + entries: dict[str, BlockEntry] = {} + for (blocks_path, module_prefix), source in _SOURCE_MAP.items(): if not blocks_path.exists(): continue @@ -31,25 +75,73 @@ def _discover_blocks(self) -> None: if py_file.name.startswith("_"): continue - module_name = f"{blocks_dir.replace('/', '.')}.{py_file.stem}" + module_name = f"{module_prefix}.{py_file.stem}" try: - module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module, inspect.isclass): - # only register classes inheriting from BaseBlock, excluding base classes - if issubclass(obj, BaseBlock) and obj not in ( - BaseBlock, - BaseMultiplierBlock, + # reload already-imported modules so file changes are picked up + module = ( + importlib.reload(sys.modules[module_name]) + if module_name in sys.modules + else importlib.import_module(module_name) + ) + for _name, obj in inspect.getmembers(module, inspect.isclass): + if ( + issubclass(obj, BaseBlock) + and obj not in (BaseBlock, BaseMultiplierBlock) + and obj.__module__ == module.__name__ ): - self._blocks[obj.__name__] = obj + entries[obj.__name__] = BlockEntry(block_class=obj, source=source) except Exception as e: - logger.warning(f"failed to load block module {module_name}: {e}") - continue + logger.exception("failed to load block module %s", module_name) + # register as unavailable so the UI can surface the failure + entries[py_file.stem] = BlockEntry( + type_name=py_file.stem, + source=source, + available=False, + error=str(e), + ) + return entries + + def register( + self, + block_class: type[BaseBlock], + source: str = "user", + available: bool = True, + error: str | None = None, + ) -> None: + with self._lock: + self._entries[block_class.__name__] = BlockEntry( + block_class=block_class, + source=source, + available=available, + error=error, + ) + + def reload(self) -> None: + """re-scan all block directories and refresh the registry. + serialized with a lock since importlib.reload is not thread-safe.""" + with self._lock: + self._entries = self._discover_blocks() + + def unregister(self, block_type: str) -> None: + with self._lock: + self._entries.pop(block_type, None) def get_block_class(self, block_type: str) -> type[BaseBlock] | None: - return self._blocks.get(block_type) + entry = self._entries.get(block_type) + return entry.block_class if entry else None + + def list_block_types(self) -> list[str]: + return list(self._entries.keys()) + + def get_entry(self, block_type: str) -> BlockEntry | None: + return self._entries.get(block_type) + + def get_block_source(self, block_type: str) -> str | None: + entry = self._entries.get(block_type) + return entry.source if entry else None - def list_blocks(self) -> list[dict[str, Any]]: - return [block_class.get_schema() for block_class in self._blocks.values()] + def list_blocks(self) -> list[BlockInfo]: + return [entry.to_block_info() for entry in self._entries.values()] def compute_accumulated_state_schema(self, blocks: list[dict[str, Any]]) -> list[str]: """ diff --git a/lib/cli/__init__.py b/lib/cli/__init__.py new file mode 100644 index 0000000..6337063 --- /dev/null +++ b/lib/cli/__init__.py @@ -0,0 +1 @@ +"""DataGenFlow CLI package.""" diff --git a/lib/cli/client.py b/lib/cli/client.py new file mode 100644 index 0000000..2dcdb2f --- /dev/null +++ b/lib/cli/client.py @@ -0,0 +1,43 @@ +"""HTTP client for DataGenFlow API.""" + +from typing import Any, cast + +import httpx + + +class DataGenFlowClient: + """thin wrapper around the DataGenFlow REST API""" + + def __init__(self, endpoint: str, timeout: float = 30.0): + self.endpoint = endpoint.rstrip("/") + self.timeout = timeout + + def _request(self, method: str, path: str, **kwargs: Any) -> Any: + url = f"{self.endpoint}{path}" + with httpx.Client(timeout=self.timeout) as client: + response = client.request(method, url, **kwargs) + response.raise_for_status() + return response.json() + + def health(self) -> dict[str, Any]: + return cast(dict[str, Any], self._request("GET", "/health")) + + def extension_status(self) -> dict[str, Any]: + return cast(dict[str, Any], self._request("GET", "/api/extensions/status")) + + def list_blocks(self) -> list[dict[str, Any]]: + return cast(list[dict[str, Any]], self._request("GET", "/api/extensions/blocks")) + + def list_templates(self) -> list[dict[str, Any]]: + return cast(list[dict[str, Any]], self._request("GET", "/api/extensions/templates")) + + def reload_extensions(self) -> dict[str, Any]: + return cast(dict[str, Any], self._request("POST", "/api/extensions/reload")) + + def validate_block(self, name: str) -> dict[str, Any]: + path = f"/api/extensions/blocks/{name}/validate" + return cast(dict[str, Any], self._request("POST", path)) + + def install_block_deps(self, name: str) -> dict[str, Any]: + path = f"/api/extensions/blocks/{name}/install-deps" + return cast(dict[str, Any], self._request("POST", path)) diff --git a/lib/cli/main.py b/lib/cli/main.py new file mode 100644 index 0000000..3c942f2 --- /dev/null +++ b/lib/cli/main.py @@ -0,0 +1,670 @@ +"""DataGenFlow CLI - Manage blocks and templates.""" + +import ast +import os +import re +import shutil +import subprocess +from pathlib import Path + +import httpx +import typer +from rich.console import Console +from rich.table import Table + +from lib.cli.client import DataGenFlowClient +from lib.constants import DEFAULT_BLOCKS_PATH, DEFAULT_TEMPLATES_PATH + +app = typer.Typer( + name="dgf", + help="DataGenFlow CLI - Manage blocks and templates", + no_args_is_help=True, +) +console = Console() + +blocks_app = typer.Typer(help="Manage custom blocks") +templates_app = typer.Typer(help="Manage pipeline templates") +image_app = typer.Typer(help="Build custom Docker images") + +app.add_typer(blocks_app, name="blocks") +app.add_typer(templates_app, name="templates") +app.add_typer(image_app, name="image") + + +def get_endpoint() -> str: + """get API endpoint from env or .env file""" + endpoint = os.getenv("DATAGENFLOW_ENDPOINT") + if endpoint: + return endpoint + + env_file = Path(".env") + if env_file.exists(): + for line in env_file.read_text().splitlines(): + if line.startswith("DATAGENFLOW_ENDPOINT="): + return line.split("=", 1)[1].strip().strip('"').strip("'") + + return "http://localhost:8000" + + +def get_client() -> DataGenFlowClient: + return DataGenFlowClient(get_endpoint()) + + +def get_user_blocks_dir() -> Path: + """resolve user blocks directory from env or default""" + return Path(os.getenv("DATAGENFLOW_BLOCKS_PATH", DEFAULT_BLOCKS_PATH)) + + +def get_user_templates_dir() -> Path: + """resolve user templates directory from env or default""" + return Path(os.getenv("DATAGENFLOW_TEMPLATES_PATH", DEFAULT_TEMPLATES_PATH)) + + +# ============ Status ============ + + +@app.command() +def status() -> None: + """Show DataGenFlow server status and extension info.""" + client = get_client() + + try: + client.health() + ext = client.extension_status() + + blocks = ext["blocks"] + templates = ext["templates"] + + console.print(f"[green]✓[/green] Server: {get_endpoint()}") + console.print( + f" Blocks: {blocks['available']} available, {blocks['unavailable']} unavailable" + ) + console.print( + f" Templates: {templates['total']} total " + f"({templates['builtin_templates']} builtin, " + f"{templates['user_templates']} user)" + ) + + except Exception as e: + console.print(f"[red]✗[/red] Cannot connect to {get_endpoint()}: {e}") + raise typer.Exit(1) + + +# ============ Blocks ============ + + +@blocks_app.command("list") +def blocks_list() -> None: + """List all registered blocks.""" + client = get_client() + blocks = client.list_blocks() + + table = Table(title="Registered Blocks") + table.add_column("Name", style="cyan") + table.add_column("Type", style="magenta") + table.add_column("Category", style="green") + table.add_column("Status") + table.add_column("Source", style="dim") + + for block in blocks: + if block.get("available", True): + block_status = "[green]✓[/green]" + else: + error = block.get("error", "unavailable") + block_status = f"[red]✗ {error[:30]}[/red]" + + table.add_row( + block.get("name", block["type"]), + block["type"], + block.get("category", "general"), + block_status, + block.get("source", "unknown"), + ) + + console.print(table) + + +@blocks_app.command("add") +def blocks_add( + file: Path = typer.Argument(..., help="Path to block Python file"), + install_deps: bool = typer.Option( + False, "--install-deps", help="Install block dependencies after adding" + ), +) -> None: + """Add a block file to the user_blocks directory.""" + if not file.exists(): + console.print(f"[red]✗[/red] File not found: {file}") + raise typer.Exit(1) + + if not file.suffix == ".py": + console.print("[red]✗[/red] Block file must be a .py file") + raise typer.Exit(1) + + try: + tree = ast.parse(file.read_text()) + except SyntaxError as e: + console.print(f"[red]✗[/red] Syntax error: {e}") + raise typer.Exit(1) + + block_names = _find_block_classes(tree) + if not block_names: + console.print("[red]✗[/red] No block classes found (must inherit from BaseBlock)") + raise typer.Exit(1) + + user_blocks_dir = get_user_blocks_dir() + user_blocks_dir.mkdir(parents=True, exist_ok=True) + + dest = user_blocks_dir / file.name + shutil.copy2(file, dest) + console.print(f"[green]✓[/green] Copied {file.name} to {user_blocks_dir}") + + client = get_client() + + # trigger server-side reload so the new block is registered + try: + client.reload_extensions() + except httpx.HTTPError as e: + console.print(f"[yellow]![/yellow] Could not trigger reload: {e}") + + for name in block_names: + try: + result = client.validate_block(name) + if result.get("valid"): + console.print(f"[green]✓[/green] Block '{name}' validated") + else: + console.print( + f"[yellow]![/yellow] Block '{name}': {result.get('error', 'unknown error')}" + ) + except httpx.HTTPError as e: + console.print(f"[yellow]![/yellow] Could not validate '{name}': {e}") + + if not install_deps: + return + + for name in block_names: + try: + console.print(f" Installing deps for '{name}'...") + client.install_block_deps(name) + console.print(f"[green]✓[/green] Dependencies installed for '{name}'") + except httpx.HTTPError as e: + console.print(f"[red]✗[/red] Failed to install deps for '{name}': {e}") + + +@blocks_app.command("remove") +def blocks_remove( + name: str = typer.Argument(..., help="Block class name to remove"), +) -> None: + """Remove a block file from the user_blocks directory.""" + user_blocks_dir = get_user_blocks_dir() + if not user_blocks_dir.exists(): + console.print(f"[red]✗[/red] User blocks directory not found: {user_blocks_dir}") + raise typer.Exit(1) + + # search all .py files for the block class + for py_file in user_blocks_dir.glob("*.py"): + try: + tree = ast.parse(py_file.read_text()) + except SyntaxError: + console.print(f"[yellow]![/yellow] Skipping {py_file.name}: syntax error") + continue + block_names = _find_block_classes(tree) + if name in block_names: + py_file.unlink() + console.print(f"[green]✓[/green] Removed {py_file.name} (contained block '{name}')") + return + + console.print(f"[red]✗[/red] Block '{name}' not found in {user_blocks_dir}") + raise typer.Exit(1) + + +@blocks_app.command("validate") +def blocks_validate( + path: Path = typer.Argument(..., help="Path to block Python file"), +) -> None: + """Validate a block file without adding it.""" + if not path.exists(): + console.print(f"[red]✗[/red] File not found: {path}") + raise typer.Exit(1) + + try: + tree = ast.parse(path.read_text()) + except SyntaxError as e: + console.print(f"[red]✗[/red] Syntax error: {e}") + raise typer.Exit(1) + + block_names = _find_block_classes(tree) + if not block_names: + console.print("[red]✗[/red] No block classes found (must inherit from BaseBlock)") + raise typer.Exit(1) + + console.print(f"[green]✓[/green] {path.name} is valid") + console.print(f" Blocks found: {', '.join(block_names)}") + + +@blocks_app.command("scaffold") +def blocks_scaffold( + name: str = typer.Argument(..., help="Block class name (e.g., SentimentAnalyzer)"), + output: Path = typer.Option(Path("."), "-o", "--output", help="Output directory"), + category: str = typer.Option("general", "-c", "--category", help="Block category"), +) -> None: + """Generate a block template file.""" + filename = re.sub(r"(?=1.0.0"] + + def __init__(self, param: str = "default"): + self.param = param + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + text = context.get_state("text", "") + result = text + return {{"result": result}} +''' + + output_path.write_text(template) + console.print(f"[green]✓[/green] Created {output_path}") + + +# ============ Templates ============ + + +@templates_app.command("list") +def templates_list() -> None: + """List all available templates.""" + client = get_client() + templates = client.list_templates() + + table = Table(title="Available Templates") + table.add_column("ID", style="cyan") + table.add_column("Name", style="magenta") + table.add_column("Description", style="green") + table.add_column("Source", style="dim") + + for tmpl in templates: + table.add_row( + tmpl["id"], + tmpl["name"], + (tmpl.get("description", "") or "")[:50], + tmpl.get("source", "unknown"), + ) + + console.print(table) + + +@templates_app.command("add") +def templates_add( + file: Path = typer.Argument(..., help="Path to template YAML file"), +) -> None: + """Add a template file to the user_templates directory.""" + if not file.exists(): + console.print(f"[red]✗[/red] File not found: {file}") + raise typer.Exit(1) + + if file.suffix not in (".yaml", ".yml"): + console.print("[red]✗[/red] Template file must be a .yaml or .yml file") + raise typer.Exit(1) + + import yaml # type: ignore[import-untyped] + + try: + with open(file) as f: + data = yaml.safe_load(f) + except yaml.YAMLError as e: + console.print(f"[red]✗[/red] Invalid YAML: {e}") + raise typer.Exit(1) + + if not isinstance(data, dict): + console.print("[red]✗[/red] Template file must contain a YAML mapping") + raise typer.Exit(1) + + errors = [] + if "name" not in data: + errors.append("Missing 'name' field") + if "blocks" not in data: + errors.append("Missing 'blocks' field") + elif not isinstance(data["blocks"], list) or len(data["blocks"]) == 0: + errors.append("'blocks' must be a non-empty list") + + if errors: + console.print(f"[red]✗[/red] {file.name} is invalid:") + for error in errors: + console.print(f" - {error}") + raise typer.Exit(1) + + user_templates_dir = get_user_templates_dir() + user_templates_dir.mkdir(parents=True, exist_ok=True) + + dest = user_templates_dir / file.name + shutil.copy2(file, dest) + console.print(f"[green]✓[/green] Added template '{data['name']}' to {user_templates_dir}") + + +@templates_app.command("remove") +def templates_remove( + template_id: str = typer.Argument(..., help="Template ID to remove"), +) -> None: + """Remove a template file from the user_templates directory.""" + user_templates_dir = get_user_templates_dir() + if not user_templates_dir.exists(): + console.print(f"[red]✗[/red] User templates directory not found: {user_templates_dir}") + raise typer.Exit(1) + + import yaml + + for yaml_file in user_templates_dir.iterdir(): + if yaml_file.suffix not in (".yaml", ".yml") or yaml_file.is_dir(): + continue + try: + with open(yaml_file) as f: + data = yaml.safe_load(f) + except (yaml.YAMLError, OSError) as e: + console.print(f"[yellow]![/yellow] Skipping {yaml_file.name}: {e}") + continue + + if not isinstance(data, dict): + console.print(f"[yellow]![/yellow] Skipping {yaml_file.name}: not a YAML mapping") + continue + + # match by explicit id field or by filename stem + file_id = data.get("id", yaml_file.stem) + if file_id != template_id: + continue + + yaml_file.unlink() + console.print(f"[green]✓[/green] Removed template '{template_id}' ({yaml_file.name})") + return + + console.print(f"[red]✗[/red] Template '{template_id}' not found in {user_templates_dir}") + raise typer.Exit(1) + + +@templates_app.command("validate") +def templates_validate( + path: Path = typer.Argument(..., help="Path to template YAML file"), +) -> None: + """Validate a template file without adding it.""" + if not path.exists(): + console.print(f"[red]✗[/red] File not found: {path}") + raise typer.Exit(1) + + import yaml + + try: + with open(path) as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + console.print("[red]✗[/red] Template file must contain a YAML mapping") + raise typer.Exit(1) + + errors = [] + if "name" not in data: + errors.append("Missing 'name' field") + if "blocks" not in data: + errors.append("Missing 'blocks' field") + elif not isinstance(data["blocks"], list): + errors.append("'blocks' must be a list") + elif len(data["blocks"]) == 0: + errors.append("'blocks' list cannot be empty") + else: + for i, block in enumerate(data["blocks"]): + if "type" not in block: + errors.append(f"Block {i} missing 'type' field") + + if errors: + console.print(f"[red]✗[/red] {path.name} is invalid:") + for error in errors: + console.print(f" - {error}") + raise typer.Exit(1) + + console.print(f"[green]✓[/green] {path.name} is valid") + console.print(f" Name: {data['name']}") + console.print(f" Blocks: {len(data['blocks'])}") + + except yaml.YAMLError as e: + console.print(f"[red]✗[/red] Invalid YAML: {e}") + raise typer.Exit(1) + + +@templates_app.command("scaffold") +def templates_scaffold( + name: str = typer.Argument(..., help="Template name"), + output: Path = typer.Option(Path("."), "-o", "--output", help="Output directory"), +) -> None: + """Generate a template YAML file.""" + filename = name.lower().replace(" ", "_") + ".yaml" + output_path = output / filename + + template = f'''name: "{name}" +description: "TODO: Add description" + +example_seed: + text: "Sample input text" + +blocks: + - type: TextGenerator + config: + model: "gpt-4o-mini" + user_prompt: | + Process the following text: + {{{{ text }}}} +''' + + output_path.write_text(template) + console.print(f"[green]✓[/green] Created {output_path}") + + +# ============ Image ============ + + +@image_app.command("scaffold") +def image_scaffold( + blocks_dir: Path = typer.Option(None, "--blocks-dir", "-b", help="Directory containing blocks"), + output: Path = typer.Option(Path("Dockerfile.custom"), "-o", "--output", help="Output path"), +) -> None: + """Generate a Dockerfile for custom image with dependencies. + + The generated Dockerfile builds from source (must be run from the + DataGenFlow repository root) and optionally installs additional + dependencies declared in user block files. + """ + deps: set[str] = set() + + if blocks_dir and blocks_dir.exists(): + for py_file in blocks_dir.glob("*.py"): + if py_file.name.startswith("_"): + continue + try: + tree = ast.parse(py_file.read_text()) + deps.update(_extract_block_deps(tree)) + except (SyntaxError, ValueError) as e: + console.print(f"[yellow]Warning:[/yellow] skipping {py_file.name}: {e}") + + # generate multi-stage dockerfile that builds from source + dockerfile = """# Custom DataGenFlow image with user blocks +# Generated by: dgf image scaffold +# Build from DataGenFlow repository root: +# docker build -f Dockerfile.custom -t my-datagenflow:latest . + +# Backend build stage +FROM python:3.10-slim AS builder + +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +WORKDIR /app + +COPY pyproject.toml uv.lock ./ +RUN uv sync --frozen --no-dev + +COPY . . + +RUN python -m compileall -b -q lib/ + +# Frontend build stage +FROM node:20-alpine AS frontend-builder + +WORKDIR /app +COPY frontend/package.json frontend/yarn.lock ./ +RUN yarn install --frozen-lockfile +COPY frontend/ ./ +RUN yarn build + +# Production stage +FROM python:3.10-slim + +WORKDIR /app + +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv +COPY --from=builder /app/.venv /app/.venv +COPY --from=builder /app/lib /app/lib +COPY --from=builder /app/app.py /app/ +COPY --from=builder /app/config.py /app/ +COPY --from=builder /app/models.py /app/ +COPY --from=builder /app/pyproject.toml /app/ +COPY --from=frontend-builder /app/build /app/frontend/build + +RUN mkdir -p /app/data /app/lib/blocks/custom /app/user_blocks /app/user_templates + +""" + + if deps: + dockerfile += "# Install custom block dependencies\n" + dockerfile += "RUN uv pip install \\\n" + dockerfile += " \\\n".join(f" {dep}" for dep in sorted(deps)) + dockerfile += "\n\n" + + dockerfile += """ENV PATH="/app/.venv/bin:$PATH" +ENV PYTHONUNBUFFERED=1 + +EXPOSE 8000 + +CMD ["uv", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] +""" + + output.write_text(dockerfile) + console.print(f"[green]✓[/green] Created {output}") + + if deps: + console.print(f" Dependencies: {len(deps)}") + for dep in sorted(deps): + console.print(f" - {dep}") + + console.print("\n[bold]Build from DataGenFlow repo root:[/bold]") + console.print(f" docker build -f {output} -t my-datagenflow:latest .") + + +@image_app.command("build") +def image_build( + dockerfile: Path = typer.Option(Path("Dockerfile.custom"), "-f", "--dockerfile"), + tag: str = typer.Option("my-datagenflow:latest", "-t", "--tag"), +) -> None: + """Build a custom Docker image.""" + if not dockerfile.exists(): + console.print(f"[red]✗[/red] Dockerfile not found: {dockerfile}") + console.print("Run 'dgf image scaffold' first") + raise typer.Exit(1) + + cmd = ["docker", "build", "-f", str(dockerfile), "-t", tag, "."] + console.print(f"Building image: {tag}") + + try: + subprocess.run(cmd, check=True) + console.print(f"\n[green]✓[/green] Successfully built {tag}") + except subprocess.CalledProcessError: + console.print("\n[red]✗[/red] Build failed") + raise typer.Exit(1) + + +# ============ Configure ============ + + +@app.command() +def configure( + endpoint: str = typer.Option(None, "--endpoint", "-e", help="DataGenFlow API endpoint"), + show: bool = typer.Option(False, "--show", "-s", help="Show current configuration"), +) -> None: + """Configure CLI settings.""" + env_file = Path(".env") + + if show or endpoint is None: + console.print("Current configuration:") + console.print(f" Endpoint: {get_endpoint()}") + return + + lines = [] + found = False + + if env_file.exists(): + for line in env_file.read_text().splitlines(): + if line.startswith("DATAGENFLOW_ENDPOINT="): + lines.append(f"DATAGENFLOW_ENDPOINT={endpoint}") + found = True + else: + lines.append(line) + + if not found: + lines.append(f"DATAGENFLOW_ENDPOINT={endpoint}") + + env_file.write_text("\n".join(lines) + "\n") + console.print(f"[green]✓[/green] Configuration saved to {env_file}") + + +# ============ Helpers ============ + + +def _extract_block_deps(tree: ast.AST) -> set[str]: + """extract dependency strings from class-level 'dependencies' list assignments""" + deps: set[str] = set() + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + for item in node.body: + if ( + isinstance(item, ast.Assign) + and any(isinstance(t, ast.Name) and t.id == "dependencies" for t in item.targets) + and isinstance(item.value, ast.List) + ): + for elt in item.value.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + deps.add(elt.value) + return deps + + +def _find_block_classes(tree: ast.AST) -> list[str]: + """extract class names that inherit from a *Block base class""" + block_names = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base in node.bases: + base_name = "" + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + base_name = base.attr + if "Block" in base_name: + block_names.append(node.name) + break + return block_names + + +if __name__ == "__main__": + app() diff --git a/lib/constants.py b/lib/constants.py index 0f6eeb4..0d0df3d 100644 --- a/lib/constants.py +++ b/lib/constants.py @@ -1,5 +1,8 @@ """shared constants for the application""" +DEFAULT_BLOCKS_PATH = "user_blocks" +DEFAULT_TEMPLATES_PATH = "user_templates" + # fields that can be updated on a record via API RECORD_UPDATABLE_FIELDS = frozenset({"output", "status", "metadata"}) diff --git a/lib/dependency_manager.py b/lib/dependency_manager.py new file mode 100644 index 0000000..8f96ad1 --- /dev/null +++ b/lib/dependency_manager.py @@ -0,0 +1,110 @@ +""" +Dependency manager for block dependencies. + +Parses, checks, and installs pip dependencies declared in block classes. +""" + +import asyncio +import importlib.metadata +import logging +import re +import subprocess +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lib.blocks.base import BaseBlock + +from lib.entities.extensions import DependencyInfo + +logger = logging.getLogger(__name__) + +# only allow valid pip package names (PEP 508) +_VALID_PACKAGE_RE = re.compile(r"^[A-Za-z0-9]([A-Za-z0-9._-]*[A-Za-z0-9])?") + + +class DependencyError(Exception): + pass + + +def _parse_package_name(requirement: str) -> str: + """extract package name from a requirement string like 'torch>=2.0.0'""" + for sep in (">=", "<=", "==", ">", "<", "[", "!=", "~="): + requirement = requirement.split(sep)[0] + return requirement.strip() + + +def _validate_requirement(req: str) -> None: + """reject requirements that look like argument injection""" + if req.startswith("-"): + raise DependencyError(f"invalid requirement (looks like a flag): {req}") + name = _parse_package_name(req) + if not _VALID_PACKAGE_RE.fullmatch(name): + raise DependencyError(f"invalid package name: {name}") + + +class DependencyManager: + def get_block_dependencies(self, block_class: type["BaseBlock"]) -> list[str]: + return getattr(block_class, "dependencies", []) + + def check_missing(self, requirements: list[str]) -> list[str]: + missing = [] + for req in requirements: + name = _parse_package_name(req) + try: + importlib.metadata.version(name) + except importlib.metadata.PackageNotFoundError: + missing.append(req) + return missing + + def get_dependency_info(self, requirements: list[str]) -> list[DependencyInfo]: + result = [] + for req in requirements: + name = _parse_package_name(req) + try: + version = importlib.metadata.version(name) + result.append( + DependencyInfo( + requirement=req, + name=name, + installed_version=version, + status="ok", + ) + ) + except importlib.metadata.PackageNotFoundError: + result.append( + DependencyInfo( + requirement=req, + name=name, + status="not_installed", + ) + ) + return result + + def _install_sync(self, requirements: list[str], timeout: int = 300) -> list[str]: + """synchronous install — run via asyncio.to_thread from async code""" + if not requirements: + return [] + + for req in requirements: + _validate_requirement(req) + + cmd = ["uv", "pip", "install", "--quiet"] + requirements + logger.info(f"installing dependencies: {requirements}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0: + raise DependencyError(f"uv pip install failed: {result.stderr}") + logger.info(f"successfully installed: {requirements}") + return requirements + except subprocess.TimeoutExpired: + raise DependencyError(f"installation timed out after {timeout}s") from None + except FileNotFoundError: + raise DependencyError("uv not found") from None + + async def install(self, requirements: list[str], timeout: int = 300) -> list[str]: + """install requirements using uv without blocking the event loop""" + return await asyncio.to_thread(self._install_sync, requirements, timeout) + + +dependency_manager = DependencyManager() diff --git a/lib/entities/extensions.py b/lib/entities/extensions.py new file mode 100644 index 0000000..32a5575 --- /dev/null +++ b/lib/entities/extensions.py @@ -0,0 +1,58 @@ +from typing import Any, Literal + +from pydantic import BaseModel + + +class BlockInfo(BaseModel): + """block schema with extensibility metadata""" + + type: str + name: str + description: str + category: str + inputs: list[str] + outputs: list[str] + config_schema: dict[str, Any] + is_multiplier: bool = False + dependencies: list[str] = [] + source: Literal["builtin", "custom", "user"] = "builtin" + available: bool = True + error: str | None = None + + +class TemplateInfo(BaseModel): + """template listing with source metadata""" + + id: str + name: str + description: str + example_seed: list[dict[str, Any]] | None = None + source: Literal["builtin", "custom", "user"] = "builtin" + + +class BlocksStatus(BaseModel): + total: int + builtin_blocks: int + custom_blocks: int + user_blocks: int + available: int + unavailable: int + + +class TemplatesStatus(BaseModel): + total: int + builtin_templates: int + user_templates: int + + +class ExtensionsStatus(BaseModel): + blocks: BlocksStatus + templates: TemplatesStatus + + +class DependencyInfo(BaseModel): + requirement: str + name: str + installed_version: str | None = None + status: str # "ok", "not_installed", "invalid" + error: str | None = None diff --git a/lib/file_watcher.py b/lib/file_watcher.py new file mode 100644 index 0000000..8094fc2 --- /dev/null +++ b/lib/file_watcher.py @@ -0,0 +1,175 @@ +""" +File watcher for hot reload of extensions. + +Monitors user_blocks/ and user_templates/ for changes +and triggers registry reload when files are added, modified, or deleted. +""" + +import logging +import os +import threading +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer + +from lib.constants import DEFAULT_BLOCKS_PATH, DEFAULT_TEMPLATES_PATH + +if TYPE_CHECKING: + from lib.blocks.registry import BlockRegistry + from lib.templates import TemplateRegistry + +logger = logging.getLogger(__name__) + + +class DebouncedHandler(FileSystemEventHandler): + """file event handler with debouncing to prevent rapid reloads""" + + def __init__( + self, + callback: Callable[[Path, str], None], + debounce_ms: int = 500, + ): + self.callback = callback + self.debounce_ms = debounce_ms + self._pending: dict[str, threading.Timer] = {} + self._lock = threading.Lock() + + def _schedule_callback(self, path: Path, event_type: str) -> None: + key = str(path) + + with self._lock: + if key in self._pending: + self._pending[key].cancel() + + timer = threading.Timer( + self.debounce_ms / 1000, + self._execute_callback, + args=(path, event_type), + ) + self._pending[key] = timer + timer.start() + + def cancel_pending(self) -> None: + with self._lock: + for timer in self._pending.values(): + timer.cancel() + self._pending.clear() + + def _execute_callback(self, path: Path, event_type: str) -> None: + with self._lock: + self._pending.pop(str(path), None) + + try: + self.callback(path, event_type) + except Exception: + logger.exception("error in file watcher callback") + + def on_created(self, event: FileSystemEvent) -> None: + if not event.is_directory: + self._schedule_callback(Path(os.fsdecode(event.src_path)), "created") + + def on_modified(self, event: FileSystemEvent) -> None: + if not event.is_directory: + self._schedule_callback(Path(os.fsdecode(event.src_path)), "modified") + + def on_deleted(self, event: FileSystemEvent) -> None: + if not event.is_directory: + self._schedule_callback(Path(os.fsdecode(event.src_path)), "deleted") + + +class BlockFileHandler(DebouncedHandler): + """handler for block file changes — triggers registry rediscovery""" + + def __init__(self, registry: "BlockRegistry", debounce_ms: int = 500): + self.registry = registry + super().__init__(self._handle_change, debounce_ms) + + def _handle_change(self, path: Path, event_type: str) -> None: + if path.suffix != ".py" or path.name.startswith("_"): + return + + logger.info("block file %s: %s", event_type, path) + self.registry.reload() + + +class TemplateFileHandler(DebouncedHandler): + """handler for template file changes""" + + def __init__(self, registry: "TemplateRegistry", user_dir: Path, debounce_ms: int = 500): + self.registry = registry + self.user_dir = user_dir + super().__init__(self._handle_change, debounce_ms) + + def _handle_change(self, path: Path, event_type: str) -> None: + if path.suffix not in (".yaml", ".yml"): + return + + logger.info("template file %s: %s", event_type, path) + # full reload is safe — uses atomic swap internally + self.registry.reload() + + +class ExtensionFileWatcher: + """watches extension directories for changes""" + + def __init__( + self, + block_registry: "BlockRegistry", + template_registry: "TemplateRegistry", + blocks_path: Path | None = None, + templates_path: Path | None = None, + ): + self.block_registry = block_registry + self.template_registry = template_registry + self.blocks_path = ( + blocks_path or Path(os.getenv("DATAGENFLOW_BLOCKS_PATH", DEFAULT_BLOCKS_PATH)).resolve() + ) + self.templates_path = ( + templates_path + or Path(os.getenv("DATAGENFLOW_TEMPLATES_PATH", DEFAULT_TEMPLATES_PATH)).resolve() + ) + self._observer: Any = None # watchdog.Observer, no stubs available + self._handlers: list[DebouncedHandler] = [] + + @property + def is_running(self) -> bool: + return self._observer is not None + + def start(self) -> None: + hot_reload = os.getenv("DATAGENFLOW_HOT_RELOAD", "true").lower() == "true" + if not hot_reload: + logger.info("hot reload disabled") + return + + self._observer = Observer() + self._handlers = [] + debounce_ms = int(os.getenv("DATAGENFLOW_HOT_RELOAD_DEBOUNCE_MS", "500")) + + if self.blocks_path.exists(): + block_handler = BlockFileHandler(self.block_registry, debounce_ms) + self._observer.schedule(block_handler, str(self.blocks_path), recursive=False) + self._handlers.append(block_handler) + logger.info("watching %s for block changes", self.blocks_path) + + if self.templates_path.exists(): + template_handler = TemplateFileHandler( + self.template_registry, self.templates_path, debounce_ms + ) + self._observer.schedule(template_handler, str(self.templates_path), recursive=False) + self._handlers.append(template_handler) + logger.info("watching %s for template changes", self.templates_path) + + self._observer.start() + logger.info("extension file watcher started") + + def stop(self) -> None: + if self._observer: + for handler in self._handlers: + handler.cancel_pending() + self._handlers = [] + self._observer.stop() + self._observer.join(timeout=5) + self._observer = None + logger.info("extension file watcher stopped") diff --git a/lib/templates/__init__.py b/lib/templates/__init__.py index 5d60954..abffcc4 100644 --- a/lib/templates/__init__.py +++ b/lib/templates/__init__.py @@ -3,25 +3,46 @@ """ import json +import logging +import threading from pathlib import Path from typing import Any import yaml # type: ignore[import-untyped] +from lib.entities.extensions import TemplateInfo + +logger = logging.getLogger(__name__) + class TemplateRegistry: - """Registry for pipeline templates""" + """Registry for pipeline templates with builtin and user template support""" - def __init__(self, templates_dir: Path | None = None): + def __init__( + self, + templates_dir: Path | None = None, + user_templates_dir: Path | None = None, + ): if templates_dir is None: templates_dir = Path(__file__).parent + if user_templates_dir is None: + user_templates_dir = Path("user_templates") self.templates_dir = templates_dir self.seeds_dir = templates_dir / "seeds" + self.user_templates_dir = user_templates_dir + self._lock = threading.Lock() self._templates: dict[str, dict[str, Any]] = {} - self._load_templates() + self._sources: dict[str, str] = {} + self._load_builtin_into(self._templates, self._sources) + if self.user_templates_dir.exists(): + self._load_user_into(self.user_templates_dir, self._templates, self._sources) - def _load_templates(self) -> None: - """load all template yaml files from templates directory""" + def _load_builtin_into( + self, + templates: dict[str, dict[str, Any]], + sources: dict[str, str], + ) -> None: + """load all template yaml files from builtin templates directory""" for template_file in self.templates_dir.glob("*.yaml"): try: with open(template_file, "r") as f: @@ -42,25 +63,86 @@ def _load_templates(self) -> None: {"repetitions": 1, "metadata": {"file_content": sf.read()}} ] - self._templates[template_id] = template_data - except Exception: - pass + templates[template_id] = template_data + sources[template_id] = "builtin" + except Exception as e: + logger.warning(f"failed to load builtin template {template_file}: {e}") + + def _load_user_into( + self, + user_dir: Path, + templates: dict[str, dict[str, Any]], + sources: dict[str, str], + ) -> None: + """load user templates, skipping ids that already exist as builtin""" + for template_file in user_dir.glob("*.yaml"): + try: + with open(template_file, "r") as f: + template_data = yaml.safe_load(f) + template_id = template_file.stem - def list_templates(self) -> list[dict[str, Any]]: + if template_id in templates: + logger.warning( + f"user template '{template_id}' skipped: conflicts with builtin" + ) + continue + + templates[template_id] = template_data + sources[template_id] = "user" + except Exception as e: + logger.warning(f"failed to load user template {template_file}: {e}") + + def register( + self, + template_id: str, + template_data: dict[str, Any], + source: str = "user", + ) -> None: + with self._lock: + self._templates[template_id] = template_data + self._sources[template_id] = source + + def unregister(self, template_id: str) -> None: + with self._lock: + self._templates.pop(template_id, None) + self._sources.pop(template_id, None) + + def reload(self) -> None: + """re-scan builtin and user template directories. + serialized with a lock to prevent concurrent partial-state reads.""" + with self._lock: + templates: dict[str, dict[str, Any]] = {} + sources: dict[str, str] = {} + self._load_builtin_into(templates, sources) + if self.user_templates_dir.exists(): + self._load_user_into(self.user_templates_dir, templates, sources) + self._templates = templates + self._sources = sources + + def list_templates(self) -> list[TemplateInfo]: """List all available templates""" + with self._lock: + templates = self._templates.copy() + sources = self._sources.copy() return [ - { - "id": template_id, - "name": template["name"], - "description": template["description"], - "example_seed": template.get("example_seed"), - } - for template_id, template in self._templates.items() + TemplateInfo( + id=template_id, + name=template.get("name", template_id), + description=template.get("description", ""), + example_seed=template.get("example_seed"), + source=sources.get(template_id, "builtin"), + ) + for template_id, template in templates.items() ] def get_template(self, template_id: str) -> dict[str, Any] | None: """Get template definition by ID""" - return self._templates.get(template_id) + with self._lock: + return self._templates.get(template_id) + + def get_template_source(self, template_id: str) -> str | None: + with self._lock: + return self._sources.get(template_id) # Singleton instance diff --git a/lib/workflow.py b/lib/workflow.py index 61ef460..30338eb 100644 --- a/lib/workflow.py +++ b/lib/workflow.py @@ -37,7 +37,7 @@ def _initialize_blocks(self) -> None: block_class = registry.get_block_class(block_type) if not block_class: - available = list(registry._blocks.keys()) + available = registry.list_block_types() raise BlockNotFoundError( f"Block '{block_type}' not found", detail={"block_type": block_type, "available_blocks": available}, diff --git a/llm/state-backend.md b/llm/state-backend.md index 817877c..d653be8 100644 --- a/llm/state-backend.md +++ b/llm/state-backend.md @@ -32,7 +32,17 @@ lib/ job_queue.py # in-memory job tracking job_processor.py # background processing + usage tracking + constraints llm_config.py # LLMConfigManager - constants.py # RECORD_UPDATABLE_FIELDS + constants.py # RECORD_UPDATABLE_FIELDS, DEFAULT_BLOCKS_PATH, DEFAULT_TEMPLATES_PATH + entities/ + extensions.py # BlockInfo, TemplateInfo, ExtensionsStatus pydantic models + api/ + extensions.py # /api/extensions/* router + blocks/ + registry.py # BlockRegistry (thread-safe, singleton) + templates/ + __init__.py # TemplateRegistry (thread-safe, singleton) + file_watcher.py # ExtensionFileWatcher + DebouncedHandler (hot reload) + dependency_manager.py # DependencyManager (uv-based pip install) app.py # endpoints + lifespan config.py # env Settings ``` @@ -46,6 +56,16 @@ config.py # env Settings ### blocks - `GET /api/blocks` - list registered blocks with schemas +### extensions +- `GET /api/extensions/status` - block/template counts + hot-reload status +- `GET /api/extensions/blocks` - list all blocks (with source, available, error) +- `GET /api/extensions/blocks/{name}` - get single block info +- `POST /api/extensions/blocks/{name}/validate` - validate block can be instantiated +- `GET /api/extensions/blocks/{name}/dependencies` - list declared pip requirements +- `POST /api/extensions/blocks/{name}/install-deps` - install missing dependencies via uv +- `GET /api/extensions/templates` - list all templates (with source) +- `POST /api/extensions/reload` - trigger hot reload of blocks + templates + ### templates - `GET /api/templates` - list pipeline templates - `POST /api/pipelines/from_template/{template_id}` - create from template diff --git a/llm/state-frontend.md b/llm/state-frontend.md index be4f406..42e2217 100644 --- a/llm/state-frontend.md +++ b/llm/state-frontend.md @@ -15,6 +15,7 @@ frontend/src/ Generator.tsx # upload + job progress + validation Review.tsx # cards + collapsible trace + job filter Settings.tsx # LLM/embedding config management + Extensions.tsx # blocks + templates status, install deps, create pipeline from template components/ GlobalJobIndicator.tsx # header job status ConfigureFieldsModal.tsx # field configuration @@ -69,6 +70,13 @@ frontend/src/ - real-time updates: 2s polling, incremental record visibility - view stability: tracks by ID, single mode preserves current record +### Extensions.tsx +- status cards: block counts by source (builtin/custom/user), template counts +- block list: availability badge, source badge (SourceBadge), validate button, install-deps button (unavailable only) +- template list: source badge, "Create Pipeline" button → POST /api/pipelines/from_template/{id} → navigate to /pipelines +- reload button: triggers manual re-scan of extension directories +- all api calls via extensionsApi service (extensionsApi.ts) + ### Settings.tsx - LLM/embedding model management via ModelCard components - provider/model selection (OpenAI, Anthropic, Ollama, etc.) diff --git a/pyproject.toml b/pyproject.toml index 6e58535..01674f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ dependencies = [ "pytest-timeout>=2.4.0", "langfuse==2.59.7", "instructor", + "watchdog>=6.0.0", + "typer>=0.9.0", + "rich>=13.0.0", ] description = "Q&A dataset generation and validation tool" name = "datagenflow" @@ -34,6 +37,9 @@ readme = "README.md" requires-python = ">=3.10" version = "0.1.0" +[project.scripts] +dgf = "lib.cli.main:app" + [tool.setuptools.packages.find] exclude = ["ui*", "data*"] include = ["lib*"] @@ -58,7 +64,7 @@ warn_unused_configs = true exclude = ["scripts/"] [[tool.mypy.overrides]] -disable_error_code = ["no-untyped-def", "no-untyped-call", "var-annotated", "override", "union-attr", "arg-type", "index", "type-arg", "unused-ignore", "import-not-found", "no-redef"] +disable_error_code = ["no-untyped-def", "no-untyped-call", "var-annotated", "override", "union-attr", "arg-type", "index", "type-arg", "unused-ignore", "import-not-found", "no-redef", "import-untyped", "type-abstract", "operator"] module = "tests.*" [[tool.mypy.overrides]] diff --git a/tests/blocks/test_base_dependencies.py b/tests/blocks/test_base_dependencies.py new file mode 100644 index 0000000..6b57c42 --- /dev/null +++ b/tests/blocks/test_base_dependencies.py @@ -0,0 +1,51 @@ +""" +Tests for BaseBlock.dependencies attribute and its inclusion in get_schema(). +""" + +from lib.blocks.base import BaseBlock + + +class NoDepsBlock(BaseBlock): + name = "No Deps" + description = "Block with no dependencies" + inputs = ["text"] + outputs = ["result"] + + async def execute(self, context): + return {"result": "ok"} + + +class WithDepsBlock(BaseBlock): + name = "With Deps" + description = "Block with dependencies" + inputs = ["text"] + outputs = ["result"] + dependencies = ["transformers>=4.30.0", "torch>=2.0.0"] + + async def execute(self, context): + return {"result": "ok"} + + +def test_base_block_has_dependencies_default(): + assert hasattr(BaseBlock, "dependencies") + assert BaseBlock.dependencies == [] + + +def test_block_without_dependencies_defaults_to_empty(): + assert NoDepsBlock.dependencies == [] + + +def test_block_with_dependencies_has_list(): + assert WithDepsBlock.dependencies == ["transformers>=4.30.0", "torch>=2.0.0"] + + +def test_get_schema_includes_dependencies(): + schema = WithDepsBlock.get_schema() + assert "dependencies" in schema + assert schema["dependencies"] == ["transformers>=4.30.0", "torch>=2.0.0"] + + +def test_get_schema_includes_empty_dependencies(): + schema = NoDepsBlock.get_schema() + assert "dependencies" in schema + assert schema["dependencies"] == [] diff --git a/tests/blocks/test_registry.py b/tests/blocks/test_registry.py index 08f92a9..3e8a77e 100644 --- a/tests/blocks/test_registry.py +++ b/tests/blocks/test_registry.py @@ -7,7 +7,7 @@ def test_registry_discovers_blocks(): blocks = registry.list_blocks() # should discover at least the core blocks - block_types = [b["type"] for b in blocks] + block_types = [b.type for b in blocks] assert "TextGenerator" in block_types assert "ValidatorBlock" in block_types assert "JSONValidatorBlock" in block_types @@ -22,3 +22,63 @@ def test_get_block_class(): invalid_class = registry.get_block_class("NonExistent") assert invalid_class is None + + +class TestBlockRegistryReload: + """tests for registry reload functionality""" + + def test_reload_method_exists(self): + """registry has reload method""" + registry = BlockRegistry() + assert hasattr(registry, "reload") + assert callable(registry.reload) + + def test_reload_preserves_block_count(self): + """reload discovers same blocks""" + registry = BlockRegistry() + initial_count = len(registry.list_blocks()) + + registry.reload() + + assert len(registry.list_blocks()) == initial_count + + def test_reload_preserves_builtin_blocks(self): + """reload keeps builtin blocks available""" + registry = BlockRegistry() + before = {b.type for b in registry.list_blocks() if b.source == "builtin"} + + registry.reload() + + after = {b.type for b in registry.list_blocks() if b.source == "builtin"} + assert before == after + + +class TestBlockRegistryGetEntry: + """tests for get_entry method""" + + def test_get_entry_returns_block_entry(self): + """get_entry returns BlockEntry for known block""" + registry = BlockRegistry() + entry = registry.get_entry("TextGenerator") + + assert entry is not None + assert entry.available is True + assert entry.source == "builtin" + assert entry.block_class.__name__ == "TextGenerator" + + def test_get_entry_returns_none_for_unknown(self): + """get_entry returns None for unknown block""" + registry = BlockRegistry() + entry = registry.get_entry("NonExistentBlock") + + assert entry is None + + def test_get_entry_has_block_class(self): + """entry contains valid block class""" + registry = BlockRegistry() + entry = registry.get_entry("ValidatorBlock") + + assert entry is not None + assert hasattr(entry.block_class, "execute") + assert hasattr(entry.block_class, "inputs") + assert hasattr(entry.block_class, "outputs") diff --git a/tests/blocks/test_registry_enhanced.py b/tests/blocks/test_registry_enhanced.py new file mode 100644 index 0000000..2fe4270 --- /dev/null +++ b/tests/blocks/test_registry_enhanced.py @@ -0,0 +1,133 @@ +""" +Tests for enhanced BlockRegistry: source tracking, register/unregister, unavailable blocks. +""" + +from lib.blocks.base import BaseBlock +from lib.blocks.registry import BlockRegistry +from lib.entities.extensions import BlockInfo + + +class FakeUserBlock(BaseBlock): + name = "Fake User Block" + description = "A user-provided block" + category = "validators" + inputs = ["text"] + outputs = ["result"] + + async def execute(self, context): + return {"result": "ok"} + + +class BlockWithDeps(BaseBlock): + name = "Block With Deps" + description = "Needs missing deps" + category = "generators" + inputs = ["text"] + outputs = ["result"] + dependencies = ["some_nonexistent_package>=1.0.0"] + + async def execute(self, context): + return {"result": "ok"} + + +# --- source tracking --- + + +def test_list_blocks_includes_source_field(): + reg = BlockRegistry() + blocks = reg.list_blocks() + for block in blocks: + assert isinstance(block, BlockInfo) + assert block.source is not None + + +def test_builtin_blocks_have_builtin_source(): + reg = BlockRegistry() + blocks = reg.list_blocks() + text_gen = next(b for b in blocks if b.type == "TextGenerator") + assert text_gen.source == "builtin" + + +def test_list_blocks_includes_available_field(): + reg = BlockRegistry() + blocks = reg.list_blocks() + for block in blocks: + assert isinstance(block.available, bool) + + +def test_builtin_blocks_are_available(): + reg = BlockRegistry() + blocks = reg.list_blocks() + text_gen = next(b for b in blocks if b.type == "TextGenerator") + assert text_gen.available is True + + +# --- register / unregister --- + + +def test_register_user_block(): + reg = BlockRegistry() + initial_count = len(reg.list_blocks()) + reg.register(FakeUserBlock, source="user") + blocks = reg.list_blocks() + assert len(blocks) == initial_count + 1 + fake = next(b for b in blocks if b.type == "FakeUserBlock") + assert fake.source == "user" + assert fake.available is True + + +def test_unregister_block(): + reg = BlockRegistry() + reg.register(FakeUserBlock, source="user") + assert reg.get_block_class("FakeUserBlock") is not None + + reg.unregister("FakeUserBlock") + assert reg.get_block_class("FakeUserBlock") is None + + +def test_unregister_nonexistent_is_noop(): + reg = BlockRegistry() + reg.unregister("DoesNotExist") # should not raise + + +def test_register_replaces_existing(): + reg = BlockRegistry() + reg.register(FakeUserBlock, source="user") + reg.register(FakeUserBlock, source="user") + matches = [b for b in reg.list_blocks() if b.type == "FakeUserBlock"] + assert len(matches) == 1 + + +# --- unavailable blocks --- + + +def test_register_unavailable_block(): + reg = BlockRegistry() + reg.register( + BlockWithDeps, source="user", available=False, error="Missing: some_nonexistent_package" + ) + blocks = reg.list_blocks() + block = next(b for b in blocks if b.type == "BlockWithDeps") + assert block.available is False + assert block.error is not None + assert "some_nonexistent_package" in block.error + + +def test_unavailable_block_class_still_accessible(): + """even unavailable blocks can be retrieved by class for inspection""" + reg = BlockRegistry() + reg.register(BlockWithDeps, source="user", available=False, error="missing deps") + cls = reg.get_block_class("BlockWithDeps") + assert cls is not None + assert cls is BlockWithDeps + + +# --- get_block_source --- + + +def test_get_block_source(): + reg = BlockRegistry() + reg.register(FakeUserBlock, source="user") + assert reg.get_block_source("FakeUserBlock") == "user" + assert reg.get_block_source("TextGenerator") == "builtin" + assert reg.get_block_source("NonExistent") is None diff --git a/tests/blocks/test_structure_sampler.py b/tests/blocks/test_structure_sampler.py index 4e6b3bb..b3f84c8 100644 --- a/tests/blocks/test_structure_sampler.py +++ b/tests/blocks/test_structure_sampler.py @@ -75,7 +75,7 @@ async def test_conditional_probabilities(self): ) # set attributes that would normally be set in execute() block.categorical_fields = ["plan", "role"] - block.dependencies = {"role": ["plan"]} + block._field_deps = {"role": ["plan"]} samples = [ {"plan": "Free", "role": "Viewer"}, diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py new file mode 100644 index 0000000..b2a1e2c --- /dev/null +++ b/tests/cli/test_commands.py @@ -0,0 +1,245 @@ +""" +Tests for dgf CLI commands. +Uses typer CliRunner to test commands without a running server. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +runner = CliRunner() + + +@pytest.fixture +def mock_client(): + """mock DataGenFlowClient that returns realistic data""" + client = MagicMock() + client.health.return_value = {"status": "healthy"} + client.extension_status.return_value = { + "blocks": { + "total": 8, + "builtin_blocks": 6, + "custom_blocks": 1, + "user_blocks": 1, + "available": 7, + "unavailable": 1, + }, + "templates": { + "total": 3, + "builtin_templates": 2, + "user_templates": 1, + }, + } + client.list_blocks.return_value = [ + { + "type": "TextGenerator", + "name": "Text Generator", + "category": "generation", + "source": "builtin", + "available": True, + "error": None, + }, + { + "type": "BrokenBlock", + "name": "Broken Block", + "category": "custom", + "source": "user", + "available": False, + "error": "missing dependency: torch", + }, + ] + client.list_templates.return_value = [ + { + "id": "qa_generation", + "name": "Q&A Generation", + "description": "Generate Q&A pairs", + "source": "builtin", + }, + ] + client.validate_block.return_value = {"valid": True, "block": "TextGenerator"} + client.install_block_deps.return_value = {"status": "ok", "installed": ["torch>=2.0"]} + return client + + +@pytest.fixture +def cli_app(mock_client): + """import the app with mocked client""" + from lib.cli.main import app + + # patch get_client to return our mock + with patch("lib.cli.main.get_client", return_value=mock_client): + yield app + + +class TestStatusCommand: + def test_status_shows_server_info(self, cli_app, mock_client): + result = runner.invoke(cli_app, ["status"]) + assert result.exit_code == 0 + assert "7 available" in result.output + assert "1 unavailable" in result.output + + def test_status_connection_error(self, cli_app, mock_client): + mock_client.health.side_effect = Exception("Connection refused") + result = runner.invoke(cli_app, ["status"]) + assert result.exit_code == 1 + assert "Cannot connect" in result.output + + +class TestBlocksCommands: + def test_blocks_list(self, cli_app, mock_client): + result = runner.invoke(cli_app, ["blocks", "list"]) + assert result.exit_code == 0 + assert "TextGenerator" in result.output + assert "BrokenBlock" in result.output + assert "builtin" in result.output + assert "user" in result.output + + def test_blocks_validate_valid_file(self, cli_app, tmp_path): + block_file = tmp_path / "my_block.py" + block_file.write_text( + "from lib.blocks.base import BaseBlock\nclass MyBlock(BaseBlock):\n pass\n" + ) + result = runner.invoke(cli_app, ["blocks", "validate", str(block_file)]) + assert result.exit_code == 0 + assert "valid" in result.output + assert "MyBlock" in result.output + + def test_blocks_validate_no_block_class(self, cli_app, tmp_path): + block_file = tmp_path / "not_a_block.py" + block_file.write_text("class Foo:\n pass\n") + result = runner.invoke(cli_app, ["blocks", "validate", str(block_file)]) + assert result.exit_code == 1 + assert "No block classes" in result.output + + def test_blocks_validate_syntax_error(self, cli_app, tmp_path): + block_file = tmp_path / "bad.py" + block_file.write_text("def broken(\n") + result = runner.invoke(cli_app, ["blocks", "validate", str(block_file)]) + assert result.exit_code == 1 + assert "Syntax error" in result.output + + def test_blocks_validate_missing_file(self, cli_app): + result = runner.invoke(cli_app, ["blocks", "validate", "/nonexistent.py"]) + assert result.exit_code == 1 + assert "not found" in result.output + + def test_blocks_scaffold(self, cli_app, tmp_path): + result = runner.invoke( + cli_app, ["blocks", "scaffold", "SentimentAnalyzer", "-o", str(tmp_path)] + ) + assert result.exit_code == 0 + output_file = tmp_path / "sentiment_analyzer.py" + assert output_file.exists() + content = output_file.read_text() + assert "class SentimentAnalyzer" in content + assert "BaseBlock" in content + + def test_blocks_add_with_install_deps(self, cli_app, mock_client, tmp_path): + block_file = tmp_path / "my_block.py" + block_file.write_text( + "from lib.blocks.base import BaseBlock\nclass MyBlock(BaseBlock):\n pass\n" + ) + mock_client.validate_block.return_value = {"valid": True} + result = runner.invoke(cli_app, ["blocks", "add", str(block_file), "--install-deps"]) + assert result.exit_code == 0 + mock_client.install_block_deps.assert_called_once_with("MyBlock") + assert "Dependencies installed" in result.output + + def test_blocks_add_install_deps_failure(self, cli_app, mock_client, tmp_path): + block_file = tmp_path / "my_block.py" + block_file.write_text( + "from lib.blocks.base import BaseBlock\nclass MyBlock(BaseBlock):\n pass\n" + ) + import httpx + + mock_client.validate_block.return_value = {"valid": True} + mock_client.install_block_deps.side_effect = httpx.HTTPError("install failed") + result = runner.invoke(cli_app, ["blocks", "add", str(block_file), "--install-deps"]) + assert "Failed to install deps" in result.output + + +class TestTemplatesCommands: + def test_templates_list(self, cli_app, mock_client): + result = runner.invoke(cli_app, ["templates", "list"]) + assert result.exit_code == 0 + assert "qa_generation" in result.output + assert "Q&A Generation" in result.output + + def test_templates_validate_valid(self, cli_app, tmp_path): + template_file = tmp_path / "my_template.yaml" + template_file.write_text( + 'name: "Test Template"\n' + 'description: "A test"\n' + "blocks:\n" + " - type: TextGenerator\n" + " config:\n" + " model: gpt-4o-mini\n" + ) + result = runner.invoke(cli_app, ["templates", "validate", str(template_file)]) + assert result.exit_code == 0 + assert "valid" in result.output + + def test_templates_validate_missing_name(self, cli_app, tmp_path): + template_file = tmp_path / "bad.yaml" + template_file.write_text("blocks:\n - type: Foo\n") + result = runner.invoke(cli_app, ["templates", "validate", str(template_file)]) + assert result.exit_code == 1 + assert "name" in result.output + + def test_templates_validate_missing_blocks(self, cli_app, tmp_path): + template_file = tmp_path / "bad2.yaml" + template_file.write_text('name: "Test"\n') + result = runner.invoke(cli_app, ["templates", "validate", str(template_file)]) + assert result.exit_code == 1 + assert "blocks" in result.output + + def test_templates_scaffold(self, cli_app, tmp_path): + result = runner.invoke( + cli_app, ["templates", "scaffold", "My Custom Pipeline", "-o", str(tmp_path)] + ) + assert result.exit_code == 0 + output_file = tmp_path / "my_custom_pipeline.yaml" + assert output_file.exists() + content = output_file.read_text() + assert "My Custom Pipeline" in content + assert "blocks:" in content + + +class TestConfigureCommand: + def test_configure_show(self, cli_app): + result = runner.invoke(cli_app, ["configure", "--show"]) + assert result.exit_code == 0 + assert "Endpoint" in result.output + + def test_configure_set_endpoint(self, cli_app, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + result = runner.invoke(cli_app, ["configure", "-e", "http://myserver:9000"]) + assert result.exit_code == 0 + env_content = (tmp_path / ".env").read_text() + assert "DATAGENFLOW_ENDPOINT=http://myserver:9000" in env_content + + +class TestImageCommands: + def test_image_scaffold(self, cli_app, tmp_path): + output = tmp_path / "Dockerfile.custom" + result = runner.invoke(cli_app, ["image", "scaffold", "-o", str(output)]) + assert result.exit_code == 0 + assert output.exists() + content = output.read_text() + assert "datagenflow" in content.lower() + + def test_image_scaffold_with_blocks(self, cli_app, tmp_path): + blocks_dir = tmp_path / "blocks" + blocks_dir.mkdir() + (blocks_dir / "my_block.py").write_text( + 'class MyBlock:\n dependencies = ["torch>=2.0", "transformers"]\n' + ) + output = tmp_path / "Dockerfile.custom" + result = runner.invoke( + cli_app, ["image", "scaffold", "-b", str(blocks_dir), "-o", str(output)] + ) + assert result.exit_code == 0 + content = output.read_text() + assert "torch>=2.0" in content + assert "transformers" in content diff --git a/tests/e2e/run_all_tests.sh b/tests/e2e/run_all_tests.sh index e783cc6..4f574ec 100644 --- a/tests/e2e/run_all_tests.sh +++ b/tests/e2e/run_all_tests.sh @@ -75,6 +75,21 @@ uv run python "$PROJECT_ROOT/scripts/with_server.py" \ -- uv run python "$SCRIPT_DIR/test_review_e2e.py" echo "" +echo "📋 Test Suite 4: Extensions (UI)" +echo "-------------------------" +uv run python "$PROJECT_ROOT/scripts/with_server.py" \ + --server "$BACKEND_CMD" --port 8000 \ + --server "$FRONTEND_CMD" --port 5173 \ + -- uv run python "$SCRIPT_DIR/test_extensions_e2e.py" +echo "" + +echo "📋 Test Suite 5: Extensions (API)" +echo "-------------------------" +uv run python "$PROJECT_ROOT/scripts/with_server.py" \ + --server "$BACKEND_CMD" --port 8000 \ + -- uv run python "$SCRIPT_DIR/test_extensions_api_e2e.py" +echo "" + echo "✅ All E2E tests completed!" echo "" echo "📸 Screenshots saved to /tmp/" diff --git a/tests/e2e/test_extensions_api_e2e.py b/tests/e2e/test_extensions_api_e2e.py new file mode 100644 index 0000000..abc5fbb --- /dev/null +++ b/tests/e2e/test_extensions_api_e2e.py @@ -0,0 +1,363 @@ +""" +E2E tests for extensions REST API. +Tests the backend API endpoints directly via HTTP, +complementing the Playwright UI tests. + +Requires: running server (uvicorn on port 8000). +""" + +import pytest + +try: + from .test_helpers import ( + cleanup_database, + get_block_dependencies, + get_blocks_list, + get_extensions_status, + get_templates_list, + reload_extensions, + validate_block, + wait_for_server, + ) +except ImportError: + from test_helpers import ( + cleanup_database, + get_block_dependencies, + get_blocks_list, + get_extensions_status, + get_templates_list, + reload_extensions, + validate_block, + wait_for_server, + ) + +import httpx + +BASE_URL = "http://localhost:8000" + + +@pytest.fixture(scope="module", autouse=True) +def _e2e_setup(): + if not wait_for_server(): + pytest.skip("server not ready for e2e tests") + cleanup_database() + yield + cleanup_database() + + +# --- GET /api/extensions/status --- + + +def test_status_returns_valid_structure(): + """verify status endpoint returns expected fields""" + status = get_extensions_status() + + assert "blocks" in status + assert "templates" in status + + blocks = status["blocks"] + assert "total" in blocks + assert "builtin_blocks" in blocks + assert "custom_blocks" in blocks + assert "user_blocks" in blocks + assert "available" in blocks + assert "unavailable" in blocks + + templates = status["templates"] + assert "total" in templates + assert "builtin_templates" in templates + assert "user_templates" in templates + + +def test_status_counts_are_consistent(): + """verify status counts add up correctly""" + status = get_extensions_status() + + blocks = status["blocks"] + # total should equal sum of sources + assert ( + blocks["total"] + == blocks["builtin_blocks"] + blocks["custom_blocks"] + blocks["user_blocks"] + ) + # total should equal available + unavailable + assert blocks["total"] == blocks["available"] + blocks["unavailable"] + # should have at least some builtin blocks + assert blocks["builtin_blocks"] > 0 + + +def test_status_block_count_matches_blocks_list(): + """verify status total matches actual blocks list length""" + status = get_extensions_status() + blocks = get_blocks_list() + + assert status["blocks"]["total"] == len(blocks) + + +def test_status_template_count_matches_templates_list(): + """verify status total matches actual templates list length""" + status = get_extensions_status() + templates = get_templates_list() + + assert status["templates"]["total"] == len(templates) + + +# --- GET /api/extensions/blocks --- + + +def test_blocks_list_returns_expected_fields(): + """verify each block has required fields""" + blocks = get_blocks_list() + assert len(blocks) > 0 + + for block in blocks: + assert "name" in block, f"block missing 'name': {block}" + assert "type" in block, f"block missing 'type': {block}" + assert "source" in block, f"block missing 'source': {block}" + assert "available" in block, f"block missing 'available': {block}" + assert "category" in block, f"block missing 'category': {block}" + assert "dependencies" in block, f"block missing 'dependencies': {block}" + assert isinstance(block["dependencies"], list) + + +def test_blocks_list_contains_known_builtin_blocks(): + """verify well-known builtin blocks are present""" + blocks = get_blocks_list() + block_types = {b["type"] for b in blocks} + + # these should always exist as builtin blocks + expected_builtins = {"TextGenerator", "JSONValidatorBlock", "FieldMapper"} + for expected in expected_builtins: + assert expected in block_types, f"expected builtin block '{expected}' not found" + + +def test_blocks_list_contains_new_blocks(): + """verify new blocks from this branch are registered""" + blocks = get_blocks_list() + block_types = {b["type"] for b in blocks} + + new_blocks = {"DuplicateRemover", "SemanticInfiller", "StructureSampler"} + for expected in new_blocks: + assert expected in block_types, f"new block '{expected}' not found in registry" + + +def test_blocks_sources_are_valid(): + """verify all block sources are one of the allowed values""" + blocks = get_blocks_list() + valid_sources = {"builtin", "custom", "user"} + + for block in blocks: + assert block["source"] in valid_sources, ( + f"block '{block['type']}' has invalid source '{block['source']}'" + ) + + +def test_blocks_categories_are_valid(): + """verify all block categories are known""" + blocks = get_blocks_list() + valid_categories = { + "generators", + "validators", + "processors", + "seeders", + "metrics", + "integrations", + "utilities", + "general", + } + + for block in blocks: + assert block["category"] in valid_categories, ( + f"block '{block['type']}' has unknown category '{block['category']}'" + ) + + +# --- GET /api/extensions/templates --- + + +def test_templates_list_returns_expected_fields(): + """verify each template has required fields""" + templates = get_templates_list() + assert len(templates) > 0 + + for tmpl in templates: + assert "id" in tmpl, f"template missing 'id': {tmpl}" + assert "name" in tmpl, f"template missing 'name': {tmpl}" + assert "source" in tmpl, f"template missing 'source': {tmpl}" + assert "description" in tmpl, f"template missing 'description': {tmpl}" + + +def test_templates_sources_are_valid(): + """verify all template sources are valid""" + templates = get_templates_list() + valid_sources = {"builtin", "user"} + + for tmpl in templates: + assert tmpl["source"] in valid_sources, ( + f"template '{tmpl['id']}' has invalid source '{tmpl['source']}'" + ) + + +# --- POST /api/extensions/reload --- + + +def test_reload_returns_ok(): + """verify reload endpoint returns success response""" + result = reload_extensions() + + assert result["status"] == "ok" + assert "message" in result + + +def test_reload_is_idempotent(): + """verify multiple reloads don't change state""" + status_before = get_extensions_status() + + reload_extensions() + reload_extensions() + + status_after = get_extensions_status() + assert status_after["blocks"]["total"] == status_before["blocks"]["total"] + assert status_after["templates"]["total"] == status_before["templates"]["total"] + + +# --- POST /api/extensions/blocks/{name}/validate --- + + +def test_validate_available_block(): + """verify validation of an available block returns valid=True""" + blocks = get_blocks_list() + available = next((b for b in blocks if b["available"]), None) + assert available is not None, "need at least one available block" + + result = validate_block(available["type"]) + assert result["valid"] is True + assert result["block"] == available["type"] + + +def test_validate_returns_block_name(): + """verify validation response includes block name""" + blocks = get_blocks_list() + assert len(blocks) > 0 + + result = validate_block(blocks[0]["type"]) + assert "block" in result + assert result["block"] == blocks[0]["type"] + + +def test_validate_nonexistent_block_returns_404(): + """verify validation of nonexistent block returns 404""" + resp = httpx.post( + f"{BASE_URL}/api/extensions/blocks/nonexistent_block_xyz/validate", timeout=10.0 + ) + assert resp.status_code == 404 + + +# --- GET /api/extensions/blocks/{name}/dependencies --- + + +def test_dependencies_returns_list(): + """verify dependencies endpoint returns a list""" + blocks = get_blocks_list() + assert len(blocks) > 0 + + result = get_block_dependencies(blocks[0]["type"]) + assert isinstance(result, list) + + +def test_dependencies_for_block_with_deps(): + """verify blocks with declared dependencies return dependency info""" + blocks = get_blocks_list() + block_with_deps = next( + (b for b in blocks if b.get("dependencies") and len(b["dependencies"]) > 0), None + ) + if block_with_deps is None: + pytest.skip("no blocks with dependencies found") + + result = get_block_dependencies(block_with_deps["type"]) + assert len(result) > 0 + + for dep in result: + assert "name" in dep + assert "installed" in dep + + +def test_dependencies_nonexistent_block_returns_404(): + """verify dependencies for nonexistent block returns 404""" + resp = httpx.get( + f"{BASE_URL}/api/extensions/blocks/nonexistent_block_xyz/dependencies", timeout=10.0 + ) + assert resp.status_code == 404 + + +# --- cross-endpoint consistency --- + + +def test_available_blocks_are_all_valid(): + """verify all blocks marked available pass validation""" + blocks = get_blocks_list() + available_blocks = [b for b in blocks if b["available"]] + + for block in available_blocks: + result = validate_block(block["type"]) + assert result["valid"] is True, ( + f"block '{block['type']}' is marked available but fails validation: {result}" + ) + + +def test_reload_then_validate_still_works(): + """verify validation works correctly after a reload""" + blocks_before = get_blocks_list() + available = next((b for b in blocks_before if b["available"]), None) + assert available is not None + + reload_extensions() + + result = validate_block(available["type"]) + assert result["valid"] is True + + +if __name__ == "__main__": + print("running extensions API e2e tests...") + + wait_for_server() + cleanup_database() + + tests = [ + ("status structure", test_status_returns_valid_structure), + ("status counts consistent", test_status_counts_are_consistent), + ("status matches blocks list", test_status_block_count_matches_blocks_list), + ("status matches templates list", test_status_template_count_matches_templates_list), + ("blocks have required fields", test_blocks_list_returns_expected_fields), + ("known builtin blocks exist", test_blocks_list_contains_known_builtin_blocks), + ("new blocks registered", test_blocks_list_contains_new_blocks), + ("block sources valid", test_blocks_sources_are_valid), + ("block categories valid", test_blocks_categories_are_valid), + ("templates have required fields", test_templates_list_returns_expected_fields), + ("template sources valid", test_templates_sources_are_valid), + ("reload returns ok", test_reload_returns_ok), + ("reload is idempotent", test_reload_is_idempotent), + ("validate available block", test_validate_available_block), + ("validate returns block name", test_validate_returns_block_name), + ("validate nonexistent 404", test_validate_nonexistent_block_returns_404), + ("dependencies returns list", test_dependencies_returns_list), + ("dependencies for block with deps", test_dependencies_for_block_with_deps), + ("dependencies nonexistent 404", test_dependencies_nonexistent_block_returns_404), + ("available blocks all valid", test_available_blocks_are_all_valid), + ("reload then validate", test_reload_then_validate_still_works), + ] + + for name, test_fn in tests: + print(f"\ntest: {name}") + try: + test_fn() + print("✓ passed") + except BaseException as e: + if type(e).__name__ == "Skipped": + print(f"⊘ skipped: {e}") + elif isinstance(e, (KeyboardInterrupt, SystemExit)): + raise + else: + print(f"✗ failed: {e}") + + cleanup_database() + print("\n✅ all extensions API e2e tests completed!") diff --git a/tests/e2e/test_extensions_e2e.py b/tests/e2e/test_extensions_e2e.py new file mode 100644 index 0000000..bd6424a --- /dev/null +++ b/tests/e2e/test_extensions_e2e.py @@ -0,0 +1,414 @@ +""" +E2E tests for extensions page. +Tests UI interactions with the extensions management interface, +block validation, reload, template creation, and status display. + +Requires: running server (yarn dev + uvicorn) and playwright installed. +""" + +import pytest +from playwright.sync_api import expect, sync_playwright + +try: + from .test_helpers import ( + cleanup_database, + get_blocks_list, + get_extensions_status, + get_headless_mode, + get_templates_list, + wait_for_server, + ) +except ImportError: + from test_helpers import ( + cleanup_database, + get_blocks_list, + get_extensions_status, + get_headless_mode, + get_templates_list, + wait_for_server, + ) + + +@pytest.fixture(scope="module", autouse=True) +def _e2e_setup(): + if not wait_for_server(): + pytest.skip("server not ready for e2e tests") + cleanup_database() + yield + cleanup_database() + + +def _navigate_to_extensions(page): + """navigate to extensions page and wait for data to load""" + page.goto("http://localhost:5173/extensions") + page.wait_for_load_state("networkidle") + # wait for blocks section to render (means API data loaded) + page.wait_for_selector("h2:has-text('Blocks')", timeout=10000) + + +# --- page structure tests --- + + +def test_extensions_page_loads(): + """verify extensions page loads with all major sections""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify all three main sections + expect(page.get_by_role("heading", name="Extensions")).to_be_visible() + expect(page.get_by_role("heading", name="Blocks")).to_be_visible() + expect(page.get_by_role("heading", name="Templates")).to_be_visible() + + # verify reload button exists + expect(page.get_by_role("button", name="Reload")).to_be_visible() + + browser.close() + + +def test_extensions_status_cards_show_counts(): + """verify status overview cards display correct counts matching API""" + api_status = get_extensions_status() + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify "available" text appears with correct count + available_text = page.locator(f"text={api_status['blocks']['available']} available") + expect(available_text.first).to_be_visible() + + # verify "Builtin" label is present + expect(page.locator("text=Builtin").first).to_be_visible() + + # verify block count in status card matches API + builtin_count = api_status["blocks"]["builtin_blocks"] + assert builtin_count > 0, "should have at least one builtin block" + + browser.close() + + +# --- block cards tests --- + + +def test_extensions_shows_block_cards_with_details(): + """verify block cards render with name, source badge, type, and description""" + api_blocks = get_blocks_list() + assert len(api_blocks) > 0, "API should return at least one block" + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify builtin badge appears + expect(page.locator("text=builtin").first).to_be_visible() + + # verify at least one known block name is rendered + first_block = api_blocks[0] + expect(page.locator(f"text={first_block['name']}").first).to_be_visible() + + # verify block type (mono text) is shown + expect(page.locator(f"text={first_block['type']}").first).to_be_visible() + + # verify each block card has a Validate button + validate_buttons = page.get_by_role("button", name="Validate") + assert validate_buttons.count() >= len(api_blocks), ( + f"expected at least {len(api_blocks)} validate buttons, got {validate_buttons.count()}" + ) + + # verify available label is present on cards + available_labels = page.locator("text=available") + assert available_labels.count() > 0, "should show 'available' labels on block cards" + + browser.close() + + +def test_block_validate_shows_success_toast(): + """verify clicking Validate on an available block shows success toast""" + api_blocks = get_blocks_list() + # find an available block + available_block = next((b for b in api_blocks if b["available"]), None) + assert available_block is not None, "need at least one available block for this test" + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # click the validate button within the card for the selected available block + block_card = page.locator( + f"xpath=//*/descendant-or-self::*[normalize-space()='{available_block['name']}']/ancestor::*[.//button[normalize-space()='Validate']][1]" + ) + block_card.get_by_role("button", name="Validate").click() + + # wait for success toast + toast = page.locator("text=is valid") + expect(toast).to_be_visible(timeout=5000) + + browser.close() + + +def test_block_cards_count_matches_api(): + """verify the number of block cards in UI matches API response""" + api_blocks = get_blocks_list() + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # count validate buttons as proxy for block cards (each card has exactly one) + validate_buttons = page.get_by_role("button", name="Validate") + assert validate_buttons.count() == len(api_blocks), ( + f"UI shows {validate_buttons.count()} blocks, API returns {len(api_blocks)}" + ) + + browser.close() + + +# --- reload tests --- + + +def test_extensions_reload_shows_success_toast(): + """verify reload button triggers reload and shows success toast""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # click reload + reload_btn = page.get_by_role("button", name="Reload") + reload_btn.click() + + # button should show "Reloading..." while in progress + # then success toast appears + toast = page.locator("text=Extensions reloaded") + expect(toast).to_be_visible(timeout=5000) + + browser.close() + + +def test_reload_preserves_block_count(): + """verify reload does not lose any blocks""" + api_status_before = get_extensions_status() + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # click reload + page.get_by_role("button", name="Reload").click() + page.locator("text=Extensions reloaded").wait_for(timeout=5000) + + # wait for page to re-render after reload + page.wait_for_load_state("networkidle") + + browser.close() + + # verify API still returns same counts + api_status_after = get_extensions_status() + assert api_status_after["blocks"]["total"] == api_status_before["blocks"]["total"], ( + "reload should not change block count" + ) + assert api_status_after["templates"]["total"] == api_status_before["templates"]["total"], ( + "reload should not change template count" + ) + + +# --- template cards tests --- + + +def test_extensions_shows_template_cards(): + """verify template cards render with name, source badge, and Create Pipeline button""" + api_templates = get_templates_list() + assert len(api_templates) > 0, "API should return at least one template" + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify first template name is displayed + first_template = api_templates[0] + expect(page.locator(f"text={first_template['name']}").first).to_be_visible() + + # verify "Create Pipeline" buttons match template count + create_buttons = page.get_by_role("button", name="Create Pipeline") + assert create_buttons.count() == len(api_templates), ( + f"UI shows {create_buttons.count()} template buttons, API returns {len(api_templates)}" + ) + + browser.close() + + +def test_create_pipeline_from_template_card(): + """verify clicking Create Pipeline on a template card creates pipeline and navigates""" + api_templates = get_templates_list() + assert len(api_templates) > 0, "need at least one template" + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # click first "Create Pipeline" button + create_btn = page.get_by_role("button", name="Create Pipeline").first + create_btn.click() + + # should show success toast + toast = page.locator("text=Pipeline created from template") + expect(toast).to_be_visible(timeout=5000) + + # should navigate to pipelines page + page.wait_for_url("**/pipelines", timeout=5000) + expect(page.get_by_role("heading", name="Pipelines", exact=True)).to_be_visible() + + browser.close() + + +# --- navigation tests --- + + +def test_navigate_to_extensions_from_sidebar(): + """verify navigating to extensions via sidebar link""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # start from homepage + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # click Extensions in sidebar + page.get_by_text("Extensions", exact=True).click() + page.wait_for_url("**/extensions", timeout=5000) + + # verify page content loaded + expect(page.get_by_role("heading", name="Extensions")).to_be_visible() + page.wait_for_selector("h2:has-text('Blocks')", timeout=10000) + + browser.close() + + +# --- edge case tests --- + + +def test_validate_button_produces_toast(): + """verify Validate button produces a toast (success or error) without crashing""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # click the first validate button + validate_btn = page.get_by_role("button", name="Validate").first + validate_btn.click() + + # should show some toast (success or error) -- not a crash + # look for any toast notification (sonner uses [data-sonner-toast]) + toast = page.locator("[data-sonner-toast]") + expect(toast.first).to_be_visible(timeout=5000) + + browser.close() + + +def test_extensions_page_shows_block_descriptions(): + """verify block descriptions from API are rendered in UI""" + api_blocks = get_blocks_list() + # find a block with a non-empty description + block_with_desc = next((b for b in api_blocks if b.get("description")), None) + if block_with_desc is None: + pytest.skip("no blocks with descriptions found") + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify description text is visible (use partial match for long descriptions) + desc_text = block_with_desc["description"][:50] + expect(page.locator(f"text={desc_text}").first).to_be_visible() + + browser.close() + + +def test_extensions_page_shows_block_dependencies(): + """verify blocks with dependencies display them in the UI""" + api_blocks = get_blocks_list() + # find a block with dependencies + block_with_deps = next( + (b for b in api_blocks if b.get("dependencies") and len(b["dependencies"]) > 0), None + ) + if block_with_deps is None: + pytest.skip("no blocks with dependencies found") + + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + _navigate_to_extensions(page) + + # verify at least the first dependency name is rendered + first_dep = block_with_deps["dependencies"][0] + expect(page.locator(f"text={first_dep}").first).to_be_visible() + + browser.close() + + +if __name__ == "__main__": + print("running extensions e2e tests...") + + if not wait_for_server(): + raise SystemExit("server not ready for e2e tests") + cleanup_database() + + tests = [ + ("extensions page loads", test_extensions_page_loads), + ("status cards show counts", test_extensions_status_cards_show_counts), + ("block cards with details", test_extensions_shows_block_cards_with_details), + ("block validate success toast", test_block_validate_shows_success_toast), + ("block cards count matches API", test_block_cards_count_matches_api), + ("reload shows success toast", test_extensions_reload_shows_success_toast), + ("reload preserves block count", test_reload_preserves_block_count), + ("template cards render", test_extensions_shows_template_cards), + ("create pipeline from template", test_create_pipeline_from_template_card), + ("navigate from sidebar", test_navigate_to_extensions_from_sidebar), + ("validate button produces toast", test_validate_button_produces_toast), + ("block descriptions shown", test_extensions_page_shows_block_descriptions), + ("block dependencies shown", test_extensions_page_shows_block_dependencies), + ] + + failures = 0 + for name, test_fn in tests: + print(f"\ntest: {name}") + try: + test_fn() + print("✓ passed") + except BaseException as e: + if type(e).__name__ == "Skipped": + print(f"⊘ skipped: {e}") + elif isinstance(e, (KeyboardInterrupt, SystemExit)): + raise + else: + print(f"✗ failed: {e}") + failures += 1 + + cleanup_database() + if failures: + raise SystemExit(f"\n{failures} extensions e2e test(s) failed") + print("\n✅ all extensions e2e tests completed!") diff --git a/tests/e2e/test_helpers.py b/tests/e2e/test_helpers.py index 97463a4..fa8940f 100644 --- a/tests/e2e/test_helpers.py +++ b/tests/e2e/test_helpers.py @@ -70,3 +70,57 @@ def get_pipeline_count(): except Exception as e: print(f"get_pipeline_count warning: {e}") return -1 + + +# --- extensions helpers --- + +BASE_URL = "http://localhost:8000" + + +def get_extensions_status(): + """get extensions status from API""" + resp = httpx.get(f"{BASE_URL}/api/extensions/status", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def get_blocks_list(): + """get all registered blocks from API""" + resp = httpx.get(f"{BASE_URL}/api/extensions/blocks", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def get_templates_list(): + """get all registered templates from API""" + resp = httpx.get(f"{BASE_URL}/api/extensions/templates", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def reload_extensions(): + """trigger extension reload via API""" + resp = httpx.post(f"{BASE_URL}/api/extensions/reload", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def validate_block(block_type: str): + """validate a block via API""" + resp = httpx.post(f"{BASE_URL}/api/extensions/blocks/{block_type}/validate", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def get_block_dependencies(block_type: str): + """get dependency info for a block""" + resp = httpx.get(f"{BASE_URL}/api/extensions/blocks/{block_type}/dependencies", timeout=10.0) + resp.raise_for_status() + return resp.json() + + +def create_pipeline_from_template(template_id: str): + """create a pipeline from template via API""" + resp = httpx.post(f"{BASE_URL}/api/pipelines/from_template/{template_id}", timeout=10.0) + resp.raise_for_status() + return resp.json() diff --git a/tests/integration/test_extensions.py b/tests/integration/test_extensions.py new file mode 100644 index 0000000..1f90a93 --- /dev/null +++ b/tests/integration/test_extensions.py @@ -0,0 +1,219 @@ +""" +Integration tests for the extensions system. +Tests the full stack: registry + API + dependency manager working together. +""" + + +class TestExtensionsFullStack: + """tests that exercise registry -> API -> response chain""" + + def test_extensions_status_counts_match_blocks_list(self, client): + """status endpoint counts should match actual blocks list length""" + status = client.get("/api/extensions/status").json() + blocks = client.get("/api/extensions/blocks").json() + + assert status["blocks"]["total"] == len(blocks) + assert status["blocks"]["available"] == sum(1 for b in blocks if b["available"]) + assert status["blocks"]["unavailable"] == sum(1 for b in blocks if not b["available"]) + + def test_extensions_status_counts_match_templates_list(self, client): + """status template counts should match actual templates list length""" + status = client.get("/api/extensions/status").json() + templates = client.get("/api/extensions/templates").json() + + assert status["templates"]["total"] == len(templates) + + def test_all_blocks_have_extensibility_fields(self, client): + """every block from extensions endpoint has source and available fields""" + blocks = client.get("/api/extensions/blocks").json() + assert len(blocks) > 0 + + for block in blocks: + assert "source" in block + assert "available" in block + assert "dependencies" in block + assert block["source"] in ("builtin", "custom", "user") + + def test_all_templates_have_source(self, client): + """every template from extensions endpoint has source field""" + templates = client.get("/api/extensions/templates").json() + assert len(templates) > 0 + + for tmpl in templates: + assert "source" in tmpl + assert tmpl["source"] in ("builtin", "user") + + def test_validate_then_check_dependencies(self, client): + """validate a block, then check its dependencies - full flow""" + resp = client.post("/api/extensions/blocks/TextGenerator/validate") + assert resp.status_code == 200 + assert resp.json()["valid"] is True + + resp = client.get("/api/extensions/blocks/TextGenerator/dependencies") + assert resp.status_code == 200 + deps = resp.json() + assert isinstance(deps, list) + + def test_reload_returns_success(self, client): + """POST /api/extensions/reload returns ok status""" + response = client.post("/api/extensions/reload") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + def test_reload_preserves_builtin_blocks(self, client): + """reloading extensions should not lose builtin blocks""" + blocks_before = client.get("/api/extensions/blocks").json() + builtin_before = [b["type"] for b in blocks_before if b["source"] == "builtin"] + + client.post("/api/extensions/reload") + + blocks_after = client.get("/api/extensions/blocks").json() + builtin_after = [b["type"] for b in blocks_after if b["source"] == "builtin"] + + assert set(builtin_before) == set(builtin_after) + + def test_validate_nonexistent_block_returns_404(self, client): + resp = client.post("/api/extensions/blocks/DoesNotExist/validate") + assert resp.status_code == 404 + + def test_dependencies_nonexistent_block_returns_404(self, client): + resp = client.get("/api/extensions/blocks/DoesNotExist/dependencies") + assert resp.status_code == 404 + + def test_install_deps_nonexistent_block_returns_404(self, client): + resp = client.post("/api/extensions/blocks/DoesNotExist/install-deps") + assert resp.status_code == 404 + + def test_install_deps_block_with_no_missing_deps_returns_ok(self, client): + """POST /api/extensions/blocks/{name}/install-deps when all deps present""" + resp = client.post("/api/extensions/blocks/TextGenerator/install-deps") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert "message" in data or "installed" in data + + def test_install_deps_invokes_installer_and_returns_installed_list(self, client): + """POST /api/extensions/blocks/{name}/install-deps calls installer when deps are missing""" + from unittest.mock import AsyncMock, patch + + with ( + patch( + "lib.api.extensions.dependency_manager.check_missing", + return_value=["some-pkg"], + ), + patch( + "lib.api.extensions.dependency_manager.install", + new_callable=AsyncMock, + return_value=["some-pkg"], + ), + ): + resp = client.post("/api/extensions/blocks/TextGenerator/install-deps") + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["installed"] == ["some-pkg"] + + def test_get_dependencies_for_block_without_deps(self, client): + """GET /api/extensions/blocks/{name}/dependencies for block without deps""" + response = client.get("/api/extensions/blocks/ValidatorBlock/dependencies") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +class TestRegistryWithUserBlocks: + """tests for dynamic block registration via the registry""" + + def test_register_and_list_user_block(self): + """registering a block makes it appear in list_blocks""" + from lib.blocks.base import BaseBlock + from lib.blocks.registry import BlockRegistry + + registry = BlockRegistry() + initial_count = len(registry.list_blocks()) + + class DummyIntegrationBlock(BaseBlock): + name = "Dummy Integration" + description = "test block" + category = "generators" + inputs = ["text"] + outputs = ["result"] + + registry.register(DummyIntegrationBlock, source="user") + + blocks = registry.list_blocks() + assert len(blocks) == initial_count + 1 + + dummy = next(b for b in blocks if b.type == "DummyIntegrationBlock") + assert dummy.source == "user" + assert dummy.available is True + + registry.unregister("DummyIntegrationBlock") + assert len(registry.list_blocks()) == initial_count + + def test_register_unavailable_block(self): + """registering an unavailable block shows error info""" + from lib.blocks.base import BaseBlock + from lib.blocks.registry import BlockRegistry + + registry = BlockRegistry() + + class BrokenIntegrationBlock(BaseBlock): + name = "Broken" + description = "broken block" + category = "generators" + inputs = ["text"] + outputs = ["result"] + dependencies = ["nonexistent-package-xyz"] + + registry.register( + BrokenIntegrationBlock, + source="user", + available=False, + error="missing dependency: nonexistent-package-xyz", + ) + + blocks = registry.list_blocks() + broken = next(b for b in blocks if b.type == "BrokenIntegrationBlock") + assert broken.available is False + assert "nonexistent-package-xyz" in broken.error + + registry.unregister("BrokenIntegrationBlock") + + +class TestDependencyManagerIntegration: + """tests for dependency checking with real packages""" + + def test_check_installed_package(self): + """pydantic should be detected as installed""" + from lib.dependency_manager import dependency_manager + + missing = dependency_manager.check_missing(["pydantic"]) + assert "pydantic" not in missing + + def test_check_missing_package(self): + """nonexistent package should be detected as missing""" + from lib.dependency_manager import dependency_manager + + missing = dependency_manager.check_missing(["nonexistent-package-xyz-999"]) + assert "nonexistent-package-xyz-999" in missing + + def test_get_dependency_info_installed(self): + """dependency info for installed package has version""" + from lib.dependency_manager import dependency_manager + + info = dependency_manager.get_dependency_info(["pydantic"]) + assert len(info) == 1 + assert info[0].status == "ok" + assert info[0].installed_version is not None + + def test_get_dependency_info_missing(self): + """dependency info for missing package shows not_installed""" + from lib.dependency_manager import dependency_manager + + info = dependency_manager.get_dependency_info(["nonexistent-package-xyz-999"]) + assert len(info) == 1 + assert info[0].status == "not_installed" + assert info[0].installed_version is None diff --git a/tests/test_api_regression.py b/tests/test_api_regression.py new file mode 100644 index 0000000..235a353 --- /dev/null +++ b/tests/test_api_regression.py @@ -0,0 +1,99 @@ +""" +API regression tests — lock current behavior before extensibility changes. +""" + + +def test_health_endpoint(client): + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +def test_blocks_endpoint_returns_list(client): + response = client.get("/api/blocks") + assert response.status_code == 200 + blocks = response.json() + assert isinstance(blocks, list) + assert len(blocks) > 0 + + +def test_blocks_endpoint_schema_shape(client): + """each block must have type, name, description, category, inputs, outputs, config_schema""" + response = client.get("/api/blocks") + blocks = response.json() + + required_keys = { + "type", + "name", + "description", + "category", + "inputs", + "outputs", + "config_schema", + "dependencies", + } + for block in blocks: + missing = required_keys - set(block.keys()) + assert not missing, f"Block {block.get('type', '?')} missing keys: {missing}" + assert isinstance(block["dependencies"], list) + + +def test_blocks_endpoint_includes_core_blocks(client): + response = client.get("/api/blocks") + block_types = [b["type"] for b in response.json()] + + assert "TextGenerator" in block_types + assert "StructuredGenerator" in block_types + assert "ValidatorBlock" in block_types + assert "JSONValidatorBlock" in block_types + assert "FieldMapper" in block_types + + +def test_templates_endpoint_returns_list(client): + response = client.get("/api/templates") + assert response.status_code == 200 + templates = response.json() + assert isinstance(templates, list) + assert len(templates) > 0 + + +def test_templates_endpoint_schema_shape(client): + """each template must have id, name, description""" + response = client.get("/api/templates") + templates = response.json() + + for template in templates: + assert "id" in template, f"Template missing 'id': {template}" + assert "name" in template, f"Template missing 'name': {template}" + assert "description" in template, f"Template missing 'description': {template}" + + +def test_templates_endpoint_includes_core_templates(client): + response = client.get("/api/templates") + template_ids = [t["id"] for t in response.json()] + + assert "json_generation" in template_ids + assert "text_classification" in template_ids + assert "qa_generation" in template_ids + assert "ragas_evaluation" in template_ids + + +def test_pipelines_endpoint_returns_list(client): + response = client.get("/api/pipelines") + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +def test_create_pipeline_from_template(client): + response = client.post("/api/pipelines/from_template/json_generation") + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["name"] == "JSON Generation" + assert data["template_id"] == "json_generation" + + +def test_create_pipeline_from_invalid_template(client): + response = client.post("/api/pipelines/from_template/nonexistent") + assert response.status_code == 404 diff --git a/tests/test_dependency_manager.py b/tests/test_dependency_manager.py new file mode 100644 index 0000000..fc3ad32 --- /dev/null +++ b/tests/test_dependency_manager.py @@ -0,0 +1,103 @@ +""" +Tests for dependency manager: parse, check, info. +""" + +from lib.blocks.base import BaseBlock + + +class BlockWithDeps(BaseBlock): + name = "Test" + description = "Test" + category = "general" + inputs = [] + outputs = [] + dependencies = ["requests>=2.28.0", "pandas>=1.5.0"] + + async def execute(self, context): + return {} + + +class BlockNoDeps(BaseBlock): + name = "No Deps" + description = "No deps" + category = "general" + inputs = [] + outputs = [] + + async def execute(self, context): + return {} + + +def test_get_block_dependencies(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + assert manager.get_block_dependencies(BlockWithDeps) == [ + "requests>=2.28.0", + "pandas>=1.5.0", + ] + + +def test_get_block_dependencies_empty(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + assert manager.get_block_dependencies(BlockNoDeps) == [] + + +def test_check_missing_returns_uninstalled(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + missing = manager.check_missing(["nonexistent-package-xyz123"]) + assert "nonexistent-package-xyz123" in missing + + +def test_check_missing_returns_empty_for_installed(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + missing = manager.check_missing(["pytest"]) + assert missing == [] + + +def test_check_missing_handles_version_specifiers(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + # pytest is installed, version spec shouldn't break parsing + missing = manager.check_missing(["pytest>=1.0.0"]) + assert missing == [] + + +def test_get_dependency_info_installed(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + info = manager.get_dependency_info(["pytest"]) + assert len(info) == 1 + assert info[0].name == "pytest" + assert info[0].status == "ok" + assert info[0].installed_version is not None + + +def test_get_dependency_info_not_installed(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + info = manager.get_dependency_info(["nonexistent-xyz-999"]) + assert len(info) == 1 + assert info[0].status == "not_installed" + assert info[0].installed_version is None + + +def test_get_dependency_info_mixed(): + from lib.dependency_manager import DependencyManager + + manager = DependencyManager() + info = manager.get_dependency_info(["pytest", "nonexistent-xyz-999"]) + assert len(info) == 2 + + by_name = {i.name: i for i in info} + assert by_name["pytest"].status == "ok" + assert by_name["nonexistent-xyz-999"].status == "not_installed" diff --git a/tests/test_extensions_api.py b/tests/test_extensions_api.py new file mode 100644 index 0000000..4663e63 --- /dev/null +++ b/tests/test_extensions_api.py @@ -0,0 +1,75 @@ +""" +Tests for extensions API endpoints. +""" + + +def test_extensions_status(client): + response = client.get("/api/extensions/status") + assert response.status_code == 200 + data = response.json() + assert "blocks" in data + assert "templates" in data + assert "builtin_blocks" in data["blocks"] + assert "user_blocks" in data["blocks"] + assert "builtin_templates" in data["templates"] + assert "user_templates" in data["templates"] + + +def test_extensions_blocks(client): + response = client.get("/api/extensions/blocks") + assert response.status_code == 200 + blocks = response.json() + assert isinstance(blocks, list) + assert len(blocks) > 0 + # every block has source and available + for b in blocks: + assert "source" in b + assert "available" in b + + +def test_extensions_templates(client): + response = client.get("/api/extensions/templates") + assert response.status_code == 200 + templates = response.json() + assert isinstance(templates, list) + assert len(templates) > 0 + for t in templates: + assert "source" in t + assert "id" in t + + +def test_extensions_reload(client): + response = client.post("/api/extensions/reload") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + +def test_validate_block_available(client): + response = client.post("/api/extensions/blocks/TextGenerator/validate") + assert response.status_code == 200 + data = response.json() + assert data["valid"] is True + assert data["block"] == "TextGenerator" + + +def test_validate_block_not_found(client): + response = client.post("/api/extensions/blocks/NonExistent/validate") + assert response.status_code == 404 + + +def test_block_dependencies_endpoint(client): + response = client.get("/api/extensions/blocks/TextGenerator/dependencies") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +def test_install_deps_already_installed_returns_installed_key(client): + # regression: response must always include 'installed' key so the frontend + # can safely call result.installed.join() without a TypeError + response = client.post("/api/extensions/blocks/TextGenerator/install-deps") + assert response.status_code == 200 + data = response.json() + assert "installed" in data + assert isinstance(data["installed"], list) diff --git a/tests/test_file_watcher.py b/tests/test_file_watcher.py new file mode 100644 index 0000000..2ec312b --- /dev/null +++ b/tests/test_file_watcher.py @@ -0,0 +1,106 @@ +""" +Tests for file watcher module: debouncing, start/stop, reload on file changes. +""" + +import threading +from pathlib import Path + +from lib.blocks.registry import BlockRegistry +from lib.templates import TemplateRegistry + + +class TestDebouncedHandler: + def test_debounce_multiple_events(self): + """multiple rapid events result in single callback""" + from lib.file_watcher import DebouncedHandler + + call_count = 0 + done = threading.Event() + + def callback(path: Path, event_type: str): + nonlocal call_count + call_count += 1 + done.set() + + handler = DebouncedHandler(callback, debounce_ms=50) + + test_path = Path("/tmp/test.py") + handler._schedule_callback(test_path, "modified") + handler._schedule_callback(test_path, "modified") + handler._schedule_callback(test_path, "modified") + + assert done.wait(timeout=2), "callback never fired" + assert call_count == 1 + + def test_different_paths_not_debounced(self): + """events for different paths fire independently""" + from lib.file_watcher import DebouncedHandler + + paths_seen: list[str] = [] + done = threading.Event() + + def callback(path: Path, event_type: str): + paths_seen.append(str(path)) + if len(paths_seen) >= 2: + done.set() + + handler = DebouncedHandler(callback, debounce_ms=50) + + handler._schedule_callback(Path("/tmp/a.py"), "modified") + handler._schedule_callback(Path("/tmp/b.py"), "modified") + + assert done.wait(timeout=2), "not all callbacks fired" + assert len(paths_seen) == 2 + + +class TestExtensionFileWatcher: + def test_watcher_starts_and_stops(self, tmp_path): + from lib.file_watcher import ExtensionFileWatcher + + blocks_dir = tmp_path / "user_blocks" + blocks_dir.mkdir() + templates_dir = tmp_path / "user_templates" + templates_dir.mkdir() + + watcher = ExtensionFileWatcher( + block_registry=BlockRegistry(), + template_registry=TemplateRegistry(), + blocks_path=blocks_dir, + templates_path=templates_dir, + ) + + watcher.start() + assert watcher.is_running + + watcher.stop() + assert not watcher.is_running + + def test_watcher_noop_when_disabled(self, tmp_path, monkeypatch): + from lib.file_watcher import ExtensionFileWatcher + + monkeypatch.setenv("DATAGENFLOW_HOT_RELOAD", "false") + + blocks_dir = tmp_path / "user_blocks" + blocks_dir.mkdir() + templates_dir = tmp_path / "user_templates" + templates_dir.mkdir() + + watcher = ExtensionFileWatcher( + block_registry=BlockRegistry(), + template_registry=TemplateRegistry(), + blocks_path=blocks_dir, + templates_path=templates_dir, + ) + + watcher.start() + assert not watcher.is_running + + def test_stop_when_not_started_is_noop(self): + from lib.file_watcher import ExtensionFileWatcher + + watcher = ExtensionFileWatcher( + block_registry=BlockRegistry(), + template_registry=TemplateRegistry(), + ) + watcher.stop() # should not raise + assert not watcher.is_running diff --git a/tests/test_registry_regression.py b/tests/test_registry_regression.py new file mode 100644 index 0000000..27236c7 --- /dev/null +++ b/tests/test_registry_regression.py @@ -0,0 +1,82 @@ +""" +Registry regression tests — lock current BlockRegistry behavior before extensibility changes. +""" + +from lib.blocks.base import BaseBlock +from lib.blocks.registry import BlockRegistry +from lib.entities.extensions import BlockInfo + + +def test_registry_discovers_all_builtin_blocks(): + reg = BlockRegistry() + block_types = {b.type for b in reg.list_blocks()} + + expected = { + "TextGenerator", + "StructuredGenerator", + "ValidatorBlock", + "JSONValidatorBlock", + "FieldMapper", + "DiversityScore", + "CoherenceScore", + "RougeScore", + "RagasMetrics", + "MarkdownMultiplierBlock", + "DuplicateRemover", + "LangfuseDatasetBlock", + "StructureSampler", + "SemanticInfiller", + } + assert expected.issubset(block_types), f"Missing blocks: {expected - block_types}" + + +def test_registry_get_block_class_returns_class(): + reg = BlockRegistry() + cls = reg.get_block_class("TextGenerator") + assert cls is not None + assert issubclass(cls, BaseBlock) + + +def test_registry_get_block_class_returns_none_for_unknown(): + reg = BlockRegistry() + assert reg.get_block_class("DoesNotExist") is None + + +def test_registry_list_blocks_returns_block_info(): + reg = BlockRegistry() + blocks = reg.list_blocks() + assert isinstance(blocks, list) + for block in blocks: + assert isinstance(block, BlockInfo) + assert block.type + assert block.name + assert block.category + assert isinstance(block.inputs, list) + assert isinstance(block.outputs, list) + assert isinstance(block.config_schema, dict) + + +def test_registry_compute_accumulated_state_schema(): + reg = BlockRegistry() + blocks = [ + {"type": "TextGenerator", "config": {}}, + {"type": "ValidatorBlock", "config": {}}, + ] + fields = reg.compute_accumulated_state_schema(blocks) + assert isinstance(fields, list) + assert fields == sorted(fields), "fields should be sorted" + + +def test_registry_skips_base_classes(): + """BaseBlock and BaseMultiplierBlock should not be registered""" + reg = BlockRegistry() + block_types = {b.type for b in reg.list_blocks()} + assert "BaseBlock" not in block_types + assert "BaseMultiplierBlock" not in block_types + + +def test_registry_multiplier_block_detected(): + reg = BlockRegistry() + cls = reg.get_block_class("MarkdownMultiplierBlock") + assert cls is not None + assert getattr(cls, "is_multiplier", False) is True diff --git a/tests/test_template_enhanced.py b/tests/test_template_enhanced.py new file mode 100644 index 0000000..f7b7493 --- /dev/null +++ b/tests/test_template_enhanced.py @@ -0,0 +1,96 @@ +""" +Tests for enhanced TemplateRegistry: user templates dir, source tracking, register/unregister. +""" + +import yaml + +from lib.entities.extensions import TemplateInfo +from lib.templates import TemplateRegistry + + +def _write_template(path, name="Test Template", desc="A test", blocks=None): + blocks = blocks or [{"type": "TextGenerator", "config": {"temperature": 0.5}}] + data = {"name": name, "description": desc, "blocks": blocks} + path.write_text(yaml.dump(data)) + + +# --- source tracking --- + + +def test_builtin_templates_have_source(): + reg = TemplateRegistry() + templates = reg.list_templates() + for t in templates: + assert isinstance(t, TemplateInfo) + assert t.source == "builtin" + + +# --- user templates dir --- + + +def test_loads_user_templates(tmp_path): + user_dir = tmp_path / "user_templates" + user_dir.mkdir() + _write_template(user_dir / "my_custom.yaml", name="My Custom") + + reg = TemplateRegistry(user_templates_dir=user_dir) + templates = reg.list_templates() + ids = {t.id for t in templates} + assert "my_custom" in ids + + custom = next(t for t in templates if t.id == "my_custom") + assert custom.source == "user" + assert custom.name == "My Custom" + + +def test_user_templates_dont_override_builtin(tmp_path): + """if user template has same id as builtin, builtin wins""" + user_dir = tmp_path / "user_templates" + user_dir.mkdir() + _write_template(user_dir / "json_generation.yaml", name="Hijacked") + + reg = TemplateRegistry(user_templates_dir=user_dir) + t = reg.get_template("json_generation") + assert t is not None + assert t["name"] != "Hijacked" + + +# --- register / unregister --- + + +def test_register_user_template(): + reg = TemplateRegistry() + initial = len(reg.list_templates()) + reg.register( + "my_new", + {"name": "My New", "description": "New template", "blocks": []}, + source="user", + ) + assert len(reg.list_templates()) == initial + 1 + t = reg.get_template("my_new") + assert t is not None + assert t["name"] == "My New" + + +def test_unregister_template(): + reg = TemplateRegistry() + reg.register( + "to_remove", {"name": "Remove Me", "description": "...", "blocks": []}, source="user" + ) + assert reg.get_template("to_remove") is not None + + reg.unregister("to_remove") + assert reg.get_template("to_remove") is None + + +def test_unregister_nonexistent_is_noop(): + reg = TemplateRegistry() + reg.unregister("does_not_exist") # should not raise + + +def test_get_template_source(): + reg = TemplateRegistry() + reg.register("user_t", {"name": "U", "description": "...", "blocks": []}, source="user") + assert reg.get_template_source("json_generation") == "builtin" + assert reg.get_template_source("user_t") == "user" + assert reg.get_template_source("nonexistent") is None diff --git a/tests/test_template_regression.py b/tests/test_template_regression.py new file mode 100644 index 0000000..90d8367 --- /dev/null +++ b/tests/test_template_regression.py @@ -0,0 +1,57 @@ +""" +Template registry regression tests — lock current TemplateRegistry behavior. +""" + +from lib.entities.extensions import TemplateInfo +from lib.templates import TemplateRegistry, template_registry + + +def test_singleton_has_builtin_templates(): + templates = template_registry.list_templates() + ids = {t.id for t in templates} + assert "json_generation" in ids + assert "text_classification" in ids + assert "qa_generation" in ids + assert "ragas_evaluation" in ids + + +def test_get_template_returns_dict(): + t = template_registry.get_template("json_generation") + assert t is not None + assert "name" in t + assert "blocks" in t + + +def test_get_template_returns_none_for_unknown(): + assert template_registry.get_template("nonexistent") is None + + +def test_list_templates_shape(): + templates = template_registry.list_templates() + for t in templates: + assert isinstance(t, TemplateInfo) + assert t.id + assert t.name + assert t.description + + +def test_template_blocks_reference_valid_types(): + """all block types in templates must exist in the block registry""" + from lib.blocks.registry import BlockRegistry + + reg = BlockRegistry() + available = {b.type for b in reg.list_blocks()} + + for t in template_registry.list_templates(): + full = template_registry.get_template(t.id) + assert full is not None, f"Template '{t.id}' returned None" + for block in full["blocks"]: + assert block["type"] in available, ( + f"Template '{t.id}' references unknown block type '{block['type']}'" + ) + + +def test_template_registry_custom_dir(tmp_path): + """TemplateRegistry with empty dir returns no templates""" + reg = TemplateRegistry(templates_dir=tmp_path) + assert reg.list_templates() == [] diff --git a/tests/test_templates.py b/tests/test_templates.py index 436357b..1a52ae8 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -10,7 +10,7 @@ def test_template_registry_lists_all_templates(): """test that all three templates are registered""" templates = template_registry.list_templates() - template_ids = [t["id"] for t in templates] + template_ids = [t.id for t in templates] assert "json_generation" in template_ids assert "text_classification" in template_ids @@ -23,10 +23,10 @@ def test_templates_have_required_fields(): templates = template_registry.list_templates() for template in templates: - assert "id" in template - assert "name" in template - assert "description" in template - assert "example_seed" in template + assert template.id + assert template.name + assert template.description + assert template.example_seed is not None def test_template_seeds_use_content_field(): @@ -34,7 +34,7 @@ def test_template_seeds_use_content_field(): templates = template_registry.list_templates() for template in templates: - example_seed = template.get("example_seed") + example_seed = template.example_seed if example_seed: # seeds are arrays assert isinstance(example_seed, list) @@ -51,7 +51,7 @@ def test_template_seeds_use_content_field(): ) has_samples = "samples" in first_seed["metadata"] assert has_content or has_samples, ( - f"Template {template['id']} seed missing expected metadata fields" + f"Template {template.id} seed missing expected metadata fields" ) # ensure no old-style system/user fields