diff --git a/.test/README.md b/.test/README.md index d5c8fe46..015133ad 100644 --- a/.test/README.md +++ b/.test/README.md @@ -1,235 +1,718 @@ -# Skill Testing Framework +# Skill Evaluation & Optimization -Test Databricks skills with real execution on serverless compute. +Automatically evaluate and optimize SKILL.md files using [GEPA](https://github.com/gepa-ai/gepa) `optimize_anything` and MLflow judges. -**Note:** This framework is for contributors only and is not distributed via install_skills.sh. +## How It Works -## Setup +SKILL.md files teach AI agents (like Claude Code) how to use Databricks features. Every token in a skill consumes the agent's context window, so skills must be **correct** (teach the right patterns) and **concise** (waste no tokens). This framework measures both and uses GEPA to improve them. -```bash -uv pip install -e ".test/[dev]" -.test/install_skill_test.sh +### The Core Loop + +``` + ┌──────────────────────────────────────────────────┐ + │ GEPA optimize_anything │ + │ │ + │ seed_candidate ─► evaluator(candidate, task) │ + │ │ │ │ + │ │ (score, side_info) │ + │ │ │ │ + │ │ reflection LM reads │ + │ │ side_info rationale │ + │ │ │ │ + │ │ proposes mutation │ + │ │ │ │ + │ └──── best_candidate (Pareto frontier) ◄───┘│ + └──────────────────────────────────────────────────┘ ``` -Requires a Databricks workspace with serverless SQL/compute enabled. +**GEPA** ([Generalized Evolutionary Prompt Architect](https://github.com/gepa-ai/gepa)) treats the SKILL.md as a text artifact to optimize. Its `optimize_anything` API takes: +- A **seed candidate** (the current SKILL.md text) +- An **evaluator** function: `(candidate, task_example) -> (score, side_info)` +- A **dataset** of test cases from `ground_truth.yaml` ---- +GEPA's reflection LM reads the `side_info` diagnostics, proposes mutations, evaluates them, and selects the best via Pareto frontier. The critical insight: the richer the `side_info` diagnostics, the better GEPA's mutations. -## New Skill Journey +### Evaluation Methodology: How We Measure Skill Quality -Complete workflow for testing a skill from scratch (e.g., `mlflow-evaluation`). +Before understanding the judges and scoring, it's important to understand **what we're measuring and why the measurement is trustworthy**. -### 1. Initialize Test Scaffolding +#### The core question: "Does this skill actually help?" -``` -/skill-test init -``` +A SKILL.md is only valuable if an agent produces **better responses with the skill than without it**. This is a testable claim — we can generate responses both ways and compare. That comparison is the foundation of all evaluation and optimization in this framework. -Claude will: -1. Read the skill's SKILL.md documentation -2. Generate `manifest.yaml` with appropriate scorers -3. Create empty `ground_truth.yaml` and `candidates.yaml` templates -4. Recommend test prompts based on documentation +#### Two layers of comparison -### 2. Add Test Cases +There are two distinct comparisons happening — understanding both is key to reading the scores: -``` -/skill-test add -``` +1. **Within each evaluation** (WITH vs WITHOUT skill): measures whether a given SKILL.md adds value over a bare LLM. This is what `quality_with` and `quality_without` refer to. +2. **Across the optimization loop** (original vs optimized): measures whether GEPA's mutations improved the SKILL.md. This is what `original_score` vs `optimized_score` refer to. -Run this with the recommended prompts from init. Claude will: -1. Ask for your test prompt -2. Invoke the skill to generate a response -3. Execute code blocks on Databricks -4. Auto-save passing tests to `ground_truth.yaml` -5. Save failing tests to `candidates.yaml` for review +The first comparison runs inside the evaluator on every iteration. The second comparison runs in the runner to decide whether to keep GEPA's changes. -Repeat for each recommended prompt. +#### The WITH vs WITHOUT experimental design -### 3. Review Candidates +Every evaluation follows a controlled experiment that measures whether a specific SKILL.md candidate helps the LLM produce better responses: -``` -/skill-test review -``` +1. **WITH-skill trial** (`quality_with`) — An LLM generates a response with the SKILL.md injected as system context. The skill teaches the model Databricks-specific patterns, syntax, and constraints it wouldn't otherwise know. +2. **WITHOUT-skill trial** (`quality_without`) — The **same LLM** generates a response to the **same prompt** with **no SKILL.md in context**. This is the control — it shows what the model already knows on its own. **This is NOT "without optimization"** — it is the bare model with no skill document at all. +3. **Judge both** — An MLflow judge scores each response against the test case's expected facts, patterns, and guidelines, returning a 0.0-1.0 quality score plus a written rationale. -Review any tests that failed execution and were saved to candidates: -1. Load pending tests from `candidates.yaml` -2. Present each with prompt, response, and execution results -3. Allow you to approve, reject, skip, or edit -4. Promote approved candidates to `ground_truth.yaml` +The WITHOUT-skill response is **computed once and cached by prompt hash** — since the model and prompt don't change, the baseline is stable across all GEPA iterations. This means every candidate SKILL.md is compared against the same fixed control (the bare model). -For batch approval of successful tests: -``` -/skill-test review --batch --filter-success -``` +#### What "baseline score" means -### 4. Configure Scorers (Optional) +Before optimization begins, the runner evaluates the **original SKILL.md** on all training tasks using the WITH/WITHOUT protocol above. This produces: -``` -/skill-test scorers -``` +- A **per-task score** — the composite score (see [Scoring Weights](#scoring-weights)) for each test case +- A **mean baseline score** — the average across all tasks (e.g., `0.909`) +- **Diagnostic labels** — each task is classified: + - **OK** — skill helped (quality delta > +0.05) + - **NEEDS_SKILL** — WITH-skill quality is below 0.5 (skill isn't teaching enough) + - **REGRESSION** — skill actively hurt the response (quality delta < −0.05) + +This baseline tells you exactly where the skill stands *before* any optimization. -View current scorer configuration. To update: +#### What "improvement" means (the second layer) + +This is the **outer comparison** — original SKILL.md vs optimized SKILL.md. After GEPA produces an optimized candidate, it's re-evaluated on all training tasks using the same WITH/WITHOUT protocol. Improvement is the difference between the optimized mean score and the original mean score: ``` -/skill-test scorers update --add-guideline "Must use CLUSTER BY" +improvement = optimized_score - original_score ``` -Or edit `.test/skills//manifest.yaml` directly to: -- Add/remove scorers -- Update default guidelines -- Configure trace expectations +Both scores come from the same evaluator, which internally runs the WITH vs WITHOUT comparison. So "improvement" means the optimized SKILL.md produced a larger quality delta (WITH minus WITHOUT) than the original SKILL.md did — i.e., the optimized skill helps the LLM more than the original skill did. -### 5. Run Evaluation +This is **not** a subjective assessment. Both scores come from the same judges, same prompts, same cached WITHOUT-skill baselines. The only variable is the SKILL.md content. -``` -/skill-test run -``` +The composite score itself is a weighted combination of four dimensions (detailed in [Scoring Weights](#scoring-weights)): -Executes code blocks on Databricks or locally (depends on SKILLS, MCP, etc.) and reports pass/fail for each test in `ground_truth.yaml`. +| Dimension | What it measures | Why it matters | +|-----------|-----------------|----------------| +| **Skill Effectiveness (40%)** | `quality_with - quality_without` | The skill's unique contribution — what the model gets right *because* of the skill | +| **Absolute Quality (30%)** | `quality_with` score | Overall response quality with the skill present | +| **Structure (5%)** | Python/SQL syntax validity | Code in the skill must be syntactically correct | +| **Token Efficiency (25%)** | Token count vs original | Smaller skills save context window — candidates that shrink get a bonus up to 1.15x | -**Note:** Requires test cases in ground_truth.yaml (from steps 2-3). +A skill that scores 0.91 after optimization vs 0.88 at baseline has a measurable, reproducible improvement of +0.03 — driven by higher quality deltas, fewer regressions, or better token efficiency. -### 6. MLflow Evaluation (Optional) +#### Why this is rigorous, not made up +- **Same model, same prompts** — the only variable is the skill content, isolating its effect +- **Cached baselines** — WITHOUT-skill responses don't change between iterations, so score deltas are real +- **Judge rationale** — every score comes with a written explanation of which facts were present/missing and which patterns matched/failed, making scores auditable +- **Train/val split** — with 5+ test cases, stratified splitting prevents overfitting to the training set +- **Deterministic structure checks** — syntax validation and pattern adherence use regex/AST parsing, not LLM judgment + +### MLflow Judges as the Evaluator + +The evaluator uses [MLflow's `make_judge`](https://mlflow.org/docs/latest/llms/llm-evaluate/index.html) to score responses. Two judges run by default during optimization: + +| Judge | What it does | Returns | +|-------|-------------|---------| +| **quality_judge** | Scores a single response against expected facts, patterns, and guidelines | `float` (0.0-1.0) + rationale | +| **regression_judge** | Identifies specific ways the skill harms responses | `bool` + rationale of what to fix | + +Effectiveness is derived from the quality delta (`quality_with - quality_without`) — no separate LLM call needed. The `effectiveness_judge` is available in `judges.py` for standalone use but is not called during optimization. + +Each judge returns **full rationale** — not truncated — so GEPA's reflection LM sees exactly what failed and why: + +```python +side_info = { + "Judge_quality_with": { + "score": 0.65, + "rationale": "The response correctly uses CREATE OR REPLACE VIEW but misses " + "the MEASURE() wrapping requirement for measure references. " + "Pattern adherence: 2/3 found. Fact coverage: 3/5 present." + }, + "Judge_quality_without": { + "score": 0.2, + "rationale": "Without the skill, the model invented a non-existent " + "CREATE METRIC VIEW syntax. Only 1/5 expected facts present." + }, + "Judge_effectiveness": { + "verdict": "improved", + "delta": 0.45, + } +} ``` -/skill-test mlflow -``` -Runs full evaluation with LLM judges and logs results to MLflow. Provides deeper quality assessment beyond pass/fail execution. +### How Baseline Evaluation Works + +This section walks through how a single test case is evaluated end-to-end, from dataset loading through to the baseline score that GEPA uses for optimization. + +#### 1. Dataset Loading (`splitter.py`) + +- Loads `ground_truth.yaml` test cases via `create_gepa_datasets()` +- If >= 5 test cases: stratified train/val split by `metadata.category` (80/20 default) +- If < 5: all used as train, no val set (single-task mode) +- If no `ground_truth.yaml` exists: `generate_bootstrap_tasks()` auto-generates tasks from SKILL.md headers + +#### 2. Evaluator Construction (`skillbench_evaluator.py`) + +`create_skillbench_evaluator()` builds a `SkillBenchEvaluator` with: + +| Parameter | Purpose | +|-----------|---------| +| `gen_model` | LLM that generates responses (plays the role of the agent) | +| `original_token_counts` | Token count of the original SKILL.md (for efficiency scoring) | +| `skill_guidelines` | Deduplicated guidelines from all test cases (injected into quality judge) | +| `tool_context` | Read-only MCP tool descriptions (included in generation prompt but not mutated) | + +The evaluator instantiates two MLflow judges: `quality_judge` and `regression_judge`. + +#### 3. Per-Task Evaluation Flow (the `__call__` method) + +Each test case goes through four phases: + +1. **Phase 1: WITH-skill generation** -- Sends the SKILL.md + tool descriptions as system context, user prompt as user message, generates response at temperature=0 +2. **Phase 2: WITHOUT-skill generation** -- Same prompt, NO skill in context. Result is **cached by prompt hash** -- computed once and reused across all GEPA iterations (the baseline never changes) +3. **Phase 3: Judge scoring** -- `quality_judge` scores both responses against `expected_facts`, `expected_patterns`, and `guidelines` from the test case. WITHOUT-skill judge results are also cached. +4. **Phase 4: Compute composite score** -- Weighted combination of effectiveness delta, absolute quality, structure validation, and token efficiency + +#### 4. Baseline Scoring (`runner.py` step 5) + +Before optimization starts, `_evaluate_on_tasks()` runs the evaluator on ALL training tasks with the original SKILL.md: + +- Collects per-task scores and `side_info` diagnostics +- `build_skillbench_background()` summarizes: mean baseline score, which tasks are NEEDS_SKILL vs REGRESSION +- This baseline context tells GEPA's reflection LM what's already working and what needs improvement + +#### 5. Why This Matters for GEPA + +- The `side_info` dict returned per-task contains **full judge rationale** (not truncated) +- GEPA's reflection LM reads this rationale to understand exactly what failed +- Better diagnostics lead to more targeted mutations and faster convergence -### 7. Save Baseline +### Scoring Weights +| Weight | Dimension | Source | +|--------|-----------|--------| +| **40%** | Skill Effectiveness | `quality_with - quality_without` (the delta) | +| **30%** | Absolute Quality | `quality_with` score from judge | +| **5%** | Structure | Python/SQL syntax validation | +| **25%** | Token Efficiency | Smaller = higher score (bonus up to 1.15x) | + +### How Multi-Pass Optimization Works + +Optimization runs as a multi-pass loop where each pass feeds its best result into the next. This section explains what happens inside a single pass and how the runner decides when to stop. + +#### What happens inside a single GEPA pass + +GEPA's `optimize_anything` receives the seed candidate (current SKILL.md text), the evaluator, the training dataset, and the preset config. Within a pass, GEPA runs up to `max_metric_calls` iterations — **15** for `quick`, **50** for `standard`, **150** for `thorough`. + +Each iteration follows this cycle: + +1. **Reflect** — The reflection LM reads `side_info` from the previous evaluation. This includes the full judge rationale: which expected facts were missing, which regex patterns weren't found, which guidelines were violated, and whether regressions occurred. +2. **Mutate** — Based on the rationale, the reflection LM proposes a targeted mutation to the SKILL.md (or tool docstring). Mutations are surgical — informed by exactly what the judges flagged. +3. **Evaluate** — The evaluator scores the mutated candidate on a task from the dataset. This involves generating responses WITH the candidate, running MLflow judges, and computing the composite score. +4. **Select** — GEPA tracks a Pareto frontier of best candidates. If the mutation improves the frontier, it's kept; otherwise, it's discarded. + +The key insight: because `side_info` contains **full judge rationale** (not truncated summaries), the reflection LM sees exactly which facts were missed, which patterns were absent, and which regressions occurred — leading to more targeted mutations. + +#### How multi-pass works and when it stops + +The runner (`runner.py`) wraps GEPA in a multi-pass loop (default: up to 5 passes, controlled by `--max-passes`): + +1. **Pass N starts** — The best candidate from pass N-1 (or the original SKILL.md for pass 1) becomes the seed. +2. **GEPA optimizes** — Runs up to `max_metric_calls` iterations within the pass. +3. **Re-evaluate** — After the pass completes, the best candidate is re-evaluated on **all** training tasks to get a stable score. +4. **Compare** — The pass score is compared to the previous best score. +5. **Decision:** + - If improvement > **0.0005** (the `improvement_threshold`): the best candidate becomes the seed for pass N+1, and optimization continues. + - If improvement ≤ **0.0005**: early stop — no further passes are run. + +This creates a refinement chain: each pass starts from the previous pass's best, allowing incremental improvements that compound across passes. Early stopping prevents wasting compute when the skill has converged. + +#### Component scaling + +When optimizing multiple components (e.g., SKILL.md + tool modules with `--include-tools`), metric calls scale: + +- **Base formula:** `base_calls × num_components` +- **Per-preset caps:** quick → 45, standard → 150, thorough → 300 +- **Global cap:** 300 (applied for slower reflection models) +- **Round-robin:** GEPA's component selector alternates which component to mutate each iteration, so all components get roughly equal optimization effort. + +For example, with `--include-tools --tool-modules sql serving` (3 components: `skill_md` + `tools_sql` + `tools_serving`), a `quick` preset uses min(15 × 3, 45) = **45** metric calls per pass. + +--- + +## Quick Start + +```bash +# Install +uv pip install -e ".test/[all]" + +# Auth (pick one) +export DATABRICKS_API_KEY="dapi..." +export DATABRICKS_API_BASE="https://.cloud.databricks.com/serving-endpoints" +# OR +export OPENAI_API_KEY="sk-..." +export GEPA_REFLECTION_LM="openai/gpt-4o" +export GEPA_GEN_LM="openai/gpt-4o" + +# OR use Databricks AI Gateway (routes through a centralized gateway with rate limits and logging) +export DATABRICKS_API_KEY="dapi..." +export DATABRICKS_API_BASE="https://.ai-gateway.cloud.databricks.com/mlflow/v1/serving-endpoints" +# IMPORTANT: When using AI Gateway, OPENAI_API_KEY must also be set to your Databricks API token. +# The MLflow judges and litellm call OpenAI-compatible endpoints, which read OPENAI_API_KEY for auth. +export OPENAI_API_KEY="$DATABRICKS_API_KEY" + +# Optimize +uv run python .test/scripts/optimize.py databricks-metric-views --preset quick --apply ``` -/skill-test baseline + +--- + +## What Can Be Optimized + +GEPA treats any text artifact as a candidate for optimization. Skills and tools are optimized **separately** to avoid cross-skill interference. + +### Skills (SKILL.md files) — default mode + +SKILL.md files teach agents Databricks patterns — API syntax, code examples, best practices. Each skill is a standalone GEPA component (`skill_md`). Tool descriptions are loaded as **read-only context** — included in the generation prompt so the evaluator sees realistic agent behavior, but not mutated by GEPA. + +This means `--preset quick` always uses **1 component / 15 metric calls per pass**, regardless of how many tool modules exist. + +```bash +# Optimize a skill (tools loaded as read-only context) +uv run python .test/scripts/optimize.py databricks-metric-views --preset quick + +# Optimize all skills that have test cases +uv run python .test/scripts/optimize.py --all --preset quick ``` -Saves current metrics to `baselines//baseline.yaml`. +### MCP Tool Descriptions — `--tools-only` mode + +`@mcp.tool` docstrings in `databricks-mcp-server/` are what the agent sees when deciding which tool to call. Concise, accurate descriptions improve tool selection. Each tool module becomes a separate GEPA component (`tools_sql`, `tools_serving`, etc.). + +Tool optimization uses a **cross-skill dataset** — tasks are sampled from all skills with `ground_truth.yaml` — so optimized docstrings work well across skills, not just one. -### 8. Check Regressions +```bash +# Optimize tool descriptions with cross-skill evaluation +uv run python .test/scripts/optimize.py databricks-metric-views --tools-only -After skill changes: +# Optimize specific tool modules only +uv run python .test/scripts/optimize.py databricks-metric-views --tools-only --tool-modules sql serving compute ``` -/skill-test regression + +When applied (`--apply`), optimized docstrings are written back to the MCP server source files via AST, preserving all surrounding code. + +### Skills + Tools Together — `--include-tools` (advanced) + +For advanced use: optimize both skill and tool descriptions in a single GEPA run. Both are treated as GEPA components (round-robin mutation). Per-preset metric call caps prevent budget blowup. + +```bash +# Skill + specific tool modules +uv run python .test/scripts/optimize.py databricks-metric-views --include-tools --tool-modules sql + +# Dry run to see all components and their token counts +uv run python .test/scripts/optimize.py databricks-metric-views --include-tools --dry-run ``` -Compares current pass rate against the saved baseline. +Available tool modules: `agent_bricks`, `aibi_dashboards`, `apps`, `compute`, `file`, `genie`, `jobs`, `lakebase`, `manifest`, `pipelines`, `serving`, `sql`, `unity_catalog`, `user`, `vector_search`, `volume_files` --- -## Trace Evaluation (In Progress) +## Example Workflow: `databricks-metric-views` -Capture Claude Code sessions and evaluate against skill expectations. +This walks through the full lifecycle of evaluating and optimizing the metric views skill. -### Enable MLflow Tracing +### 1. Inspect the skill and test cases -```bash -export DATABRICKS_CONFIG_PROFILE=aws-apps -export MLFLOW_EXPERIMENT_NAME="/Users//Claude Code Skill Traces" +The skill lives at `databricks-skills/databricks-metric-views/SKILL.md`. Test cases live at `.test/skills/databricks-metric-views/ground_truth.yaml`: -pip install mlflow[databricks] -mlflow autolog claude -u databricks -n "$MLFLOW_EXPERIMENT_NAME" . +```yaml +test_cases: + - id: metric-views_create_sql_001 + inputs: + prompt: "Create a metric view for order analytics with revenue and order count measures" + outputs: + response: | + ```sql + CREATE OR REPLACE VIEW main.default.order_metrics + WITH METRICS LANGUAGE YAML + $$ + source: main.default.orders + dimensions: + - name: Order Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Total Revenue + expr: SUM(amount) + $$ + ``` + expectations: + expected_facts: + - "Uses CREATE OR REPLACE VIEW with WITH METRICS LANGUAGE YAML" + - "Defines dimensions with name and expr fields" + - "Defines measures with name and expr using aggregate functions" + expected_patterns: + - pattern: "WITH METRICS LANGUAGE YAML" + description: "Metric view DDL syntax" + - pattern: "MEASURE\\(" + description: "MEASURE() function for querying" + guidelines: + - "Must use WITH METRICS LANGUAGE YAML syntax" + - "Must define dimensions and measures in YAML block" + + - id: metric-views_query_measure_002 + inputs: + prompt: "Query a metric view to get total revenue and order count by month" + expectations: + expected_facts: + - "Uses MEASURE() function to reference measures" + - "SELECT * is NOT supported on metric views" + expected_patterns: + - pattern: "MEASURE\\(" + description: "MEASURE() wrapping for measures" + - pattern: "GROUP BY ALL" + description: "GROUP BY ALL for metric view queries" ``` -### Evaluate Traces +Each test case defines: +- **`inputs.prompt`** — what the user asks +- **`expectations.expected_facts`** — facts the response must mention +- **`expectations.expected_patterns`** — regex patterns the response must contain +- **`expectations.guidelines`** — soft rules for the MLflow quality judge -**Local trace file:** -``` -/skill-test trace-eval --trace ~/.claude/projects/.../session.jsonl +### 2. Dry run to check baseline + +```bash +uv run python .test/scripts/optimize.py databricks-metric-views --dry-run ``` -**From MLflow run ID** (from `mlflow.search_runs`): ``` -/skill-test trace-eval --run-id abc123 +=== Dry Run: databricks-metric-views (skillbench) === +SKILL.md path: databricks-skills/databricks-metric-views/SKILL.md +Components: ['skill_md'] +Total original tokens: 1,234 + skill_md: 1,234 tokens +Tool context (read-only): 16,757 tokens +Train tasks: 8 +Evaluator: skillbench (judge-driven) +Preset: quick (max_metric_calls=15, scaled for 1 component(s)) +Current score: 0.909 + metric-views_create_sql_001: 0.952 + metric-views_query_measure_002: 0.871 + metric-views_create_mcp_003: 0.934 + ... +``` + +The evaluator runs each test case **twice** — once WITH the skill in context and once WITHOUT — then judges the delta. Test case 002 scores lower because the MEASURE() wrapping example in the skill has a syntax gap. + +### 3. Run optimization + +```bash +uv run python .test/scripts/optimize.py databricks-metric-views --preset quick ``` -**From MLflow trace ID** (from `mlflow.get_trace`): +GEPA runs 15 iterations per component across up to 5 passes. Each iteration: +1. Mutates the SKILL.md based on judge rationale +2. Generates responses WITH the mutated skill +3. Judges score the responses +4. GEPA keeps mutations that improve the Pareto frontier + ``` -/skill-test trace-eval --trace-id tr-d416fccdab46e2dea6bad1d0bd8aaaa8 + Starting multi-pass optimization (up to 5 passes, 1 component(s), 15 metric calls/pass) + + --- Pass 1/5 (best score so far: 0.9090) --- + Pass 1 score: 0.9350 (delta: +0.0260) + + --- Pass 2/5 (best score so far: 0.9350) --- + No significant improvement in pass 2 -- stopping early. ``` -**List available traces:** +### 4. Review and apply + ``` -/skill-test list-traces --local -/skill-test list-traces --experiment "$MLFLOW_EXPERIMENT_NAME" +============================================================ + Optimization Results: databricks-metric-views +============================================================ + Score: 0.909 -> 0.935 (+0.026) + Skill Effectiveness: 0.42 + Quality (with): 0.78 + Quality (without): 0.36 (baseline) + Tokens: 1,234 -> 1,198 (-2.9%) + + Per-task: + metric-views_create_sql_001 WITH 0.85 WITHOUT 0.35 delta +0.50 [OK] + metric-views_query_measure_002 WITH 0.79 WITHOUT 0.22 delta +0.57 [OK] + ... + + Saved: .test/skills/databricks-metric-views/optimized_SKILL.md + Apply: uv run python .test/scripts/optimize.py databricks-metric-views --apply-last +============================================================ ``` -### Configure Expectations +Review the diff, then apply: -In `manifest.yaml`: -```yaml -scorers: - trace_expectations: - tool_limits: - Bash: 15 - mcp__databricks__execute_sql: 10 - token_budget: - max_total: 150000 - required_tools: - - Read - banned_tools: - - "DROP DATABASE" +```bash +# Review what changed +diff databricks-skills/databricks-metric-views/SKILL.md \ + .test/skills/databricks-metric-views/optimized_SKILL.md + +# Apply +uv run python .test/scripts/optimize.py databricks-metric-views --apply-last ``` --- -## Command Reference - -| Command | Description | -|---------|-------------| -| `run` | Execute tests against ground truth (default) | -| `init` | Generate test scaffolding from skill docs | -| `add` | Add test cases interactively | -| `review` | Review and promote candidates | -| `baseline` | Save current results as baseline | -| `regression` | Compare against baseline | -| `mlflow` | Full evaluation with LLM judges | -| `trace-eval` | Evaluate session traces | -| `list-traces` | List available traces | -| `scorers` | View/update scorer config | +## CLI Reference ---- +```bash +# Presets +uv run python .test/scripts/optimize.py --preset quick # 15 iterations +uv run python .test/scripts/optimize.py --preset standard # 50 iterations (default) +uv run python .test/scripts/optimize.py --preset thorough # 150 iterations -## Files +# Options +--dry-run # Show scores without optimizing +--apply # Run + apply immediately +--apply-last # Apply saved result without re-running +--gen-model "..." # Override generation model (default: databricks/databricks-claude-sonnet-4-6) +--reflection-lm "..." # Override reflection model (default: databricks/databricks-claude-opus-4-6) +--max-passes N # Max optimization passes (default: 5) +--token-budget N # Hard token ceiling +--include-tools # Include MCP tool descriptions as GEPA components (advanced) +--tool-modules sql ... # Specific tool modules to include +--tools-only # Optimize only tool descriptions (cross-skill evaluation) +--all # Optimize all skills with ground_truth.yaml +--run-dir DIR # Directory for GEPA checkpoints (resumes if dir exists) +# Test case generation +--generate-from FILE # Generate test cases from requirements file +--requirement "..." # Inline requirement (repeatable) ``` -.test/skills// -├── manifest.yaml # Scorers, guidelines, trace expectations -├── ground_truth.yaml # Verified test cases -└── candidates.yaml # Pending review -.test/baselines// -└── baseline.yaml # Regression baseline +### Flag Details + +- **`--dry-run`**: Runs baseline evaluation on all training tasks — scores the current SKILL.md WITH and WITHOUT the skill in context, shows per-task scores and a cost estimate, then exits without running optimization. Useful for checking your baseline before committing to a full run. + +- **`--apply`**: Runs optimization to completion, then immediately writes the optimized SKILL.md back to `databricks-skills/`. Combines `optimize` + `--apply-last` in one step. Use when you're confident in the preset and want a hands-off workflow. + +- **`--apply-last`**: Loads the previously saved `optimized_SKILL.md` and `last_optimization.json` from `.test/skills//` and writes the optimized content back to the repo. Does **not** re-run optimization. Use after reviewing a previous run's diff to confirm the changes look good. + +- **`--include-tools`**: Makes MCP tool docstrings optimizable GEPA components alongside SKILL.md. Both are mutated by GEPA via round-robin selection. Tool descriptions are no longer read-only context — they become first-class candidates. Metric calls scale with component count (see [Component scaling](#component-scaling)). + +- **`--tools-only`**: Drops SKILL.md entirely. Only tool module docstrings become GEPA components. Uses a **cross-skill dataset** (tasks sampled from ALL skills with `ground_truth.yaml`, max 5 per skill) so optimized descriptions generalize across skills rather than overfitting to one. + +- **`--tool-modules`**: Filters which tool modules are extracted for optimization. Without this flag, all modules are included. Example: `--tool-modules sql serving` optimizes only the `tools_sql` and `tools_serving` components. + +- **`--all`**: Discovers all skills with `ground_truth.yaml` in `.test/skills/`, runs optimization sequentially for each, and prints per-skill results plus a summary table at the end. + +- **`--run-dir`**: Enables GEPA checkpointing. Each pass saves state to `{run_dir}/pass_{N}/`. If the same `--run-dir` is passed on a subsequent run, GEPA resumes from the last checkpoint. Use `touch {run_dir}/pass_N/gepa.stop` for graceful mid-pass stop. + +- **`--max-passes`**: Maximum number of optimization passes (default 5). Each pass feeds the previous best as seed. Early stops if improvement falls below the threshold (0.0005). Lower values trade potential quality for faster completion. + +- **`--token-budget`**: Hard ceiling on candidate token count. The efficiency scorer penalizes candidates that exceed this budget. Also available via `GEPA_TOKEN_BUDGET` env var. + +### Model Configuration + +| Env Var | Default | Purpose | +|---------|---------|---------| +| `GEPA_GEN_LM` | `databricks/databricks-claude-sonnet-4-6` | Generation model (produces responses from skill) | +| `GEPA_REFLECTION_LM` | `databricks/databricks-claude-opus-4-6` | Reflection model (proposes mutations) | +| `GEPA_TOKEN_BUDGET` | none | Hard token ceiling for candidates | + +Model strings use [litellm provider prefixes](https://docs.litellm.ai/docs/providers): `databricks/`, `openai/`, `anthropic/`. + +--- + +## Resuming Long Runs + +GEPA saves optimization state to a run directory. If interrupted, resume from where you left off: + +```bash +# Start with checkpointing +uv run python .test/scripts/optimize.py databricks-metric-views \ + --preset standard --run-dir ./opt_runs/metric-views + +# Resume after interruption (same command) +uv run python .test/scripts/optimize.py databricks-metric-views \ + --preset standard --run-dir ./opt_runs/metric-views + +# Graceful stop (GEPA finishes current iteration then exits) +touch ./opt_runs/metric-views/pass_1/gepa.stop ``` +Each pass gets its own subdirectory (`pass_1/`, `pass_2/`, ...) so checkpoints are isolated per pass. + --- -## Test Case Format +## Writing Test Cases + +Test cases in `ground_truth.yaml` define what each skill should teach. Minimal example: ```yaml +metadata: + skill_name: my-skill + version: "1.0" + test_cases: - - id: "eval_basic_001" + - id: basic_001 inputs: - prompt: "Create a scorer for response length" + prompt: "Show me how to create a streaming table" outputs: response: | - ```python - @scorer - def response_length(outputs): - return Feedback(name="length", value=len(outputs["response"])) + ```sql + CREATE OR REFRESH STREAMING TABLE bronze_events + AS SELECT * FROM STREAM read_files('s3://bucket/events/') ``` - execution_success: true expectations: - expected_facts: ["@scorer", "Feedback"] - guidelines: ["Must use mlflow.genai.scorers"] -``` + expected_facts: + - "Uses CREATE OR REFRESH STREAMING TABLE syntax" + expected_patterns: + - pattern: "CREATE OR REFRESH STREAMING TABLE" + description: "SDP DDL syntax" + guidelines: + - "Must use SDP syntax, not legacy DLT syntax" + metadata: + category: happy_path +``` + +**Tips:** +- **5+ test cases** enables a train/val split for generalization +- **Cover categories**: happy_path, error_handling, edge cases — the splitter stratifies by `metadata.category` +- **`expected_patterns`** use regex — be specific (`"MEASURE\\("` not `".*MEASURE.*"`) +- **`guidelines`** are evaluated by the MLflow quality judge — use for soft expectations that can't be regex-matched +- **Generate from requirements**: `--requirement "Must explain MEASURE() wrapping"` auto-generates test cases --- -## CI/CD +## Test Case & Configuration Files -```bash -uv pip install -e ".test/" -uv run pytest .test/tests/ -uv run python .test/scripts/regression.py +Each skill under `.test/skills//` has two configuration files that drive evaluation and optimization. + +### `ground_truth.yaml` — What the skill must teach + +The evaluation dataset. Each test case represents a user prompt and the expected behavior when the skill is in context. + +**Full field schema:** + +| Field | Required | Description | +|-------|----------|-------------| +| `metadata.skill_name` | yes | Identifier matching the skill directory name | +| `metadata.version` | yes | Schema version (e.g., `"1.0"`) | +| `metadata.created_at` | no | ISO timestamp of creation | +| `test_cases[].id` | yes | Unique identifier (convention: `__`) | +| `test_cases[].inputs.prompt` | yes | The user question sent to the generation model | +| `test_cases[].outputs.response` | no | Expected reference answer. Used for judge comparison, **not** exact matching. Omit if you only want pattern/fact checks. | +| `test_cases[].expectations.expected_facts` | yes | List of factual claims the response must contain. The quality judge checks each one. | +| `test_cases[].expectations.expected_patterns` | no | Regex patterns with fields: `pattern`, `description`, and optionally `min_count` / `max_count`. Checked deterministically. | +| `test_cases[].expectations.guidelines` | no | Soft rules evaluated by the quality judge for things regex can't check (e.g., "Should explain why SELECT * doesn't work"). | +| `test_cases[].metadata.category` | recommended | Used for stratified train/val splitting. Common values: `happy_path`, `error_handling`, `advanced`, `conceptual`, `edge_case`. | + +**Example with all fields:** + +```yaml +metadata: + skill_name: databricks-metric-views + version: "1.0" + created_at: "2025-01-15T10:00:00Z" + +test_cases: + - id: metric-views_create_sql_001 + inputs: + prompt: "Create a metric view for order analytics" + outputs: + response: | + ```sql + CREATE OR REPLACE VIEW main.default.order_metrics + WITH METRICS LANGUAGE YAML + $$ + source: main.default.orders + measures: + - name: Total Revenue + expr: SUM(amount) + $$ + ``` + expectations: + expected_facts: + - "Uses CREATE OR REPLACE VIEW with WITH METRICS LANGUAGE YAML" + - "Defines measures with name and expr using aggregate functions" + expected_patterns: + - pattern: "WITH METRICS LANGUAGE YAML" + description: "Metric view DDL syntax" + min_count: 1 + - pattern: "MEASURE\\(" + description: "MEASURE() function for querying" + min_count: 0 + max_count: 5 + guidelines: + - "Must use WITH METRICS LANGUAGE YAML syntax, not CREATE METRIC VIEW" + - "Should include a complete YAML block between $$ delimiters" + metadata: + category: happy_path +``` + +### `manifest.yaml` — How to evaluate the skill + +Configures which scorers run and what quality thresholds apply during evaluation. + +**Full field schema:** + +| Field | Description | +|-------|-------------| +| `skill_name` | Identifier matching the skill directory name | +| `scorers.enabled` | List of deterministic scorers to run: `python_syntax`, `sql_syntax`, `pattern_adherence`, `no_hallucinated_apis`, `expected_facts_present` | +| `scorers.llm_scorers` | List of LLM-based scorers: `Safety`, `guidelines_from_expectations`, `Guidelines` | +| `scorers.default_guidelines` | Fallback guidelines applied when a test case doesn't specify its own `guidelines` field | +| `quality_gates` | Minimum score thresholds per scorer (e.g., `syntax_valid: 1.0`, `pattern_adherence: 0.9`). Failing a gate flags the test case. | +| `scorers.trace_expectations.tool_limits` | Max number of tool calls allowed (for trace-based scoring) | +| `scorers.trace_expectations.token_budget` | Max tokens allowed in the response | +| `scorers.trace_expectations.required_tools` | Tools that must be called (e.g., `["execute_sql"]`) | +| `scorers.trace_expectations.banned_tools` | Tools that must not be called | + +**Example:** + +```yaml +skill_name: databricks-metric-views + +scorers: + enabled: + - sql_syntax + - pattern_adherence + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - "Responses must use Databricks-specific syntax, not generic SQL" + - "Code examples must be runnable without modification" + +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + safety: 1.0 +``` + +--- + +## Architecture + +``` +.test/ +├── scripts/ +│ ├── optimize.py # CLI entry point +│ ├── generate_examples.py # Generate test cases from requirements +│ └── trace_to_examples.py # Extract test cases from MLflow traces +├── src/skill_test/optimize/ +│ ├── judges.py # MLflow make_judge factories (quality, effectiveness, regression) +│ ├── skillbench_evaluator.py # WITH vs WITHOUT evaluator using judges +│ ├── runner.py # GEPA optimize_anything orchestrator +│ ├── utils.py # Token counting, path resolution +│ ├── asi.py # MLflow Feedback → side_info conversion +│ ├── alignment.py # MemAlign judge alignment (future) +│ ├── config.py # GEPA presets, model registration +│ ├── splitter.py # Train/val dataset splitting +│ └── tools.py # MCP tool description extraction +├── src/skill_test/scorers/ +│ ├── universal.py # Deterministic: python_syntax, sql_syntax, etc. +│ ├── trace.py # Trace-based: tool_count, token_budget, etc. +│ └── routing.py # Skill routing accuracy (deprecated) +└── skills// + ├── ground_truth.yaml # Test cases + ├── manifest.yaml # Scorer configuration + ├── optimized_SKILL.md # Last optimization output + └── last_optimization.json # Metadata for --apply-last ``` diff --git a/.test/notebooks/gepa_skill_optimization_demo.ipynb b/.test/notebooks/gepa_skill_optimization_demo.ipynb new file mode 100644 index 00000000..8099460d --- /dev/null +++ b/.test/notebooks/gepa_skill_optimization_demo.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "{\n", + " \"cells\": [\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"# GEPA Skill Optimization Demo\\n\",\n", + " \"\\n\",\n", + " \"This notebook demonstrates how the skill-test framework uses [GEPA](https://github.com/gepa-ai/gepa) to automatically optimize Databricks SKILL.md files for **quality** and **token efficiency**.\\n\",\n", + " \"\\n\",\n", + " \"SKILL.md files teach AI agents (like Claude Code) Databricks patterns. Every token in a skill consumes agent context window budget, so skills should be as concise and high-quality as possible.\\n\",\n", + " \"\\n\",\n", + " \"**What GEPA does:**\\n\",\n", + " \"1. Scores the current SKILL.md against deterministic scorers (syntax, patterns, APIs, facts)\\n\",\n", + " \"2. Reflects on failures and proposes mutations to improve the skill\\n\",\n", + " \"3. Selects the best candidate via Pareto frontier optimization\\n\",\n", + " \"4. Repeats until quality converges or budget is exhausted\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Setup\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"import sys\\n\",\n", + " \"from pathlib import Path\\n\",\n", + " \"\\n\",\n", + " \"# Add skill-test to path\\n\",\n", + " \"repo_root = Path(\\\".\\\").resolve()\\n\",\n", + " \"while not (repo_root / \\\".test\\\" / \\\"src\\\").exists() and repo_root != repo_root.parent:\\n\",\n", + " \" repo_root = repo_root.parent\\n\",\n", + " \"sys.path.insert(0, str(repo_root / \\\".test\\\" / \\\"src\\\"))\\n\",\n", + " \"\\n\",\n", + " \"print(f\\\"Repo root: {repo_root}\\\")\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": \"import os\\n\\n# Configure the reflection model -- pick ONE:\\n\\n# Option A: Databricks Model Serving (default, recommended)\\n# IMPORTANT: DATABRICKS_API_BASE must end with /serving-endpoints\\n# os.environ[\\\"DATABRICKS_API_KEY\\\"] = \\\"dapi...\\\" \\n# os.environ[\\\"DATABRICKS_API_BASE\\\"] = \\\"https://.cloud.databricks.com/serving-endpoints\\\"\\n# os.environ[\\\"GEPA_REFLECTION_LM\\\"] = \\\"databricks/databricks-gpt-5-2\\\"\\n\\n# Option B: OpenAI\\n# os.environ[\\\"OPENAI_API_KEY\\\"] = \\\"sk-...\\\"\\n# os.environ[\\\"GEPA_REFLECTION_LM\\\"] = \\\"openai/gpt-4o\\\"\\n\\nprint(f\\\"Reflection LM: {os.environ.get('GEPA_REFLECTION_LM', 'databricks/databricks-gpt-5-2 (default)')}\\\")\"\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 1: Inspect the Skill\\n\",\n", + " \"\\n\",\n", + " \"Let's look at the `databricks-model-serving` skill -- its current size, test cases, and baseline score.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": \"SKILL_NAME = \\\"databricks-model-serving\\\"\\n\\nfrom skill_test.optimize.utils import find_skill_md, count_tokens\\nfrom skill_test.optimize.splitter import create_gepa_datasets\\n\\n# Load skill\\nskill_path = find_skill_md(SKILL_NAME)\\noriginal_content = skill_path.read_text()\\noriginal_tokens = count_tokens(original_content)\\n\\n# Load test cases\\ntrain, val = create_gepa_datasets(SKILL_NAME)\\n\\nprint(f\\\"Skill: {SKILL_NAME}\\\")\\nprint(f\\\"Path: {skill_path}\\\")\\nprint(f\\\"Lines: {len(original_content.splitlines())}\\\")\\nprint(f\\\"Tokens: {original_tokens:,}\\\")\\nprint(f\\\"Train cases: {len(train)}\\\")\\nprint(f\\\"Val cases: {len(val) if val else 'None'}\\\")\"\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"# Show first few test cases\\n\",\n", + " \"for t in train[:3]:\\n\",\n", + " \" print(f\\\"\\\\n--- {t['id']} ---\\\")\\n\",\n", + " \" print(f\\\"Prompt: {t['input'][:100]}...\\\")\\n\",\n", + " \" if t.get('answer'):\\n\",\n", + " \" print(f\\\"Answer: {t['answer'][:100]}...\\\")\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 2: Evaluate Current Quality (Baseline)\\n\",\n", + " \"\\n\",\n", + " \"Before optimizing, measure the current skill quality using the scorer pipeline.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": \"from skill_test.optimize.skillbench_evaluator import create_skillbench_evaluator\\nfrom skill_test.optimize.utils import SKILL_KEY\\nfrom skill_test.optimize.splitter import to_gepa_instances\\n\\n# Note: SkillBench evaluator requires a gen_model for WITH/WITHOUT comparison\\n# If you don't have API access, this cell will raise an error\\nGEN_MODEL = os.environ.get(\\\"GEPA_GEN_LM\\\", \\\"databricks/databricks-claude-sonnet-4-6\\\")\\n\\nevaluator = create_skillbench_evaluator(SKILL_NAME, gen_model=GEN_MODEL)\\nseed_candidate = {SKILL_KEY: original_content}\\n\\n# Evaluate on all train tasks\\ngepa_instances = to_gepa_instances(train)\\n\\nprint(f\\\"{'Task ID':<35} {'Score':>8}\\\")\\nprint(\\\"-\\\" * 45)\\nfor i, inst in enumerate(gepa_instances):\\n score, side_info = evaluator(seed_candidate, inst)\\n task_id = train[i]['id']\\n status = 'PASS' if score >= 0.5 else 'FAIL'\\n print(f\\\"{task_id:<35} {score:>7.3f} {status}\\\")\\n\\n# Quick baseline\\nscores = [evaluator(seed_candidate, inst)[0] for inst in gepa_instances]\\nbaseline_score = sum(scores) / len(scores)\\nprint(f\\\"\\\\nBaseline Score: {baseline_score:.3f}\\\")\\nprint(f\\\"Token Count: {original_tokens:,}\\\")\"\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 3: Run GEPA Optimization\\n\",\n", + " \"\\n\",\n", + " \"Now run the optimization. GEPA will:\\n\",\n", + " \"- Use the current SKILL.md as the seed candidate\\n\",\n", + " \"- Run scorers against each test case\\n\",\n", + " \"- Reflect on failures to propose mutations\\n\",\n", + " \"- Select the best candidate via Pareto frontier\\n\",\n", + " \"- Penalize token bloat (80% quality, 20% efficiency weighting)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"from skill_test.optimize.runner import optimize_skill\\n\",\n", + " \"\\n\",\n", + " \"result = optimize_skill(\\n\",\n", + " \" skill_name=SKILL_NAME,\\n\",\n", + " \" mode=\\\"static\\\",\\n\",\n", + " \" preset=\\\"quick\\\", # 15 iterations -- increase to \\\"standard\\\" (50) or \\\"thorough\\\" (150) for better results\\n\",\n", + " \")\\n\",\n", + " \"\\n\",\n", + " \"print(f\\\"Optimization complete!\\\")\\n\",\n", + " \"print(f\\\"GEPA metric calls: {result.gepa_result.total_metric_calls}\\\")\\n\",\n", + " \"print(f\\\"Candidates explored: {result.gepa_result.num_candidates}\\\")\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 4: Results Comparison\\n\",\n", + " \"\\n\",\n", + " \"Compare the original vs. optimized skill across quality and token efficiency.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"print(\\\"=\\\" * 60)\\n\",\n", + " \"print(f\\\" OPTIMIZATION RESULTS: {SKILL_NAME}\\\")\\n\",\n", + " \"print(\\\"=\\\" * 60)\\n\",\n", + " \"print()\\n\",\n", + " \"\\n\",\n", + " \"# Quality comparison\\n\",\n", + " \"quality_delta = result.improvement\\n\",\n", + " \"quality_pct = (quality_delta / result.original_score * 100) if result.original_score > 0 else 0\\n\",\n", + " \"print(f\\\" Quality Score\\\")\\n\",\n", + " \"print(f\\\" Before: {result.original_score:.3f}\\\")\\n\",\n", + " \"print(f\\\" After: {result.optimized_score:.3f}\\\")\\n\",\n", + " \"print(f\\\" Delta: {quality_delta:+.3f} ({quality_pct:+.1f}%)\\\")\\n\",\n", + " \"print()\\n\",\n", + " \"\\n\",\n", + " \"# Token comparison \\n\",\n", + " \"token_delta = result.original_token_count - result.optimized_token_count\\n\",\n", + " \"print(f\\\" Token Count\\\")\\n\",\n", + " \"print(f\\\" Before: {result.original_token_count:,}\\\")\\n\",\n", + " \"print(f\\\" After: {result.optimized_token_count:,}\\\")\\n\",\n", + " \"print(f\\\" Saved: {token_delta:,} tokens ({result.token_reduction_pct:.1f}% reduction)\\\")\\n\",\n", + " \"print()\\n\",\n", + " \"\\n\",\n", + " \"# Line count comparison\\n\",\n", + " \"orig_lines = len(result.original_content.splitlines())\\n\",\n", + " \"opt_lines = len(result.optimized_content.splitlines())\\n\",\n", + " \"print(f\\\" Lines\\\")\\n\",\n", + " \"print(f\\\" Before: {orig_lines}\\\")\\n\",\n", + " \"print(f\\\" After: {opt_lines}\\\")\\n\",\n", + " \"print(f\\\" Saved: {orig_lines - opt_lines} lines\\\")\\n\",\n", + " \"print()\\n\",\n", + " \"\\n\",\n", + " \"# Validation scores\\n\",\n", + " \"if result.val_scores:\\n\",\n", + " \" avg_val = sum(result.val_scores.values()) / len(result.val_scores)\\n\",\n", + " \" print(f\\\" Validation (held-out test cases)\\\")\\n\",\n", + " \" for tid, score in result.val_scores.items():\\n\",\n", + " \" print(f\\\" {tid}: {score:.3f}\\\")\\n\",\n", + " \" print(f\\\" Average: {avg_val:.3f}\\\")\\n\",\n", + " \"\\n\",\n", + " \"print()\\n\",\n", + " \"print(\\\"=\\\" * 60)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"# Visual comparison bar chart\\n\",\n", + " \"try:\\n\",\n", + " \" import matplotlib.pyplot as plt\\n\",\n", + " \" import matplotlib\\n\",\n", + " \" matplotlib.rcParams['font.family'] = 'monospace'\\n\",\n", + " \"\\n\",\n", + " \" fig, axes = plt.subplots(1, 2, figsize=(12, 5))\\n\",\n", + " \"\\n\",\n", + " \" # Quality scores\\n\",\n", + " \" ax = axes[0]\\n\",\n", + " \" bars = ax.bar(\\n\",\n", + " \" ['Before', 'After'],\\n\",\n", + " \" [result.original_score, result.optimized_score],\\n\",\n", + " \" color=['#d4534b', '#4a9c5d'],\\n\",\n", + " \" width=0.5\\n\",\n", + " \" )\\n\",\n", + " \" ax.set_ylim(0, 1.1)\\n\",\n", + " \" ax.set_ylabel('Quality Score')\\n\",\n", + " \" ax.set_title(f'Quality: {result.original_score:.3f} → {result.optimized_score:.3f}')\\n\",\n", + " \" for bar, val in zip(bars, [result.original_score, result.optimized_score]):\\n\",\n", + " \" ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,\\n\",\n", + " \" f'{val:.3f}', ha='center', fontweight='bold')\\n\",\n", + " \"\\n\",\n", + " \" # Token counts\\n\",\n", + " \" ax = axes[1]\\n\",\n", + " \" bars = ax.bar(\\n\",\n", + " \" ['Before', 'After'],\\n\",\n", + " \" [result.original_token_count, result.optimized_token_count],\\n\",\n", + " \" color=['#d4534b', '#4a9c5d'],\\n\",\n", + " \" width=0.5\\n\",\n", + " \" )\\n\",\n", + " \" ax.set_ylabel('Token Count')\\n\",\n", + " \" ax.set_title(f'Tokens: {result.original_token_count:,} → {result.optimized_token_count:,} ({result.token_reduction_pct:.0f}% reduction)')\\n\",\n", + " \" for bar, val in zip(bars, [result.original_token_count, result.optimized_token_count]):\\n\",\n", + " \" ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 50,\\n\",\n", + " \" f'{val:,}', ha='center', fontweight='bold')\\n\",\n", + " \"\\n\",\n", + " \" fig.suptitle(f'GEPA Optimization: {SKILL_NAME}', fontsize=14, fontweight='bold')\\n\",\n", + " \" plt.tight_layout()\\n\",\n", + " \" plt.show()\\n\",\n", + " \"except ImportError:\\n\",\n", + " \" print(\\\"(matplotlib not installed -- skipping chart)\\\")\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 5: Review the Diff\\n\",\n", + " \"\\n\",\n", + " \"Inspect what GEPA changed in the SKILL.md.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"from skill_test.optimize.review import review_optimization\\n\",\n", + " \"\\n\",\n", + " \"review_optimization(result)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Step 6: Apply (Optional)\\n\",\n", + " \"\\n\",\n", + " \"If the results look good, apply the optimized SKILL.md. Uncomment the cell below to write it.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"# Uncomment to apply:\\n\",\n", + " \"# from skill_test.optimize.review import apply_optimization\\n\",\n", + " \"# apply_optimization(result)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": \"## Multi-Component Optimization: Skills + Tools\\n\\nGEPA supports optimizing multiple text components simultaneously. You can optimize SKILL.md files **alongside** MCP tool descriptions in a single run.\\n\\nGEPA's `RoundRobinReflectionComponentSelector` cycles through components one at a time, so each gets dedicated reflection and mutation.\"\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": \"# Inspect available MCP tools\\nfrom skill_test.optimize.tools import get_tool_stats, extract_tool_descriptions, tools_to_gepa_components\\nfrom skill_test.optimize.utils import count_tokens\\n\\nstats = get_tool_stats()\\nprint(f\\\"MCP Tool Modules: {stats['modules']}\\\")\\nprint(f\\\"Total Tools: {stats['total_tools']}\\\")\\nprint(f\\\"Total Chars: {stats['total_description_chars']:,}\\\")\\nprint()\\nfor mod, info in stats[\\\"per_module\\\"].items():\\n print(f\\\" {mod:<20} {info['tools']:>2} tools {info['chars']:>6,} chars\\\")\\n\\n# Show what GEPA components look like for selected modules\\ntool_map = extract_tool_descriptions(modules=[\\\"serving\\\", \\\"sql\\\"])\\ncomponents = tools_to_gepa_components(tool_map, per_module=True)\\nprint(f\\\"\\\\nGEPA components for serving + sql: {list(components.keys())}\\\")\\nfor name, text in components.items():\\n print(f\\\" {name}: {count_tokens(text):,} tokens\\\")\"\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"source\": \"## Changing the Reflection Model\\n\\nBy default, GEPA uses `databricks/databricks-gpt-5-2` via Databricks Model Serving.\\nOverride per-call or via environment variable:\\n\\n```python\\n# Per-call\\nresult = optimize_skill(\\\"my-skill\\\", reflection_lm=\\\"openai/gpt-4o\\\")\\n\\n# Environment variable (persistent)\\nos.environ[\\\"GEPA_REFLECTION_LM\\\"] = \\\"databricks/databricks-gpt-5-2\\\"\\n```\\n\\nSee README.md for full model configuration options.\",\n", + " \"metadata\": {},\n", + " \"execution_count\": null,\n", + " \"outputs\": []\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## Summary\\n\",\n", + " \"\\n\",\n", + " \"The GEPA optimization pipeline:\\n\",\n", + " \"\\n\",\n", + " \"| Metric | Before | After | Change |\\n\",\n", + " \"|--------|--------|-------|--------|\\n\",\n", + " \"| Quality Score | `result.original_score` | `result.optimized_score` | `result.improvement` |\\n\",\n", + " \"| Token Count | `result.original_token_count` | `result.optimized_token_count` | `result.token_reduction_pct`% |\\n\",\n", + " \"\\n\",\n", + " \"Key points:\\n\",\n", + " \"- **Quality gate**: Existing scorers (syntax, patterns, APIs, facts) are reused as-is\\n\",\n", + " \"- **Token efficiency**: 80/20 quality/efficiency weighting penalizes bloated skills\\n\",\n", + " \"- **Validation split**: Held-out test cases detect overfitting\\n\",\n", + " \"- **Reflection LM**: Configurable via `--reflection-lm` flag or `GEPA_REFLECTION_LM` env var\\n\",\n", + " \"- **Default model**: `databricks/databricks-gpt-5-2` via Databricks Model Serving\"\n", + " ]\n", + " }\n", + " ],\n", + " \"metadata\": {\n", + " \"kernelspec\": {\n", + " \"display_name\": \"Python 3\",\n", + " \"language\": \"python\",\n", + " \"name\": \"python3\"\n", + " },\n", + " \"language_info\": {\n", + " \"name\": \"python\",\n", + " \"version\": \"3.12.0\"\n", + " }\n", + " },\n", + " \"nbformat\": 4,\n", + " \"nbformat_minor\": 4\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add skill-test to path\n", + "repo_root = Path(\".\").resolve()\n", + "while not (repo_root / \".test\" / \"src\").exists() and repo_root != repo_root.parent:\n", + " repo_root = repo_root.parent\n", + "sys.path.insert(0, str(repo_root / \".test\" / \"src\"))\n", + "\n", + "print(f\"Repo root: {repo_root}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Inspect the Skill\n", + "\n", + "Let's look at the `databricks-model-serving` skill -- its current size, test cases, and baseline score." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show first few test cases\n", + "for t in train[:3]:\n", + " print(f\"\\n--- {t['id']} ---\")\n", + " print(f\"Prompt: {t['input'][:100]}...\")\n", + " if t.get('answer'):\n", + " print(f\"Answer: {t['answer'][:100]}...\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from skill_test.optimize.evaluator import create_skill_evaluator, SKILL_KEY\nfrom skill_test.optimize.splitter import to_gepa_instances\n\nevaluator = create_skill_evaluator(SKILL_NAME)\nseed_candidate = {SKILL_KEY: original_content}\n\n# Evaluate on all train tasks\ngepa_instances = to_gepa_instances(train)\n\nprint(f\"{'Task ID':<35} {'Score':>8}\")\nprint(\"-\" * 45)\nfor i, inst in enumerate(gepa_instances):\n score, side_info = evaluator(seed_candidate, inst)\n task_id = train[i]['id']\n status = 'PASS' if score >= 0.5 else 'FAIL'\n print(f\"{task_id:<35} {score:>7.3f} {status}\")\n\n# Quick baseline\nscores = [evaluator(seed_candidate, inst)[0] for inst in gepa_instances]\nbaseline_score = sum(scores) / len(scores)\nprint(f\"\\nBaseline Score: {baseline_score:.3f}\")\nprint(f\"Token Count: {original_tokens:,}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skill_test.optimize.runner import optimize_skill\n", + "\n", + "result = optimize_skill(\n", + " skill_name=SKILL_NAME,\n", + " mode=\"static\",\n", + " preset=\"quick\", # 15 iterations -- increase to \"standard\" (50) or \"thorough\" (150) for better results\n", + ")\n", + "\n", + "print(f\"Optimization complete!\")\n", + "print(f\"GEPA metric calls: {result.gepa_result.total_metric_calls}\")\n", + "print(f\"Candidates explored: {result.gepa_result.num_candidates}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(f\" OPTIMIZATION RESULTS: {SKILL_NAME}\")\n", + "print(\"=\" * 60)\n", + "print()\n", + "\n", + "# Quality comparison\n", + "quality_delta = result.improvement\n", + "quality_pct = (quality_delta / result.original_score * 100) if result.original_score > 0 else 0\n", + "print(f\" Quality Score\")\n", + "print(f\" Before: {result.original_score:.3f}\")\n", + "print(f\" After: {result.optimized_score:.3f}\")\n", + "print(f\" Delta: {quality_delta:+.3f} ({quality_pct:+.1f}%)\")\n", + "print()\n", + "\n", + "# Token comparison \n", + "token_delta = result.original_token_count - result.optimized_token_count\n", + "print(f\" Token Count\")\n", + "print(f\" Before: {result.original_token_count:,}\")\n", + "print(f\" After: {result.optimized_token_count:,}\")\n", + "print(f\" Saved: {token_delta:,} tokens ({result.token_reduction_pct:.1f}% reduction)\")\n", + "print()\n", + "\n", + "# Line count comparison\n", + "orig_lines = len(result.original_content.splitlines())\n", + "opt_lines = len(result.optimized_content.splitlines())\n", + "print(f\" Lines\")\n", + "print(f\" Before: {orig_lines}\")\n", + "print(f\" After: {opt_lines}\")\n", + "print(f\" Saved: {orig_lines - opt_lines} lines\")\n", + "print()\n", + "\n", + "# Validation scores\n", + "if result.val_scores:\n", + " avg_val = sum(result.val_scores.values()) / len(result.val_scores)\n", + " print(f\" Validation (held-out test cases)\")\n", + " for tid, score in result.val_scores.items():\n", + " print(f\" {tid}: {score:.3f}\")\n", + " print(f\" Average: {avg_val:.3f}\")\n", + "\n", + "print()\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Review the Diff\n", + "\n", + "Inspect what GEPA changed in the SKILL.md." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Apply (Optional)\n", + "\n", + "If the results look good, apply the optimized SKILL.md. Uncomment the cell below to write it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Multi-Component Optimization: Skills + Tools\n\nGEPA supports optimizing multiple text components simultaneously. You can optimize SKILL.md files **alongside** MCP tool descriptions in a single run.\n\nGEPA's `RoundRobinReflectionComponentSelector` cycles through components one at a time, so each gets dedicated reflection and mutation." + }, + { + "cell_type": "code", + "source": "## Changing the Reflection Model\n\nBy default, GEPA uses `databricks/databricks-gpt-5-2` via Databricks Model Serving.\nOverride per-call or via environment variable:\n\n```python\n# Per-call\nresult = optimize_skill(\"my-skill\", reflection_lm=\"openai/gpt-4o\")\n\n# Environment variable (persistent)\nos.environ[\"GEPA_REFLECTION_LM\"] = \"databricks/databricks-gpt-5-2\"\n```\n\nSee README.md for full model configuration options.", + "metadata": {}, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/.test/pyproject.toml b/.test/pyproject.toml index 8feaeed3..88dd19ab 100644 --- a/.test/pyproject.toml +++ b/.test/pyproject.toml @@ -17,7 +17,10 @@ dependencies = [ [project.optional-dependencies] databricks = ["databricks-sdk>=0.20.0"] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] -all = ["skill-test[databricks,dev]"] +optimize = ["gepa>=0.1.0", "tiktoken>=0.7.0"] +# judges group: install separately when mlflow-deepeval/mlflow-ragas are published +# judges = ["mlflow-deepeval>=0.1.0", "mlflow-ragas>=0.1.0"] +all = ["skill-test[databricks,dev,optimize]"] [project.scripts] skill-test = "skill_test.cli:main" diff --git a/.test/scripts/add_example.py b/.test/scripts/add_example.py new file mode 100644 index 00000000..68f5c920 --- /dev/null +++ b/.test/scripts/add_example.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +"""Manually add a test case to a skill's ground_truth.yaml. + +Usage: + # Interactive mode — prompts for each field + uv run python .test/scripts/add_example.py databricks-model-serving + + # Inline mode — provide prompt and response directly + uv run python .test/scripts/add_example.py databricks-model-serving \ + --prompt "Create a ChatAgent with tool calling" \ + --response-file /path/to/response.md \ + --facts "Uses ChatAgent class" "Implements predict method" \ + --patterns "ChatAgent" "def predict" + + # From clipboard + uv run python .test/scripts/add_example.py databricks-model-serving --from-clipboard +""" + +import argparse +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _detect_language(response: str) -> str | None: + """Auto-detect code language from response code blocks.""" + langs = re.findall(r"```(\w+)\n", response) + if not langs: + return None + # Most common language wins + from collections import Counter + counts = Counter(l for l in langs if l != "text") + return counts.most_common(1)[0][0] if counts else None + + +def _auto_extract_patterns(response: str) -> list[str]: + """Extract patterns from code blocks.""" + patterns = [] + for match in re.finditer(r"```(?:python)\n(.*?)```", response, re.DOTALL): + code = match.group(1) + for m in re.finditer(r"class\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"def\s+(\w+)", code): + patterns.append(m.group(1)) + + for match in re.finditer(r"```(?:sql)\n(.*?)```", response, re.DOTALL): + code = match.group(1) + for m in re.finditer(r"(?:CREATE|ALTER)\s+(?:TABLE|VIEW)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + + return list(dict.fromkeys(patterns)) + + +def _next_id(skill_name: str, existing_ids: set[str]) -> str: + """Generate the next sequential ID for a skill.""" + prefix = skill_name.replace("-", "_") + idx = 1 + while True: + candidate = f"{prefix}_{idx:03d}" + if candidate not in existing_ids: + return candidate + idx += 1 + + +def _read_clipboard() -> str: + """Read text from system clipboard.""" + import subprocess + try: + result = subprocess.run(["pbpaste"], capture_output=True, text=True, check=True) + return result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + try: + result = subprocess.run( + ["xclip", "-selection", "clipboard", "-o"], + capture_output=True, text=True, check=True, + ) + return result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + print("Error: Could not read clipboard (tried pbpaste and xclip)") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Add a test case to a skill's ground_truth.yaml", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + help="Name of the skill (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--prompt", "-p", + default=None, + help="The user prompt for the test case", + ) + parser.add_argument( + "--response", "-r", + default=None, + help="The expected response text (inline)", + ) + parser.add_argument( + "--response-file", + type=Path, + default=None, + help="Path to a file containing the expected response", + ) + parser.add_argument( + "--facts", "-f", + nargs="*", + default=None, + help="Expected facts that must appear in the response", + ) + parser.add_argument( + "--patterns", + nargs="*", + default=None, + help="Expected patterns (regex) that must match in the response", + ) + parser.add_argument( + "--category", "-c", + default="happy_path", + help="Test case category (default: happy_path)", + ) + parser.add_argument( + "--from-clipboard", + action="store_true", + help="Read prompt and response from clipboard (separated by ---)", + ) + parser.add_argument( + "--id", + default=None, + help="Override the auto-generated test case ID", + ) + + args = parser.parse_args() + + import yaml + from skill_test.dataset import get_dataset_source, YAMLDatasetSource + + # Validate skill exists + skill_dir = Path(".test/skills") / args.skill_name + gt_path = skill_dir / "ground_truth.yaml" + + if not skill_dir.exists(): + print(f"Error: skill directory not found: {skill_dir}") + print("Available skills:") + for d in sorted(Path(".test/skills").iterdir()): + if d.is_dir() and not d.name.startswith("_"): + print(f" {d.name}") + sys.exit(1) + + # Load existing records + existing_ids = set() + if gt_path.exists(): + try: + source = YAMLDatasetSource(gt_path) + existing = source.load() + existing_ids = {r.id for r in existing} + except Exception: + pass + + # Get prompt + prompt = args.prompt + response = args.response + + if args.from_clipboard: + clipboard = _read_clipboard() + if "---" in clipboard: + parts = clipboard.split("---", 1) + prompt = parts[0].strip() + response = parts[1].strip() + else: + prompt = clipboard.strip() + print("Clipboard content set as prompt (no --- separator found for response)") + + if args.response_file: + response = args.response_file.read_text() + + if not prompt: + print("Enter the user prompt (Ctrl+D to finish):") + prompt = sys.stdin.read().strip() + + if not prompt: + print("Error: prompt is required") + sys.exit(1) + + if not response: + print("Enter the expected response (Ctrl+D to finish):") + response = sys.stdin.read().strip() + + # Generate ID + test_id = args.id or _next_id(args.skill_name, existing_ids) + + # Auto-extract patterns and facts + auto_patterns = _auto_extract_patterns(response) if response else [] + auto_facts = args.facts or [] + user_patterns = args.patterns or [] + + # Merge auto and user patterns + all_patterns = list(dict.fromkeys(user_patterns + auto_patterns)) + + # Detect language + language = _detect_language(response) if response else None + + # Build test case + test_case = { + "id": test_id, + "inputs": {"prompt": prompt}, + "metadata": { + "category": args.category, + "source": "manual", + }, + } + + if response: + test_case["outputs"] = {"response": response} + if language: + test_case["metadata"]["language"] = language + + expectations = {} + if auto_facts: + expectations["expected_facts"] = auto_facts + if all_patterns: + expectations["expected_patterns"] = all_patterns + if expectations: + test_case["expectations"] = expectations + + # Show summary + print(f"\n--- Test Case Preview ---") + print(f"ID: {test_id}") + print(f"Skill: {args.skill_name}") + print(f"Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + if response: + print(f"Response: {len(response)} chars") + if all_patterns: + print(f"Patterns: {all_patterns}") + if auto_facts: + print(f"Facts: {auto_facts}") + print(f"Category: {args.category}") + + # Confirm + if sys.stdin.isatty(): + confirm = input("\nAppend to ground_truth.yaml? [Y/n] ").strip().lower() + if confirm and confirm != "y": + print("Aborted.") + sys.exit(0) + + # Save + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + gt_path.parent.mkdir(parents=True, exist_ok=True) + data = {"test_cases": []} + + data["test_cases"].append(test_case) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"Added test case '{test_id}' to {gt_path}") + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/generate_examples.py b/.test/scripts/generate_examples.py new file mode 100644 index 00000000..6b0f764f --- /dev/null +++ b/.test/scripts/generate_examples.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +"""Generate test cases from user requirements for skill optimization. + +Converts natural-language requirements into ground_truth.yaml test cases +with binary assertions (patterns + facts) and guidelines for LLM judge +evaluation. + +Usage: + # From a requirements file (one requirement per line) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirements requirements.txt + + # Inline requirements (repeatable) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping for all measure references" \ + --requirement "Should show error handling when SELECT * is used on metric views" + + # Interactive mode (prompts for requirements) + uv run python .test/scripts/generate_examples.py databricks-metric-views --interactive + + # Auto-append to ground_truth.yaml (skip manual review) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping" --trust + + # With a second LLM pass to tighten assertions + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping" --refine +""" + +import argparse +import json +import re +import sys +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _find_skill_md(skill_name: str) -> str: + """Load SKILL.md content for the given skill.""" + from skill_test.optimize.utils import find_skill_md as find_md + path = find_md(skill_name) + if path is None: + raise FileNotFoundError(f"Could not find SKILL.md for '{skill_name}'") + return path.read_text() + + +def _load_existing_ids(skill_name: str) -> set[str]: + """Load existing test case IDs from ground_truth.yaml.""" + import yaml + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + if not gt_path.exists(): + return set() + with open(gt_path) as f: + data = yaml.safe_load(f) or {} + return {tc["id"] for tc in data.get("test_cases", []) if "id" in tc} + + +def generate_examples_from_requirements( + skill_name: str, + requirements: list[str], + skill_md: str, + gen_model: str, + count_per_requirement: int = 3, +) -> list[dict]: + """Generate test cases from requirements using an LLM. + + For each requirement, generates ``count_per_requirement`` test cases + grounded in the SKILL.md content. + + Returns: + List of test case dicts in ground_truth.yaml format. + """ + import litellm + + existing_ids = _load_existing_ids(skill_name) + all_examples: list[dict] = [] + + for req_idx, requirement in enumerate(requirements): + print(f"\n Generating for requirement {req_idx + 1}/{len(requirements)}:") + print(f" {requirement[:100]}") + + messages = [ + { + "role": "system", + "content": ( + "You are an expert test case generator for Databricks skills. " + "Given a SKILL.md document and a user requirement, generate " + f"{count_per_requirement} realistic test cases that evaluate " + "whether the skill satisfies the requirement.\n\n" + "Each test case must include:\n" + "- A realistic user prompt\n" + "- A reference response grounded in SKILL.md\n" + "- Binary assertions: expected_patterns (regex) and expected_facts (substring)\n" + "- Guidelines derived from the requirement (for LLM judge evaluation)\n" + "- Category and difficulty metadata\n\n" + "Return a JSON array of test cases. Each test case:\n" + "{\n" + ' "prompt": "user question",\n' + ' "response": "reference answer grounded in SKILL.md",\n' + ' "expected_patterns": [{"pattern": "regex", "min_count": 1, "description": "what it checks"}],\n' + ' "expected_facts": ["substring that must appear"],\n' + ' "guidelines": ["evaluation guideline from the requirement"],\n' + ' "category": "happy_path|edge_case|error_handling",\n' + ' "difficulty": "easy|medium|hard"\n' + "}\n\n" + "Important:\n" + "- Patterns should be regex that work with re.findall(pattern, response, re.IGNORECASE)\n" + "- Facts should be exact substrings (case-insensitive) from the response\n" + "- Guidelines should be evaluable by an LLM judge looking at the response\n" + "- Ground everything in SKILL.md — don't invent APIs or syntax" + ), + }, + { + "role": "user", + "content": ( + f"## SKILL.md for '{skill_name}':\n\n" + f"{skill_md[:8000]}\n\n" + f"## Requirement:\n{requirement}\n\n" + f"## Existing test case IDs (avoid duplicates):\n" + f"{', '.join(sorted(existing_ids)[:20]) or 'None'}\n\n" + f"Generate {count_per_requirement} test cases as a JSON array." + ), + }, + ] + + try: + resp = litellm.completion( + model=gen_model, + messages=messages, + temperature=0.7, + ) + content = resp.choices[0].message.content or "" + + # Extract JSON array from response + json_match = re.search(r"\[.*\]", content, re.DOTALL) + if not json_match: + print(f" Warning: no JSON array found in response") + continue + + cases = json.loads(json_match.group()) + if not isinstance(cases, list): + print(f" Warning: expected JSON array, got {type(cases)}") + continue + + for case_idx, case in enumerate(cases): + test_id = f"{skill_name}_gen_{req_idx:02d}_{case_idx:02d}" + # Avoid collisions with existing IDs + while test_id in existing_ids: + test_id += "_x" + existing_ids.add(test_id) + + example = { + "id": test_id, + "inputs": {"prompt": case.get("prompt", "")}, + "outputs": { + "response": case.get("response", ""), + "execution_success": True, + }, + "expectations": {}, + "metadata": { + "category": case.get("category", "happy_path"), + "difficulty": case.get("difficulty", "medium"), + "source": "generated_from_requirement", + "requirement": requirement[:200], + }, + } + + if case.get("expected_patterns"): + example["expectations"]["expected_patterns"] = case["expected_patterns"] + if case.get("expected_facts"): + example["expectations"]["expected_facts"] = case["expected_facts"] + if case.get("guidelines"): + example["expectations"]["guidelines"] = case["guidelines"] + + all_examples.append(example) + + print(f" Generated {len(cases)} test case(s)") + + except Exception as e: + print(f" Error generating for requirement: {e}") + + return all_examples + + +def refine_examples(examples: list[dict], gen_model: str) -> list[dict]: + """Second LLM pass to validate and tighten assertions.""" + import litellm + + for ex in examples: + prompt_text = ex["inputs"]["prompt"][:200] + response_text = ex["outputs"]["response"][:1000] + patterns = ex["expectations"].get("expected_patterns", []) + facts = ex["expectations"].get("expected_facts", []) + guidelines = ex["expectations"].get("guidelines", []) + + messages = [ + { + "role": "system", + "content": ( + "You are reviewing auto-generated test expectations. " + "Validate that patterns actually match the response, " + "facts are actually present as substrings, and guidelines " + "are clear and evaluable. Tighten or fix as needed. " + "Return JSON only." + ), + }, + { + "role": "user", + "content": ( + f"Prompt: {prompt_text}\n\n" + f"Response: {response_text}\n\n" + f"Patterns: {json.dumps(patterns)}\n" + f"Facts: {json.dumps(facts)}\n" + f"Guidelines: {json.dumps(guidelines)}\n\n" + "Return a JSON object with:\n" + '- "expected_patterns": refined list\n' + '- "expected_facts": refined list (must be exact substrings of response)\n' + '- "guidelines": refined list\n' + "Remove any patterns/facts that don't actually match the response." + ), + }, + ] + + try: + resp = litellm.completion(model=gen_model, messages=messages, temperature=0) + content = resp.choices[0].message.content or "" + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + refined = json.loads(json_match.group()) + if "expected_patterns" in refined: + ex["expectations"]["expected_patterns"] = refined["expected_patterns"] + if "expected_facts" in refined: + ex["expectations"]["expected_facts"] = refined["expected_facts"] + if "guidelines" in refined: + ex["expectations"]["guidelines"] = refined["guidelines"] + except Exception as e: + print(f" Warning: refinement failed for {ex['id']}: {e}") + + return examples + + +def save_candidates(examples: list[dict], skill_name: str) -> Path: + """Save generated examples to candidates.yaml for review.""" + import yaml + + output_path = Path(".test/skills") / skill_name / "candidates.yaml" + output_path.parent.mkdir(parents=True, exist_ok=True) + + data = {"test_cases": examples} + with open(output_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"\nSaved {len(examples)} examples to {output_path}") + print(f"Review and then append to ground_truth.yaml, or re-run with --trust") + return output_path + + +def append_to_ground_truth(examples: list[dict], skill_name: str) -> None: + """Append generated examples to ground_truth.yaml.""" + import yaml + + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + gt_path.parent.mkdir(parents=True, exist_ok=True) + + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + data = {"test_cases": []} + + existing_ids = {tc["id"] for tc in data.get("test_cases", []) if "id" in tc} + new_examples = [ex for ex in examples if ex["id"] not in existing_ids] + + if not new_examples: + print("No new examples to add (all IDs already exist).") + return + + data["test_cases"].extend(new_examples) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"\nAppended {len(new_examples)} examples to {gt_path}") + + +def run_generation( + skill_name: str, + requirements: list[str], + gen_model: str, + trust: bool = False, + refine: bool = False, + count_per_requirement: int = 3, +) -> list[dict]: + """Run the full generation pipeline. Called by optimize.py --generate-from. + + Returns: + List of generated test case dicts. + """ + skill_md = _find_skill_md(skill_name) + print(f"Generating test cases for '{skill_name}' from {len(requirements)} requirement(s)") + + examples = generate_examples_from_requirements( + skill_name=skill_name, + requirements=requirements, + skill_md=skill_md, + gen_model=gen_model, + count_per_requirement=count_per_requirement, + ) + + if not examples: + print("No examples generated.") + return [] + + if refine: + print("\nRefining assertions with LLM...") + examples = refine_examples(examples, gen_model) + + if trust: + append_to_ground_truth(examples, skill_name) + else: + save_candidates(examples, skill_name) + + return examples + + +def main(): + parser = argparse.ArgumentParser( + description="Generate test cases from requirements for skill optimization", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + help="Name of the skill (e.g., databricks-metric-views)", + ) + parser.add_argument( + "--requirements", + type=str, + default=None, + metavar="FILE", + help="Path to a requirements file (one requirement per line)", + ) + parser.add_argument( + "--requirement", + action="append", + default=None, + dest="inline_requirements", + help="Inline requirement (repeatable)", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Interactively enter requirements", + ) + parser.add_argument( + "--gen-model", + default=None, + help="LLM model for generation (default: GEPA_GEN_LM env or Sonnet)", + ) + parser.add_argument( + "--count", + type=int, + default=3, + help="Number of test cases per requirement (default: 3)", + ) + parser.add_argument( + "--trust", + action="store_true", + help="Auto-append to ground_truth.yaml instead of writing candidates.yaml", + ) + parser.add_argument( + "--refine", + action="store_true", + help="Second LLM pass to validate and tighten assertions", + ) + + args = parser.parse_args() + + # Collect requirements from all sources + requirements: list[str] = [] + + if args.requirements: + req_path = Path(args.requirements) + if not req_path.exists(): + print(f"Error: requirements file not found: {req_path}") + sys.exit(1) + requirements.extend( + line.strip() for line in req_path.read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + + if args.inline_requirements: + requirements.extend(args.inline_requirements) + + if args.interactive: + print("Enter requirements (one per line, empty line to finish):") + while True: + line = input(" > ").strip() + if not line: + break + requirements.append(line) + + if not requirements: + parser.error("Provide requirements via --requirements, --requirement, or --interactive") + + # Resolve gen_model + gen_model = args.gen_model + if gen_model is None: + from skill_test.optimize.config import DEFAULT_GEN_LM + gen_model = DEFAULT_GEN_LM + + run_generation( + skill_name=args.skill_name, + requirements=requirements, + gen_model=gen_model, + trust=args.trust, + refine=args.refine, + count_per_requirement=args.count, + ) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/generate_ground_truth.py b/.test/scripts/generate_ground_truth.py new file mode 100644 index 00000000..083b08a2 --- /dev/null +++ b/.test/scripts/generate_ground_truth.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +"""Generate ground_truth.yaml and manifest.yaml for skills missing test cases. + +Reads each SKILL.md, extracts code examples, headers, and key patterns, +then generates test cases that enable GEPA scorers to produce real signal. + +Usage: + # Generate for a specific skill + uv run python .test/scripts/generate_ground_truth.py databricks-metric-views + + # Generate for all missing skills + uv run python .test/scripts/generate_ground_truth.py --all + + # Preview without writing (dry run) + uv run python .test/scripts/generate_ground_truth.py --all --dry-run +""" + +import argparse +import re +import sys +import textwrap +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path + +import yaml + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class CodeBlock: + language: str + code: str + context: str # surrounding prose/header text + + +@dataclass +class Section: + level: int # 2 for ##, 3 for ### + title: str + content: str + code_blocks: list[CodeBlock] = field(default_factory=list) + + +@dataclass +class Pattern: + pattern: str + description: str + min_count: int = 1 + + +# --------------------------------------------------------------------------- +# SKILL.md parsing +# --------------------------------------------------------------------------- + +def extract_frontmatter(content: str) -> dict[str, str]: + """Extract YAML frontmatter from SKILL.md.""" + m = re.match(r"^---\s*\n(.*?)\n---", content, re.DOTALL) + if m: + try: + return yaml.safe_load(m.group(1)) or {} + except yaml.YAMLError: + return {} + return {} + + +def extract_code_blocks(content: str) -> list[CodeBlock]: + """Extract fenced code blocks with their language and surrounding context.""" + blocks = [] + # Match ```lang ... ``` + for m in re.finditer( + r"```(\w+)\s*\n(.*?)```", + content, + re.DOTALL, + ): + lang = m.group(1).lower() + code = m.group(2).strip() + # Get surrounding context (up to 200 chars before) + start = max(0, m.start() - 200) + ctx = content[start : m.start()].strip() + # Find the nearest header + header_match = re.search(r"#+\s+(.+)", ctx) + context = header_match.group(1) if header_match else ctx[-100:] if ctx else "" + blocks.append(CodeBlock(language=lang, code=code, context=context)) + return blocks + + +def extract_sections(content: str) -> list[Section]: + """Extract H2 and H3 sections with their content and code blocks.""" + # Remove frontmatter + content = re.sub(r"^---\s*\n.*?\n---\s*\n", "", content, flags=re.DOTALL) + + sections = [] + # Split by headers + parts = re.split(r"(^#{2,3}\s+.+$)", content, flags=re.MULTILINE) + + current_title = "" + current_level = 0 + current_content = "" + + for part in parts: + header_match = re.match(r"^(#{2,3})\s+(.+)$", part) + if header_match: + # Save previous section + if current_title: + code_blocks = extract_code_blocks(current_content) + sections.append(Section( + level=current_level, + title=current_title, + content=current_content.strip(), + code_blocks=code_blocks, + )) + current_level = len(header_match.group(1)) + current_title = header_match.group(2).strip() + current_content = "" + else: + current_content += part + + # Don't forget last section + if current_title: + code_blocks = extract_code_blocks(current_content) + sections.append(Section( + level=current_level, + title=current_title, + content=current_content.strip(), + code_blocks=code_blocks, + )) + + return sections + + +def extract_patterns_from_code(code: str, language: str) -> list[Pattern]: + """Extract function/class/keyword patterns from a code block.""" + patterns = [] + + if language in ("python", "py"): + # Function calls: word( + for m in re.finditer(r"\b([a-z_]\w+)\s*\(", code): + name = m.group(1) + if name not in ("print", "str", "int", "float", "len", "range", "list", + "dict", "set", "tuple", "type", "isinstance", "if", "for", + "while", "return", "import", "from", "as", "with", "round", + "max", "min", "abs", "sum", "enumerate", "zip", "map", + "filter", "sorted", "any", "all", "open", "format", "bool", + "append", "extend"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"Uses {name}()", + )) + # Class names: CapitalWord + for m in re.finditer(r"\b([A-Z][a-zA-Z]+(?:[A-Z][a-zA-Z]+)*)\b", code): + name = m.group(1) + if name not in ("True", "False", "None", "String", "Int", "Float", + "IMPORTANT", "NOTE", "WARNING", "TODO"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"References {name}", + )) + elif language in ("sql",): + # SQL keywords and functions + for m in re.finditer(r"\b(CREATE\s+(?:OR\s+REPLACE\s+)?(?:MATERIALIZED\s+VIEW|VIEW|TABLE|PROCEDURE|CONNECTION))", code, re.IGNORECASE): + patterns.append(Pattern( + pattern=m.group(1).replace(" ", " "), + description=f"Uses {m.group(1).strip()}", + )) + # Function calls in SQL + for m in re.finditer(r"\b([a-z_]\w+)\s*\(", code, re.IGNORECASE): + name = m.group(1).lower() + if name not in ("select", "from", "where", "and", "or", "not", + "group", "order", "having", "limit", "insert", + "update", "delete", "values", "into", "set", + "join", "on", "as", "case", "when", "then", + "else", "end", "in", "between", "like", "is", + "null", "exists", "count", "sum", "avg", "min", "max"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"Uses {name}() function", + )) + elif language in ("yaml", "yml"): + # Key YAML keys + for m in re.finditer(r"^\s*(\w[\w_-]+):", code, re.MULTILINE): + key = m.group(1) + if key not in ("name", "description", "type", "default", "value", + "true", "false"): + patterns.append(Pattern( + pattern=re.escape(key), + description=f"Includes {key} configuration", + )) + elif language in ("bash", "sh"): + # CLI commands + for m in re.finditer(r"\b(databricks\s+\w+(?:\s+\w+)?)", code): + patterns.append(Pattern( + pattern=re.escape(m.group(1)), + description=f"Uses {m.group(1)} command", + )) + + # Deduplicate by pattern string + seen = set() + unique = [] + for p in patterns: + if p.pattern not in seen: + seen.add(p.pattern) + unique.append(p) + return unique + + +def extract_facts_from_section(section: Section) -> list[str]: + """Extract key factual statements from a section's prose.""" + facts = [] + # Look for bullet points with key info + for line in section.content.split("\n"): + line = line.strip() + # Bullet points with bold terms + m = re.match(r"[-*]\s+\*\*(.+?)\*\*\s*[-:]\s*(.+)", line) + if m: + facts.append(f"{m.group(1)}: {m.group(2).strip()}") + continue + # Table rows with useful info + m = re.match(r"\|\s*`?(\w[\w_.-]+)`?\s*\|\s*(.+?)\s*\|", line) + if m and not m.group(1).startswith("-"): + facts.append(f"{m.group(1)}: {m.group(2).strip()}") + + return facts[:5] # Limit to top 5 + + +# --------------------------------------------------------------------------- +# Test case generation +# --------------------------------------------------------------------------- + +def generate_prompt_from_section(section: Section, skill_name: str) -> str: + """Generate a natural user prompt from a section's content.""" + title = section.title + + # Map section titles to natural prompts + prompt_templates = { + "Quick Start": f"Show me how to get started with {skill_name.replace('databricks-', '')}", + "Create": f"Create a {title.lower().replace('create ', '')}", + "Common Patterns": f"Show me common patterns for {skill_name.replace('databricks-', '')}", + "Configuration": f"How do I configure {skill_name.replace('databricks-', '')}?", + "Filtering": f"How do I filter results when querying?", + "Common Issues": f"What are common issues with {skill_name.replace('databricks-', '')}?", + } + + # Check if any template matches + for key, template in prompt_templates.items(): + if key.lower() in title.lower(): + return template + + # Generate from code blocks if present + if section.code_blocks: + block = section.code_blocks[0] + if block.language in ("python", "py"): + return f"Write Python code to {title.lower()}" + elif block.language == "sql": + return f"Write SQL to {title.lower()}" + elif block.language in ("yaml", "yml"): + return f"Show me the YAML configuration for {title.lower()}" + elif block.language in ("bash", "sh"): + return f"Show me the CLI commands to {title.lower()}" + + # Default: use section title + return f"How do I {title.lower()} with {skill_name.replace('databricks-', '')}?" + + +def generate_response_from_section(section: Section) -> str: + """Generate an expected response from a section's code blocks and content.""" + parts = [] + + # Add brief explanation from prose + prose_lines = [] + for line in section.content.split("\n"): + stripped = line.strip() + if stripped and not stripped.startswith("|") and not stripped.startswith("```"): + if not stripped.startswith("#"): + prose_lines.append(stripped) + if len(prose_lines) >= 3: + break + + if prose_lines: + parts.append(" ".join(prose_lines[:3])) + + # Add code blocks + for block in section.code_blocks[:2]: # Max 2 code blocks per test case + parts.append(f"\n```{block.language}\n{block.code}\n```") + + return "\n\n".join(parts) if parts else section.content[:500] + + +def build_test_case( + skill_name: str, + section: Section, + index: int, + difficulty: str = "easy", +) -> dict: + """Build a single test case from a section.""" + # Generate ID + feature = re.sub(r"[^a-z0-9]+", "_", section.title.lower()).strip("_")[:30] + case_id = f"{skill_name}_{feature}_{index:03d}" + + prompt = generate_prompt_from_section(section, skill_name) + response = generate_response_from_section(section) + + # Collect patterns from all code blocks + all_patterns = [] + for block in section.code_blocks: + all_patterns.extend(extract_patterns_from_code(block.code, block.language)) + + # Limit patterns to top 5 most relevant + patterns_data = [] + seen = set() + for p in all_patterns[:8]: + if p.pattern not in seen: + seen.add(p.pattern) + patterns_data.append({ + "pattern": p.pattern, + "min_count": p.min_count, + "description": p.description, + }) + if len(patterns_data) >= 5: + break + + # Extract facts + facts = extract_facts_from_section(section) + if not facts: + # Fall back to key terms from code blocks + for block in section.code_blocks: + if block.language in ("python", "py"): + facts.append(f"Uses Python {block.language}") + elif block.language == "sql": + facts.append("Uses SQL syntax") + + # Build guidelines from section context + guidelines = [] + if any(b.language in ("python", "py") for b in section.code_blocks): + guidelines.append("Code must be valid Python syntax") + if any(b.language == "sql" for b in section.code_blocks): + guidelines.append("SQL must follow Databricks SQL syntax") + if section.code_blocks: + guidelines.append("Response must include working code examples") + + return { + "id": case_id, + "inputs": {"prompt": prompt}, + "outputs": { + "response": response, + "execution_success": True, + }, + "expectations": { + "expected_facts": facts if facts else [], + "expected_patterns": patterns_data if patterns_data else [], + "guidelines": guidelines if guidelines else [], + }, + "metadata": { + "category": "happy_path", + "difficulty": difficulty, + "source": "auto_generated", + "section": section.title, + }, + } + + +def detect_languages(sections: list[Section]) -> set[str]: + """Detect which languages are used across all sections.""" + langs = set() + for s in sections: + for b in s.code_blocks: + langs.add(b.language) + return langs + + +def generate_manifest(skill_name: str, description: str, languages: set[str]) -> dict: + """Generate a manifest.yaml for a skill.""" + enabled_scorers = ["pattern_adherence", "no_hallucinated_apis", "expected_facts_present"] + if "python" in languages or "py" in languages: + enabled_scorers.insert(0, "python_syntax") + if "sql" in languages: + enabled_scorers.insert(0, "sql_syntax") + + default_guidelines = [ + "Response must address the user's request completely", + "Code examples must follow documented best practices", + "Response must use modern APIs (not deprecated ones)", + ] + + return { + "skill_name": skill_name, + "description": description or f"Test cases for {skill_name} skill", + "scorers": { + "enabled": enabled_scorers, + "llm_scorers": ["Safety", "guidelines_from_expectations"], + "default_guidelines": default_guidelines, + "trace_expectations": { + "tool_limits": {"Bash": 10, "Read": 20}, + "token_budget": {"max_total": 100000}, + "required_tools": ["Read"], + "banned_tools": [], + "expected_files": [], + }, + }, + "quality_gates": { + "syntax_valid": 1.0, + "pattern_adherence": 0.9, + "execution_success": 0.8, + }, + } + + +# --------------------------------------------------------------------------- +# Section selection: pick the best sections for test cases +# --------------------------------------------------------------------------- + +def select_sections_for_tests(sections: list[Section], target: int = 7) -> list[Section]: + """Select the best sections for test case generation. + + Prefers sections with code blocks and diverse topics. + """ + # Score sections by relevance + scored = [] + for s in sections: + score = 0 + # Sections with code are much more valuable + score += len(s.code_blocks) * 3 + # Prefer H2 over H3 + if s.level == 2: + score += 1 + # Skip meta sections + skip_titles = {"related skills", "resources", "reference files", "notes", + "common issues", "current limitations", "sdk version requirements", + "prerequisites", "prerequisites check", "when to use", + "environment configuration", "best practices"} + if s.title.lower() in skip_titles: + score -= 5 + # Boost pattern/example sections + if any(kw in s.title.lower() for kw in ("pattern", "example", "start", "create", "common")): + score += 2 + # Boost if has substantial content + if len(s.content) > 200: + score += 1 + + scored.append((score, s)) + + # Sort by score descending, take top N + scored.sort(key=lambda x: x[0], reverse=True) + selected = [s for _, s in scored if _ > 0][:target] + + return selected + + +# --------------------------------------------------------------------------- +# Main generation logic +# --------------------------------------------------------------------------- + +def generate_skill_tests(skill_name: str, repo_root: Path) -> tuple[list[dict], dict]: + """Generate test cases and manifest for a single skill. + + Returns: + (test_cases, manifest) tuple + """ + skill_md_path = repo_root / "databricks-skills" / skill_name / "SKILL.md" + if not skill_md_path.exists(): + raise FileNotFoundError(f"SKILL.md not found: {skill_md_path}") + + content = skill_md_path.read_text() + frontmatter = extract_frontmatter(content) + description = frontmatter.get("description", "") + sections = extract_sections(content) + + # Select best sections for test cases + selected = select_sections_for_tests(sections, target=7) + + if not selected: + raise ValueError(f"No suitable sections found in {skill_md_path}") + + # Generate test cases + test_cases = [] + difficulties = ["easy", "easy", "easy", "medium", "medium", "medium", "hard", "hard"] + for i, section in enumerate(selected): + difficulty = difficulties[i] if i < len(difficulties) else "medium" + tc = build_test_case(skill_name, section, i + 1, difficulty) + test_cases.append(tc) + + # Generate manifest + languages = detect_languages(sections) + manifest = generate_manifest(skill_name, description, languages) + + return test_cases, manifest + + +def write_skill_tests( + skill_name: str, + test_cases: list[dict], + manifest: dict, + output_dir: Path, + dry_run: bool = False, +) -> None: + """Write ground_truth.yaml and manifest.yaml for a skill.""" + skill_dir = output_dir / skill_name + skill_dir.mkdir(parents=True, exist_ok=True) + + gt_data = { + "metadata": { + "skill_name": skill_name, + "version": "0.1.0", + "created_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f"), + }, + "test_cases": test_cases, + } + + gt_path = skill_dir / "ground_truth.yaml" + manifest_path = skill_dir / "manifest.yaml" + + if dry_run: + print(f" [DRY RUN] Would write {gt_path} ({len(test_cases)} test cases)") + print(f" [DRY RUN] Would write {manifest_path}") + return + + # Custom YAML representer for multiline strings + class MultilineDumper(yaml.SafeDumper): + pass + + def str_representer(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + MultilineDumper.add_representer(str, str_representer) + + with open(gt_path, "w") as f: + yaml.dump(gt_data, f, Dumper=MultilineDumper, default_flow_style=False, + sort_keys=False, allow_unicode=True, width=120) + + with open(manifest_path, "w") as f: + yaml.dump(manifest, f, Dumper=MultilineDumper, default_flow_style=False, + sort_keys=False, allow_unicode=True, width=120) + + print(f" Wrote {gt_path} ({len(test_cases)} test cases)") + print(f" Wrote {manifest_path}") + + +# --------------------------------------------------------------------------- +# Skills that are missing test cases +# --------------------------------------------------------------------------- + +MISSING_SKILLS = [ + "databricks-app-apx", + "databricks-config", + "databricks-dbsql", + "databricks-docs", + "databricks-jobs", + "databricks-lakebase-autoscale", + "databricks-lakebase-provisioned", + "databricks-metric-views", + "databricks-mlflow-evaluation", + "databricks-python-sdk", + "databricks-spark-structured-streaming", + "databricks-synthetic-data-generation", + "databricks-unity-catalog", + "databricks-unstructured-pdf-generation", + "databricks-vector-search", +] + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Generate ground_truth.yaml test cases for skills", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + nargs="?", + help="Skill name (e.g., databricks-metric-views)", + ) + parser.add_argument( + "--all", + action="store_true", + help="Generate for all missing skills", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview without writing files", + ) + parser.add_argument( + "--force", + action="store_true", + help="Overwrite existing ground_truth.yaml files", + ) + + args = parser.parse_args() + + if not args.skill_name and not args.all: + parser.error("Provide a skill name or use --all") + + # Find repo root + repo_root = Path(__file__).resolve().parent.parent.parent + if not (repo_root / "databricks-skills").exists(): + print(f"ERROR: databricks-skills/ not found at {repo_root}", file=sys.stderr) + sys.exit(1) + + output_dir = repo_root / ".test" / "skills" + + # Determine which skills to process + if args.all: + skills = MISSING_SKILLS + else: + skills = [args.skill_name] + + success = 0 + errors = 0 + + for skill_name in skills: + print(f"\n{'=' * 50}") + print(f" {skill_name}") + print(f"{'=' * 50}") + + # Check if already exists + if not args.force and (output_dir / skill_name / "ground_truth.yaml").exists(): + print(f" SKIP: ground_truth.yaml already exists (use --force to overwrite)") + continue + + try: + test_cases, manifest = generate_skill_tests(skill_name, repo_root) + write_skill_tests(skill_name, test_cases, manifest, output_dir, dry_run=args.dry_run) + success += 1 + except Exception as e: + print(f" ERROR: {e}") + errors += 1 + + print(f"\n{'=' * 50}") + print(f" Done: {success} generated, {errors} errors") + print(f"{'=' * 50}") + + sys.exit(1 if errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/optimize.py b/.test/scripts/optimize.py new file mode 100644 index 00000000..622f7587 --- /dev/null +++ b/.test/scripts/optimize.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""CLI entry point for GEPA skill optimization. + +Usage: + # Standard workflow: evaluate + optimize a skill + uv run python .test/scripts/optimize.py databricks-metric-views + + # Quick pass (15 iterations) + uv run python .test/scripts/optimize.py databricks-metric-views --preset quick + + # Thorough optimization (150 iterations) + uv run python .test/scripts/optimize.py databricks-metric-views --preset thorough + + # Dry run (show config, dataset info, estimate cost) + uv run python .test/scripts/optimize.py databricks-metric-views --dry-run + + # Review the saved result then apply (no re-run needed) + uv run python .test/scripts/optimize.py databricks-metric-views --apply-last + + # Run optimization and immediately apply + uv run python .test/scripts/optimize.py databricks-metric-views --apply + + # Optimize all skills that have ground_truth.yaml test cases + uv run python .test/scripts/optimize.py --all +""" + +import argparse +import sys +from pathlib import Path + +# Setup path using shared utilities +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path, handle_error, print_result + +setup_path() + + +def main(): + parser = argparse.ArgumentParser( + description="Optimize Databricks skills using GEPA", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + nargs="?", + help="Name of the skill to optimize (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--all", + action="store_true", + help="Optimize all skills that have ground_truth.yaml", + ) + parser.add_argument( + "--preset", "-p", + choices=["quick", "standard", "thorough"], + default="standard", + help="GEPA optimization preset (default: standard)", + ) + parser.add_argument( + "--gen-model", + default=None, + help="LLM model for generative evaluation (default: GEPA_GEN_LM env or " + "databricks/databricks-claude-sonnet-4-6).", + ) + parser.add_argument( + "--reflection-lm", + default=None, + help="Override GEPA reflection model (default: GEPA_REFLECTION_LM env or databricks/databricks-claude-opus-4-6)", + ) + parser.add_argument( + "--judge-model", + default=None, + help="Override judge model for quality/effectiveness evaluation (future use)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show config and cost estimate without running optimization", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Run optimization and immediately apply the result", + ) + parser.add_argument( + "--apply-last", + action="store_true", + help="Apply the last saved optimization result without re-running " + "(reads from .test/skills//optimized_SKILL.md)", + ) + parser.add_argument( + "--include-tools", + action="store_true", + help="Include MCP tool descriptions as additional optimization components", + ) + parser.add_argument( + "--tool-modules", + nargs="*", + default=None, + help="Specific tool modules to optimize (e.g., sql compute serving). Default: all.", + ) + parser.add_argument( + "--tools-only", + action="store_true", + help="Optimize ONLY tool descriptions, not the SKILL.md", + ) + parser.add_argument( + "--max-passes", + type=int, + default=5, + help="Maximum optimization passes per component (default: 5).", + ) + parser.add_argument( + "--max-metric-calls", + type=int, + default=None, + help="Override max metric calls per pass (default: auto-scaled by preset).", + ) + parser.add_argument( + "--token-budget", + type=int, + default=None, + help="Token budget ceiling. Candidates exceeding this are penalized.", + ) + parser.add_argument( + "--align", + action="store_true", + help="Use MemAlign to align judges with human feedback (requires alignment traces)", + ) + parser.add_argument( + "--run-dir", + type=str, + default=None, + help="Directory for GEPA checkpoints. Resumes from last state if dir exists.", + ) + parser.add_argument( + "--generate-from", + type=str, + default=None, + metavar="REQUIREMENTS_FILE", + help="Generate test cases from a requirements file before optimizing.", + ) + parser.add_argument( + "--requirement", + action="append", + default=None, + dest="requirements", + help="Inline requirement for test case generation (repeatable).", + ) + + args = parser.parse_args() + + if not args.skill_name and not args.all: + parser.error("Either provide a skill name or use --all") + + from skill_test.optimize.runner import optimize_skill + from skill_test.optimize.review import review_optimization, apply_optimization, load_last_result + + # Handle requirements-driven example generation + if args.generate_from or args.requirements: + if not args.skill_name: + parser.error("Test case generation requires a skill name") + requirements = [] + if args.generate_from: + req_path = Path(args.generate_from) + if not req_path.exists(): + print(f"Error: requirements file not found: {req_path}") + sys.exit(1) + requirements.extend( + line.strip() for line in req_path.read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + if args.requirements: + requirements.extend(args.requirements) + if requirements: + from generate_examples import run_generation + gen_model = args.gen_model + if gen_model is None: + from skill_test.optimize.config import DEFAULT_GEN_LM + gen_model = DEFAULT_GEN_LM + run_generation( + skill_name=args.skill_name, + requirements=requirements, + gen_model=gen_model, + trust=True, + ) + print() + + # Handle --apply-last: load saved result and apply without re-running + if args.apply_last: + if not args.skill_name: + parser.error("--apply-last requires a skill name") + result = load_last_result(args.skill_name) + if result is None: + print(f"No saved optimization found for '{args.skill_name}'.") + print(f"Run optimization first: uv run python .test/scripts/optimize.py {args.skill_name}") + sys.exit(1) + print(f"Applying saved optimization for '{args.skill_name}':") + print(f" Score: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({result.improvement:+.3f})") + print(f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,}") + try: + apply_optimization(result) + sys.exit(0) + except Exception as e: + print(f"Error applying: {e}") + sys.exit(1) + + if args.all: + # Find all skills with ground_truth.yaml + skills_dir = Path(".test/skills") + skill_names = [ + d.name + for d in sorted(skills_dir.iterdir()) + if d.is_dir() and (d / "ground_truth.yaml").exists() and not d.name.startswith("_") + ] + print(f"Found {len(skill_names)} skills to optimize: {', '.join(skill_names)}\n") + + results = [] + for name in skill_names: + print(f"\n{'=' * 60}") + print(f" Optimizing: {name}") + print(f"{'=' * 60}") + try: + result = optimize_skill( + skill_name=name, + preset=args.preset, + gen_model=args.gen_model, + reflection_lm=args.reflection_lm, + include_tools=args.include_tools, + tool_modules=args.tool_modules, + tools_only=args.tools_only, + dry_run=args.dry_run, + max_passes=args.max_passes, + max_metric_calls=args.max_metric_calls, + token_budget=args.token_budget, + judge_model=args.judge_model, + align=args.align, + run_dir=f"{args.run_dir}/{name}" if args.run_dir else None, + ) + review_optimization(result) + if args.apply and not args.dry_run: + apply_optimization(result) + results.append({"skill": name, "success": True, "improvement": result.improvement}) + except Exception as e: + print(f" ERROR: {e}") + results.append({"skill": name, "success": False, "error": str(e)}) + + # Summary + print(f"\n{'=' * 60}") + print(" Summary") + print(f"{'=' * 60}") + for r in results: + status = "OK" if r["success"] else "FAIL" + detail = f"+{r['improvement']:.3f}" if r["success"] else r["error"] + print(f" [{status}] {r['skill']}: {detail}") + + sys.exit(0 if all(r["success"] for r in results) else 1) + + else: + try: + result = optimize_skill( + skill_name=args.skill_name, + preset=args.preset, + gen_model=args.gen_model, + reflection_lm=args.reflection_lm, + include_tools=args.include_tools, + tool_modules=args.tool_modules, + tools_only=args.tools_only, + dry_run=args.dry_run, + max_passes=args.max_passes, + max_metric_calls=args.max_metric_calls, + token_budget=args.token_budget, + judge_model=args.judge_model, + align=args.align, + run_dir=args.run_dir, + ) + review_optimization(result) + if args.apply and not args.dry_run: + apply_optimization(result) + sys.exit(0) + except Exception as e: + sys.exit(handle_error(e, args.skill_name)) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/trace_to_examples.py b/.test/scripts/trace_to_examples.py new file mode 100644 index 00000000..cb056b3f --- /dev/null +++ b/.test/scripts/trace_to_examples.py @@ -0,0 +1,710 @@ +#!/usr/bin/env python3 +"""Extract test cases from Claude Code traces (local or MLflow). + +Parses real agent transcripts and auto-generates ground_truth.yaml entries +from user prompt / assistant response pairs. + +Supports three trace sources: +1. Local session.jsonl files (--trace) +2. MLflow experiment traces (--experiment-id) +3. Individual MLflow traces (--trace-id or --run-id) + +Usage: + # --- From local session.jsonl --- + uv run python .test/scripts/trace_to_examples.py \ + --trace ~/.claude/projects/.../session.jsonl \ + --skill databricks-model-serving + + # --- From MLflow experiment (browse traces, pick best ones) --- + # List recent traces in an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 --list + + # Extract from all recent traces in an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving + + # Extract from a specific MLflow run + uv run python .test/scripts/trace_to_examples.py \ + --run-id abc123def456 \ + --skill databricks-model-serving + + # Extract from a specific MLflow trace ID + uv run python .test/scripts/trace_to_examples.py \ + --trace-id tr-d416fccdab46e2dea6bad1d0bd8aaaa8 \ + --skill databricks-model-serving + + # --- Common options --- + # With LLM refinement of expectations + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --refine + + # Auto-append to ground_truth.yaml (skip manual review) + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --trust + + # Limit number of traces to process from an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --limit 5 +""" + +import argparse +import re +import sys +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _extract_text_content(message: dict) -> str: + """Extract text from a message's content array.""" + content = message.get("content", []) + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + return "\n".join(parts) + return "" + + +def _extract_code_blocks(text: str) -> list[dict]: + """Extract fenced code blocks with language tags.""" + blocks = [] + for match in re.finditer(r"```(\w*)\n(.*?)```", text, re.DOTALL): + lang = match.group(1) or "text" + code = match.group(2).strip() + blocks.append({"language": lang, "code": code}) + return blocks + + +def _extract_patterns_from_code(code_blocks: list[dict]) -> list[str]: + """Auto-extract patterns from code blocks (function/class names, SQL keywords).""" + patterns = [] + for block in code_blocks: + code = block["code"] + lang = block["language"] + + if lang == "python": + for m in re.finditer(r"class\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"def\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"from\s+([\w.]+)\s+import", code): + patterns.append(m.group(1)) + + elif lang == "sql": + for m in re.finditer(r"(?:CREATE|ALTER|DROP)\s+(?:TABLE|VIEW|SCHEMA)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + for m in re.finditer(r"(?:STREAMING TABLE|MATERIALIZED VIEW)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + + return list(dict.fromkeys(patterns)) # dedupe preserving order + + +def _extract_facts_from_response(response: str) -> list[str]: + """Auto-extract key facts: API names, function calls, class references.""" + facts = [] + for m in re.finditer(r"(mlflow\.\w+(?:\.\w+)*)\(", response): + facts.append(m.group(1)) + for m in re.finditer(r"(spark\.\w+(?:\.\w+)*)\(", response): + facts.append(m.group(1)) + for m in re.finditer(r"\b([A-Z]\w+(?:Agent|Client|Config|Builder))\b", response): + facts.append(m.group(1)) + return list(dict.fromkeys(facts))[:10] + + +def _categorize_by_tools(tool_names: list[str]) -> str: + """Infer category from tool usage in the turn.""" + if any("sql" in t.lower() or "dbsql" in t.lower() for t in tool_names): + return "sql" + if any("bash" in t.lower() for t in tool_names): + return "deployment" + if any("write" in t.lower() or "edit" in t.lower() for t in tool_names): + return "code_generation" + return "general" + + +# --------------------------------------------------------------------------- +# Local trace extraction (session.jsonl) +# --------------------------------------------------------------------------- + +def extract_examples_from_file(trace_path: Path, skill_name: str | None = None) -> list[dict]: + """Parse a session.jsonl and extract test case candidates.""" + from skill_test.trace.parser import parse_transcript_file, link_tool_results + + entries = parse_transcript_file(trace_path) + link_tool_results(entries) + + examples = [] + idx = 0 + + for i, entry in enumerate(entries): + if entry.type != "user": + continue + if entry.tool_use_result: + continue + + user_text = _extract_text_content(entry.message) + if not user_text or len(user_text) < 10: + continue + + assistant_text = "" + tool_names = [] + for j in range(i + 1, len(entries)): + if entries[j].type == "assistant": + assistant_text = _extract_text_content(entries[j].message) + tool_names = [tc.name for tc in entries[j].tool_calls] + break + if entries[j].type == "user" and not entries[j].tool_use_result: + break + + if not assistant_text or len(assistant_text) < 50: + continue + + code_blocks = _extract_code_blocks(assistant_text) + auto_patterns = _extract_patterns_from_code(code_blocks) + auto_facts = _extract_facts_from_response(assistant_text) + category = _categorize_by_tools(tool_names) + + prefix = skill_name or "trace" + example = { + "id": f"{prefix}_{idx:03d}", + "inputs": {"prompt": user_text}, + "outputs": {"response": assistant_text}, + "expectations": {}, + "metadata": { + "category": category, + "source": "trace", + "trace_file": str(trace_path.name), + }, + } + + if auto_patterns: + example["expectations"]["expected_patterns"] = auto_patterns + if auto_facts: + example["expectations"]["expected_facts"] = auto_facts + if code_blocks: + langs = list({b["language"] for b in code_blocks if b["language"] != "text"}) + if langs: + example["metadata"]["languages"] = langs + + examples.append(example) + idx += 1 + + return examples + + +# --------------------------------------------------------------------------- +# MLflow trace extraction +# --------------------------------------------------------------------------- + +def _extract_examples_from_mlflow_trace(trace: Any, skill_name: str | None, idx_offset: int = 0) -> list[dict]: + """Extract test case candidates from an MLflow Trace object. + + MLflow traces from `mlflow autolog claude` contain spans representing + the agent conversation. We look for the root span's input/output which + contains the user prompt and final assistant response. + """ + examples = [] + prefix = skill_name or "mlflow" + + trace_info = trace.info + trace_id = trace_info.request_id if hasattr(trace_info, "request_id") else "unknown" + + # Try to get input/output from the trace data + user_text = "" + assistant_text = "" + + if trace.data: + # The root span typically has the full conversation + spans = trace.data.spans if hasattr(trace.data, "spans") else [] + + # Look for the root span (no parent) or the first AGENT/CHAIN span + root_span = None + for span in spans: + parent = getattr(span, "parent_id", None) + if parent is None or parent == "0": + root_span = span + break + + if root_span is None and spans: + root_span = spans[0] + + if root_span: + inputs = getattr(root_span, "inputs", None) + outputs = getattr(root_span, "outputs", None) + + # Extract user prompt from inputs + if isinstance(inputs, dict): + # Common patterns: {"messages": [...]}, {"input": "..."}, {"prompt": "..."} + if "messages" in inputs: + msgs = inputs["messages"] + if isinstance(msgs, list): + for msg in reversed(msgs): + if isinstance(msg, dict) and msg.get("role") == "user": + user_text = msg.get("content", "") + break + elif "input" in inputs: + user_text = str(inputs["input"]) + elif "prompt" in inputs: + user_text = str(inputs["prompt"]) + elif isinstance(inputs, str): + user_text = inputs + + # Extract assistant response from outputs + if isinstance(outputs, dict): + if "choices" in outputs: + choices = outputs["choices"] + if isinstance(choices, list) and choices: + msg = choices[0].get("message", {}) + assistant_text = msg.get("content", "") + elif "output" in outputs: + assistant_text = str(outputs["output"]) + elif "response" in outputs: + assistant_text = str(outputs["response"]) + elif isinstance(outputs, str): + assistant_text = outputs + + if not user_text or len(user_text) < 10: + return examples + if not assistant_text or len(assistant_text) < 50: + return examples + + # Build the test case + code_blocks = _extract_code_blocks(assistant_text) + auto_patterns = _extract_patterns_from_code(code_blocks) + auto_facts = _extract_facts_from_response(assistant_text) + + # Categorize by looking at tool spans + tool_names = [] + if trace.data and hasattr(trace.data, "spans"): + for span in trace.data.spans: + span_type = getattr(span, "span_type", "") + if span_type == "TOOL" or "tool" in getattr(span, "name", "").lower(): + tool_names.append(getattr(span, "name", "unknown")) + + category = _categorize_by_tools(tool_names) + + example = { + "id": f"{prefix}_{idx_offset:03d}", + "inputs": {"prompt": user_text}, + "outputs": {"response": assistant_text}, + "expectations": {}, + "metadata": { + "category": category, + "source": "mlflow_trace", + "trace_id": trace_id, + }, + } + + if auto_patterns: + example["expectations"]["expected_patterns"] = auto_patterns + if auto_facts: + example["expectations"]["expected_facts"] = auto_facts + if code_blocks: + langs = list({b["language"] for b in code_blocks if b["language"] != "text"}) + if langs: + example["metadata"]["languages"] = langs + + examples.append(example) + return examples + + +def list_mlflow_traces(experiment_id: str, limit: int = 20) -> None: + """List recent traces in an MLflow experiment.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching traces from experiment {experiment_id}...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[experiment_id], + max_results=limit, + ) + except Exception as e: + print(f"Error fetching traces: {e}") + print("\nMake sure you have authentication configured:") + print(" export DATABRICKS_HOST='https://.cloud.databricks.com'") + print(" export DATABRICKS_TOKEN='dapi...'") + sys.exit(1) + + if traces_df.empty: + print("No traces found in experiment.") + return + + print(f"\nFound {len(traces_df)} traces:\n") + print(f"{'Trace ID':<45} {'Status':<10} {'Timestamp':<25} {'Duration'}") + print("-" * 100) + + for _, row in traces_df.iterrows(): + trace_id = row.get("request_id", "unknown") + status = row.get("status", "?") + ts = row.get("timestamp_ms", 0) + duration = row.get("execution_time_ms", 0) + + from datetime import datetime + ts_str = datetime.fromtimestamp(ts / 1000).strftime("%Y-%m-%d %H:%M:%S") if ts else "?" + dur_str = f"{duration / 1000:.1f}s" if duration else "?" + + print(f"{trace_id:<45} {status:<10} {ts_str:<25} {dur_str}") + + print(f"\nTo extract examples from a specific trace:") + print(f" uv run python .test/scripts/trace_to_examples.py --trace-id --skill ") + print(f"\nTo extract from all traces in this experiment:") + print(f" uv run python .test/scripts/trace_to_examples.py --experiment-id {experiment_id} --skill ") + + +def extract_examples_from_experiment(experiment_id: str, skill_name: str | None, limit: int = 10) -> list[dict]: + """Extract examples from recent traces in an MLflow experiment.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching up to {limit} traces from experiment {experiment_id}...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[experiment_id], + max_results=limit, + filter_string="status = 'OK'", + ) + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + if traces_df.empty: + print("No successful traces found.") + return [] + + print(f"Processing {len(traces_df)} traces...") + all_examples = [] + + for _, row in traces_df.iterrows(): + trace_id = row.get("request_id") + if not trace_id: + continue + try: + trace = mlflow.get_trace(trace_id) + if trace: + examples = _extract_examples_from_mlflow_trace( + trace, skill_name, idx_offset=len(all_examples) + ) + all_examples.extend(examples) + if examples: + print(f" {trace_id}: extracted {len(examples)} example(s)") + except Exception as e: + print(f" {trace_id}: skipped ({e})") + + return all_examples + + +def extract_examples_from_trace_id(trace_id: str, skill_name: str | None) -> list[dict]: + """Extract examples from a single MLflow trace by ID.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching trace {trace_id}...") + trace = mlflow.get_trace(trace_id) + if trace is None: + print(f"Trace not found: {trace_id}") + return [] + + return _extract_examples_from_mlflow_trace(trace, skill_name) + + +def extract_examples_from_run_id(run_id: str, skill_name: str | None) -> list[dict]: + """Extract examples from an MLflow run (downloads session.jsonl artifact).""" + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + import mlflow + + print(f"Fetching artifacts from run {run_id}...") + + # Try to download session.jsonl artifact + artifact_names = ["trace.jsonl", "session.jsonl", "transcript.jsonl"] + artifact_path = None + + for name in artifact_names: + try: + artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=name) + print(f" Downloaded {name}") + break + except Exception: + continue + + if artifact_path: + return extract_examples_from_file(Path(artifact_path), skill_name) + + # Fallback: try MLflow traces API + print(" No JSONL artifact found, trying traces API...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[mlflow.get_run(run_id).info.experiment_id], + filter_string=f"run_id = '{run_id}'", + max_results=10, + ) + if not traces_df.empty: + all_examples = [] + for _, row in traces_df.iterrows(): + tid = row.get("request_id") + if tid: + trace = mlflow.get_trace(tid) + if trace: + all_examples.extend( + _extract_examples_from_mlflow_trace(trace, skill_name, len(all_examples)) + ) + return all_examples + except Exception as e: + print(f" Traces API failed: {e}") + + print(" No extractable data found in this run.") + return [] + + +# --------------------------------------------------------------------------- +# LLM refinement and output +# --------------------------------------------------------------------------- + +def refine_with_llm(examples: list[dict], skill_name: str) -> list[dict]: + """Use an LLM to review and refine auto-extracted expectations.""" + import litellm + import json + + for ex in examples: + prompt_text = ex["inputs"]["prompt"][:200] + response_text = ex["outputs"]["response"][:1000] + current_patterns = ex["expectations"].get("expected_patterns", []) + current_facts = ex["expectations"].get("expected_facts", []) + + messages = [ + { + "role": "system", + "content": ( + "You are reviewing auto-extracted test expectations for a " + f"Databricks skill called '{skill_name}'. Refine the patterns " + "and facts to be more precise and meaningful. Return JSON only." + ), + }, + { + "role": "user", + "content": ( + f"Prompt: {prompt_text}\n\n" + f"Response excerpt: {response_text}\n\n" + f"Auto-extracted patterns: {json.dumps(current_patterns)}\n" + f"Auto-extracted facts: {json.dumps(current_facts)}\n\n" + "Return a JSON object with:\n" + '- "expected_patterns": list of regex pattern strings\n' + '- "expected_facts": list of fact strings that must appear\n' + "Keep only patterns/facts that are genuinely important for correctness." + ), + }, + ] + + try: + from skill_test.optimize.config import DEFAULT_GEN_LM + resp = litellm.completion(model=DEFAULT_GEN_LM, messages=messages) + content = resp.choices[0].message.content + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + refined = json.loads(json_match.group()) + if "expected_patterns" in refined: + ex["expectations"]["expected_patterns"] = refined["expected_patterns"] + if "expected_facts" in refined: + ex["expectations"]["expected_facts"] = refined["expected_facts"] + except Exception as e: + print(f" Warning: LLM refinement failed for {ex['id']}: {e}") + + return examples + + +def save_examples(examples: list[dict], output_path: Path) -> None: + """Save examples to a YAML file.""" + import yaml + + data = {"test_cases": examples} + with open(output_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + print(f"Saved {len(examples)} examples to {output_path}") + + +def append_to_ground_truth(examples: list[dict], skill_name: str) -> None: + """Append examples directly to a skill's ground_truth.yaml.""" + import yaml + + from skill_test.dataset import get_dataset_source + + try: + source = get_dataset_source(skill_name) + existing = source.load() + existing_ids = {r.id for r in existing} + gt_path = source.yaml_path + except FileNotFoundError: + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + gt_path.parent.mkdir(parents=True, exist_ok=True) + existing_ids = set() + + new_examples = [ex for ex in examples if ex["id"] not in existing_ids] + if not new_examples: + print("No new examples to add (all IDs already exist).") + return + + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + data = {"test_cases": []} + + data["test_cases"].extend(new_examples) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"Appended {len(new_examples)} examples to {gt_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Extract test cases from Claude Code traces (local or MLflow)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Trace sources (mutually exclusive) + source_group = parser.add_argument_group("trace source (pick one)") + source_group.add_argument( + "--trace", "-t", + type=Path, + default=None, + help="Path to local session.jsonl transcript file", + ) + source_group.add_argument( + "--experiment-id", + default=None, + help="MLflow experiment ID to extract traces from (e.g., 2452310130108632)", + ) + source_group.add_argument( + "--run-id", + default=None, + help="MLflow run ID to extract traces from", + ) + source_group.add_argument( + "--trace-id", + default=None, + help="MLflow trace ID (e.g., tr-d416fccdab46e2dea6bad1d0bd8aaaa8)", + ) + + # Common options + parser.add_argument( + "--skill", "-s", + default=None, + help="Skill name to tag examples with (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--refine", + action="store_true", + help="Use LLM to review and refine auto-extracted expectations", + ) + parser.add_argument( + "--trust", + action="store_true", + help="Auto-append to ground_truth.yaml instead of writing candidates.yaml", + ) + parser.add_argument( + "--output", "-o", + type=Path, + default=None, + help="Output file path (default: candidates.yaml in skill dir or cwd)", + ) + parser.add_argument( + "--limit", + type=int, + default=10, + help="Max number of traces to process from an experiment (default: 10)", + ) + parser.add_argument( + "--list", + action="store_true", + dest="list_traces", + help="List traces in the experiment without extracting (use with --experiment-id)", + ) + + args = parser.parse_args() + + # Validate: at least one source required + sources = [args.trace, args.experiment_id, args.run_id, args.trace_id] + if not any(sources): + parser.error("Provide a trace source: --trace, --experiment-id, --run-id, or --trace-id") + + # List mode + if args.list_traces: + if not args.experiment_id: + parser.error("--list requires --experiment-id") + list_mlflow_traces(args.experiment_id, limit=args.limit) + return + + # Extract examples based on source + examples = [] + + if args.trace: + if not args.trace.exists(): + print(f"Error: trace file not found: {args.trace}") + sys.exit(1) + print(f"Parsing local trace: {args.trace}") + examples = extract_examples_from_file(args.trace, args.skill) + + elif args.experiment_id: + examples = extract_examples_from_experiment(args.experiment_id, args.skill, limit=args.limit) + + elif args.run_id: + examples = extract_examples_from_run_id(args.run_id, args.skill) + + elif args.trace_id: + examples = extract_examples_from_trace_id(args.trace_id, args.skill) + + print(f"\nExtracted {len(examples)} candidate test cases") + + if not examples: + print("No suitable prompt/response pairs found.") + sys.exit(0) + + if args.refine: + print("Refining expectations with LLM...") + examples = refine_with_llm(examples, args.skill or "unknown") + + if args.trust and args.skill: + append_to_ground_truth(examples, args.skill) + else: + output_path = args.output + if output_path is None: + if args.skill: + output_path = Path(".test/skills") / args.skill / "candidates.yaml" + else: + output_path = Path("candidates.yaml") + output_path.parent.mkdir(parents=True, exist_ok=True) + save_examples(examples, output_path) + + +if __name__ == "__main__": + main() diff --git a/.test/skills/databricks-app-apx/ground_truth.yaml b/.test/skills/databricks-app-apx/ground_truth.yaml new file mode 100644 index 00000000..ab9c8bbf --- /dev/null +++ b/.test/skills/databricks-app-apx/ground_truth.yaml @@ -0,0 +1,311 @@ +metadata: + skill_name: databricks-app-apx + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: apx_initialize_project_001 + inputs: + prompt: "I want to start building a new full-stack Databricks app. How do I initialize an APX project?" + outputs: + response: | + To initialize an APX project, use the APX MCP tools: + + ```bash + # Start the APX development server + mcp-cli call apx/start '{}' + + # Check that the server is running + mcp-cli call apx/status '{}' + ``` + + After initialization, APX scaffolds a FastAPI backend and React frontend with TypeScript support. + execution_success: true + expectations: + expected_facts: + - "APX uses MCP tools invoked via mcp-cli call" + - "apx/start initializes the development server" + - "apx/status checks the current server status" + expected_patterns: + - pattern: "apx/start" + min_count: 1 + description: "Mentions the apx/start MCP tool" + - pattern: "apx/status" + min_count: 1 + description: "Mentions the apx/status MCP tool" + guidelines: + - "Response must show the MCP CLI commands for starting APX" + - "Response should explain what happens after initialization" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: apx_pydantic_models_002 + inputs: + prompt: "What is the recommended pattern for defining backend data models in an APX app?" + outputs: + response: | + In `src/{app_name}/backend/models.py`, follow the 3-model pattern: + + ```python + from pydantic import BaseModel + + # Input validation - what the client sends + class EntityIn(BaseModel): + name: str + description: str + + # Database record - stored internally + class EntityRecord(EntityIn): + id: str + created_at: datetime + + # API response - what the client receives + class EntityOut(EntityRecord): + pass + ``` + + - `EntityIn` handles input validation + - `EntityRecord` extends with database fields + - `EntityOut` controls API response shape + execution_success: true + expectations: + expected_facts: + - "APX backend models live in src/{app_name}/backend/models.py" + - "The 3-model pattern uses EntityIn, EntityRecord, and EntityOut" + - "Models are built with Pydantic BaseModel" + expected_patterns: + - pattern: "class\\s+\\w+In\\(" + min_count: 1 + description: "Defines an input model class" + - pattern: "class\\s+\\w+Record\\(" + min_count: 1 + description: "Defines a record model class" + - pattern: "BaseModel" + min_count: 1 + description: "Uses Pydantic BaseModel" + guidelines: + - "Response must show the 3-model pattern with EntityIn, EntityRecord, and EntityOut" + - "Code must be valid Python with Pydantic imports" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_api_routes_003 + inputs: + prompt: "How should I define API routes in my APX backend?" + outputs: + response: | + In `src/{app_name}/backend/router.py`, define FastAPI routes with these requirements: + + ```python + from fastapi import APIRouter + + router = APIRouter(prefix="/api/entities", tags=["entities"]) + + @router.get("/", response_model=list[EntityOut]) + async def list_entities(): + return entities + + @router.get("/{entity_id}", response_model=EntityOut) + async def get_entity(entity_id: str): + return find_entity(entity_id) + + @router.post("/", response_model=EntityOut) + async def create_entity(entity: EntityIn): + return save_entity(entity) + ``` + + Always include `response_model` on every route, as this enables OpenAPI spec generation for the frontend client. + execution_success: true + expectations: + expected_facts: + - "API routes are defined in src/{app_name}/backend/router.py" + - "Every route must include response_model for OpenAPI generation" + - "Routes use FastAPI's APIRouter" + expected_patterns: + - pattern: "response_model" + min_count: 2 + description: "Includes response_model on route decorators" + - pattern: "APIRouter" + min_count: 1 + description: "Uses FastAPI APIRouter" + - pattern: "@router\\.(get|post|put|delete)" + min_count: 2 + description: "Defines HTTP method route decorators" + guidelines: + - "Response must emphasize the requirement for response_model on every route" + - "Code must show standard CRUD route patterns" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_add_ui_components_004 + inputs: + prompt: "How do I add shadcn UI components to my APX frontend?" + outputs: + response: | + Use the shadcn MCP tool to add components: + + ```bash + mcp-cli call shadcn/get_add_command_for_items '{ + "items": ["@shadcn/button", "@shadcn/card", "@shadcn/table", + "@shadcn/badge", "@shadcn/select", "@shadcn/skeleton"] + }' + ``` + + This generates the install command for the specified shadcn components. Run the output command to add them to your frontend. + execution_success: true + expectations: + expected_facts: + - "shadcn components are added via the shadcn MCP tool" + - "The tool is invoked with shadcn/get_add_command_for_items" + expected_patterns: + - pattern: "shadcn/get_add_command_for_items" + min_count: 1 + description: "Uses the shadcn MCP tool" + - pattern: "@shadcn/" + min_count: 2 + description: "References shadcn component packages" + guidelines: + - "Response must show the MCP tool call with JSON parameters" + - "Response should list common UI components like button, card, table" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: apx_testing_endpoints_005 + inputs: + prompt: "How do I test my APX app's backend API endpoints and run type checks?" + outputs: + response: | + APX provides built-in testing and type checking: + + ```bash + # Run type checks on both backend and frontend + mcp-cli call apx/dev_check '{}' + + # Test API endpoints manually + curl http://localhost:8000/api/{entities} | jq . + curl http://localhost:8000/api/{entities}/{id} | jq . + + # Get the frontend development URL + mcp-cli call apx/get_frontend_url '{}' + ``` + + Fix any Python type errors reported by basedpyright before deploying. + execution_success: true + expectations: + expected_facts: + - "apx/dev_check runs type checks on both backend and frontend" + - "Backend API runs on localhost:8000" + - "APX uses basedpyright for Python type checking" + expected_patterns: + - pattern: "apx/dev_check" + min_count: 1 + description: "Uses the dev_check MCP tool" + - pattern: "curl.*localhost:8000" + min_count: 1 + description: "Tests endpoints via curl on localhost" + - pattern: "apx/get_frontend_url" + min_count: 1 + description: "Uses the get_frontend_url MCP tool" + guidelines: + - "Response must include both type checking and API testing commands" + - "Response should mention basedpyright for Python type checking" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_deployment_006 + inputs: + prompt: "How do I deploy my APX app to Databricks and check the logs?" + outputs: + response: | + Deploy and monitor your APX app: + + ```bash + # Deploy the app to Databricks + mcp-cli call apx/deploy '{}' + + # Check deployed application logs + databricks apps logs --profile + ``` + + The APX MCP server can also automatically check deployed application logs. Simply ask: "Please check the deployed app logs for ". + execution_success: true + expectations: + expected_facts: + - "APX apps are deployed using the apx/deploy MCP tool" + - "Application logs are accessed via databricks apps logs CLI command" + expected_patterns: + - pattern: "apx/deploy" + min_count: 1 + description: "Uses the deploy MCP tool" + - pattern: "databricks apps logs" + min_count: 1 + description: "Uses the databricks apps logs CLI command" + guidelines: + - "Response must show both deployment and log monitoring commands" + - "Response should mention the APX MCP automatic log checking feature" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_full_workflow_007 + inputs: + prompt: "Walk me through the complete APX development workflow from start to deployment." + outputs: + response: | + The APX development workflow has 5 phases: + + **Phase 1: Initialize** + ```bash + mcp-cli call apx/start '{}' + mcp-cli call apx/status '{}' + ``` + + **Phase 2: Backend** - Create Pydantic models in `models.py` and FastAPI routes in `router.py` + + **Phase 3: Frontend** - Add shadcn components and build React UI + + **Phase 4: Testing** + ```bash + mcp-cli call apx/dev_check '{}' + curl http://localhost:8000/api/{entities} | jq . + ``` + + **Phase 5: Deploy** + ```bash + mcp-cli call apx/deploy '{}' + databricks apps logs --profile + ``` + execution_success: true + expectations: + expected_facts: + - "APX development follows 5 phases: Initialize, Backend, Frontend, Testing, Deploy" + - "Backend uses Pydantic models and FastAPI routes" + - "Frontend uses React with shadcn components" + expected_patterns: + - pattern: "apx/start" + min_count: 1 + description: "Includes initialization step" + - pattern: "apx/dev_check" + min_count: 1 + description: "Includes testing step" + - pattern: "apx/deploy" + min_count: 1 + description: "Includes deployment step" + guidelines: + - "Response must cover all 5 phases of the APX workflow" + - "Response should include code examples for key phases" + metadata: + category: happy_path + difficulty: hard + source: auto_generated diff --git a/.test/skills/databricks-app-apx/manifest.yaml b/.test/skills/databricks-app-apx/manifest.yaml new file mode 100644 index 00000000..e819a067 --- /dev/null +++ b/.test/skills/databricks-app-apx/manifest.yaml @@ -0,0 +1,28 @@ +skill_name: databricks-app-apx +description: Build full-stack Databricks applications using APX framework (FastAPI + React). +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-config/ground_truth.yaml b/.test/skills/databricks-config/ground_truth.yaml new file mode 100644 index 00000000..10dbad09 --- /dev/null +++ b/.test/skills/databricks-config/ground_truth.yaml @@ -0,0 +1,199 @@ +metadata: + skill_name: databricks-config + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: databricks-config_setup_default_profile_001 + inputs: + prompt: Configure my Databricks profile for the DEFAULT workspace + outputs: + response: | + I'll help you configure your Databricks DEFAULT profile. First, let's authenticate: + + ```bash + databricks auth login -p DEFAULT + ``` + + This will open a browser for OAuth authentication. After authenticating, I'll update your `~/.databrickscfg` with the compute configuration. + + Would you like to use: + - **Cluster ID**: A specific interactive/all-purpose cluster + - **Serverless**: Serverless compute (`serverless_compute_id = auto`) + + Example configuration with a cluster: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + cluster_id = 1217-064531-c9c3ngyn + auth_type = databricks-cli + ``` + execution_success: true + expectations: + expected_facts: + - Uses databricks auth login command + - Configures ~/.databrickscfg file + - Offers choice between cluster_id and serverless + expected_patterns: + - pattern: databricks auth login + min_count: 1 + description: Must use databricks auth login command + - pattern: databrickscfg + min_count: 1 + description: Must reference the config file + guidelines: + - Must run databricks auth login for authentication + - Must offer choice between cluster_id and serverless compute + - Must never print token values in plain text + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_serverless_config_002 + inputs: + prompt: Set up my Databricks profile to use serverless compute + outputs: + response: | + I'll configure your profile with serverless compute. + + ```bash + databricks auth login -p DEFAULT + ``` + + Then update ~/.databrickscfg: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + serverless_compute_id = auto + auth_type = databricks-cli + ``` + + Note: For serverless, set `serverless_compute_id = auto` (not just `serverless = true`). + execution_success: true + expectations: + expected_facts: + - Sets serverless_compute_id to auto + - Does not use cluster_id with serverless + expected_patterns: + - pattern: serverless_compute_id + min_count: 1 + description: Must use serverless_compute_id parameter + - pattern: auto + min_count: 1 + description: Must set serverless to auto + guidelines: + - Must use serverless_compute_id = auto, not serverless = true + - Must not combine cluster_id and serverless_compute_id + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_workspace_url_003 + inputs: + prompt: Configure Databricks profile using workspace URL https://adb-1234567890123456.7.azuredatabricks.net/ + outputs: + response: | + I'll extract the profile name from the workspace URL and configure it. + + Profile name: `adb-1234567890123456` + + ```bash + databricks auth login -p adb-1234567890123456 --host https://adb-1234567890123456.7.azuredatabricks.net/ + ``` + + After authentication, choose your compute option and I'll update ~/.databrickscfg. + execution_success: true + expectations: + expected_facts: + - Extracts profile name from workspace URL + - Uses --host flag with the URL + expected_patterns: + - pattern: databricks auth login + min_count: 1 + description: Must use databricks auth login + - pattern: "\\-\\-host" + min_count: 1 + description: Must pass --host when URL is provided + guidelines: + - Must extract profile name from workspace URL + - Must pass --host flag when workspace URL is provided + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: databricks-config_cluster_id_004 + inputs: + prompt: Configure my Databricks profile with cluster ID 0123-456789-abcdef12 + outputs: + response: | + I'll configure your profile with the specified cluster ID. + + ```bash + databricks auth login -p DEFAULT + ``` + + Updated ~/.databrickscfg: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + cluster_id = 0123-456789-abcdef12 + auth_type = databricks-cli + ``` + execution_success: true + expectations: + expected_facts: + - Sets cluster_id in configuration + - Preserves existing host and auth settings + expected_patterns: + - pattern: cluster_id + min_count: 1 + description: Must set cluster_id + - pattern: 0123-456789-abcdef12 + min_count: 1 + description: Must use the specified cluster ID + guidelines: + - Must set cluster_id, not serverless_compute_id + - Must preserve existing profile settings + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_token_security_005 + inputs: + prompt: Show me my Databricks profile configuration + outputs: + response: | + Here's your current profile configuration: + + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + token = [REDACTED] + cluster_id = 1217-064531-c9c3ngyn + ``` + + View full configuration at: ~/.databrickscfg + execution_success: true + expectations: + expected_facts: + - Token values must be redacted + - Directs user to ~/.databrickscfg for full config + expected_patterns: + - pattern: REDACTED + min_count: 1 + description: Must redact token values + - pattern: databrickscfg + min_count: 1 + description: Must reference config file location + guidelines: + - Must NEVER print token values in plain text + - Must redact any token field values + - Must inform user they can view full config at ~/.databrickscfg + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-config/manifest.yaml b/.test/skills/databricks-config/manifest.yaml new file mode 100644 index 00000000..6e8db2b6 --- /dev/null +++ b/.test/skills/databricks-config/manifest.yaml @@ -0,0 +1,28 @@ +skill_name: databricks-config +description: Configure Databricks profile and authenticate for Databricks Connect, Databricks CLI, and Databricks SDK. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-dbsql/ground_truth.yaml b/.test/skills/databricks-dbsql/ground_truth.yaml new file mode 100644 index 00000000..f6a3c6d6 --- /dev/null +++ b/.test/skills/databricks-dbsql/ground_truth.yaml @@ -0,0 +1,366 @@ +metadata: + skill_name: databricks-dbsql + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: dbsql_sql_scripting_etl_001 + inputs: + prompt: "Write a SQL script that conditionally processes new orders into a processed table using Databricks SQL scripting." + outputs: + response: | + ```sql + BEGIN + DECLARE v_count INT; + DECLARE v_status STRING DEFAULT 'pending'; + + SET v_count = (SELECT COUNT(*) FROM catalog.schema.raw_orders WHERE status = 'new'); + + IF v_count > 0 THEN + INSERT INTO catalog.schema.processed_orders + SELECT *, current_timestamp() AS processed_at + FROM catalog.schema.raw_orders + WHERE status = 'new'; + + SET v_status = 'completed'; + ELSE + SET v_status = 'skipped'; + END IF; + + SELECT v_status AS result, v_count AS rows_processed; + END + ``` + execution_success: true + expectations: + expected_facts: + - "Databricks SQL scripting uses BEGIN/END blocks for procedural logic" + - "Variables are declared with DECLARE and assigned with SET" + - "IF/ELSE conditional logic is supported in SQL scripts" + expected_patterns: + - pattern: "BEGIN" + min_count: 1 + description: "Uses SQL scripting BEGIN block" + - pattern: "DECLARE" + min_count: 1 + description: "Declares variables using DECLARE" + - pattern: "IF.*THEN" + min_count: 1 + description: "Uses IF/THEN conditional logic" + - pattern: "current_timestamp\\(\\)" + min_count: 1 + description: "Uses current_timestamp() function" + guidelines: + - "SQL must use Databricks SQL scripting syntax with BEGIN/END blocks" + - "Response must include variable declarations and conditional logic" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: dbsql_stored_procedure_002 + inputs: + prompt: "Create a stored procedure that upserts customer data with error handling." + outputs: + response: | + ```sql + CREATE OR REPLACE PROCEDURE catalog.schema.upsert_customers( + IN p_source STRING, + OUT p_rows_affected INT + ) + LANGUAGE SQL + SQL SECURITY INVOKER + BEGIN + DECLARE EXIT HANDLER FOR SQLEXCEPTION + BEGIN + SET p_rows_affected = -1; + SIGNAL SQLSTATE '45000' + SET MESSAGE_TEXT = concat('Upsert failed for source: ', p_source); + END; + + MERGE INTO catalog.schema.dim_customer AS t + USING (SELECT * FROM identifier(p_source)) AS s + ON t.customer_id = s.customer_id + WHEN MATCHED THEN UPDATE SET * + WHEN NOT MATCHED THEN INSERT *; + + SET p_rows_affected = (SELECT COUNT(*) FROM identifier(p_source)); + END; + + -- Invoke the procedure: + CALL catalog.schema.upsert_customers('catalog.schema.staging_customers', ?); + ``` + execution_success: true + expectations: + expected_facts: + - "Stored procedures are created with CREATE OR REPLACE PROCEDURE" + - "SQL SECURITY INVOKER runs with the caller's permissions" + - "The identifier() function dynamically resolves table names from string parameters" + - "Error handling uses DECLARE EXIT HANDLER FOR SQLEXCEPTION" + expected_patterns: + - pattern: "CREATE OR REPLACE PROCEDURE" + min_count: 1 + description: "Creates a stored procedure" + - pattern: "MERGE INTO" + min_count: 1 + description: "Uses MERGE for upsert logic" + - pattern: "identifier\\(" + min_count: 1 + description: "Uses identifier() for dynamic table references" + - pattern: "DECLARE EXIT HANDLER" + min_count: 1 + description: "Includes error handling with exit handler" + guidelines: + - "Procedure must include IN/OUT parameters" + - "Must include error handling with SQLEXCEPTION handler" + - "Must use MERGE INTO for upsert pattern" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_materialized_view_003 + inputs: + prompt: "Create a materialized view that refreshes hourly and aggregates daily revenue by region." + outputs: + response: | + ```sql + CREATE OR REPLACE MATERIALIZED VIEW catalog.schema.daily_revenue + CLUSTER BY (order_date) + SCHEDULE EVERY 1 HOUR + COMMENT 'Hourly-refreshed daily revenue by region' + AS SELECT + order_date, + region, + SUM(amount) AS total_revenue, + COUNT(DISTINCT customer_id) AS unique_customers + FROM catalog.schema.fact_orders + JOIN catalog.schema.dim_store USING (store_id) + GROUP BY order_date, region; + ``` + execution_success: true + expectations: + expected_facts: + - "Materialized views support scheduled refresh with SCHEDULE EVERY syntax" + - "CLUSTER BY optimizes data layout for query performance" + - "Materialized views are created with CREATE OR REPLACE MATERIALIZED VIEW" + expected_patterns: + - pattern: "CREATE OR REPLACE MATERIALIZED VIEW" + min_count: 1 + description: "Creates a materialized view" + - pattern: "SCHEDULE EVERY" + min_count: 1 + description: "Configures scheduled refresh" + - pattern: "CLUSTER BY" + min_count: 1 + description: "Uses CLUSTER BY for data layout optimization" + - pattern: "GROUP BY" + min_count: 1 + description: "Aggregates data with GROUP BY" + guidelines: + - "Must include SCHEDULE EVERY for automated refresh" + - "Must include CLUSTER BY for performance optimization" + - "Response should use realistic aggregation patterns" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: dbsql_pipe_syntax_004 + inputs: + prompt: "Show me how to use Databricks SQL pipe syntax to transform and filter order data." + outputs: + response: | + ```sql + -- Pipe syntax: read left-to-right instead of inside-out + FROM catalog.schema.fact_orders + |> WHERE order_date >= current_date() - INTERVAL 30 DAYS + |> AGGREGATE SUM(amount) AS total, COUNT(*) AS cnt GROUP BY region, product_category + |> WHERE total > 10000 + |> ORDER BY total DESC + |> LIMIT 20; + ``` + + Pipe syntax starts with FROM and chains transformations with `|>`, making complex queries more readable than nested subqueries. + execution_success: true + expectations: + expected_facts: + - "Pipe syntax uses |> to chain transformations left-to-right" + - "Queries start with FROM instead of SELECT when using pipe syntax" + - "AGGREGATE replaces GROUP BY in pipe syntax context" + expected_patterns: + - pattern: "\\|>" + min_count: 3 + description: "Uses pipe operator |> for chaining" + - pattern: "^FROM" + min_count: 1 + description: "Starts the query with FROM" + - pattern: "AGGREGATE" + min_count: 1 + description: "Uses AGGREGATE keyword in pipe syntax" + guidelines: + - "Response must demonstrate the pipe |> operator" + - "Query must start with FROM, not SELECT" + - "Response should explain the readability benefit of pipe syntax" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_ai_functions_005 + inputs: + prompt: "How can I use Databricks AI functions to classify and analyze text data in SQL?" + outputs: + response: | + ```sql + -- Classify support tickets + SELECT + ticket_id, + description, + ai_classify(description, ARRAY('billing', 'technical', 'account', 'feature_request')) AS category, + ai_analyze_sentiment(description) AS sentiment + FROM catalog.schema.support_tickets + LIMIT 100; + + -- Extract structured entities from text + SELECT + doc_id, + ai_extract(content, ARRAY('person_name', 'company', 'dollar_amount')) AS entities + FROM catalog.schema.contracts; + + -- General-purpose AI query with typed output + SELECT ai_query( + 'databricks-meta-llama-3-3-70b-instruct', + concat('Summarize this feedback in JSON: ', feedback), + returnType => 'STRUCT>' + ) AS analysis + FROM catalog.schema.customer_feedback + LIMIT 50; + ``` + execution_success: true + expectations: + expected_facts: + - "ai_classify() categorizes text into predefined labels" + - "ai_analyze_sentiment() returns sentiment analysis of text" + - "ai_extract() extracts structured entities from unstructured text" + - "ai_query() calls an LLM endpoint with optional structured return types" + expected_patterns: + - pattern: "ai_classify" + min_count: 1 + description: "Uses ai_classify() for text classification" + - pattern: "ai_analyze_sentiment" + min_count: 1 + description: "Uses ai_analyze_sentiment() for sentiment analysis" + - pattern: "ai_extract" + min_count: 1 + description: "Uses ai_extract() for entity extraction" + - pattern: "ai_query" + min_count: 1 + description: "Uses ai_query() for general LLM queries" + guidelines: + - "Response must demonstrate at least three different AI functions" + - "ai_query must include the returnType parameter for structured output" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_geospatial_h3_006 + inputs: + prompt: "Write a query that finds all stores within 5km of each customer using H3 geospatial indexing." + outputs: + response: | + ```sql + -- Find stores within 5km using H3 spatial indexing + WITH customer_h3 AS ( + SELECT *, h3_longlatash3(longitude, latitude, 7) AS h3_cell + FROM catalog.schema.customers + ), + store_h3 AS ( + SELECT *, h3_longlatash3(longitude, latitude, 7) AS h3_cell + FROM catalog.schema.stores + ) + SELECT + c.customer_id, + s.store_id, + ST_Distance( + ST_Point(c.longitude, c.latitude), + ST_Point(s.longitude, s.latitude) + ) AS distance_m + FROM customer_h3 c + JOIN store_h3 s ON h3_ischildof(c.h3_cell, h3_toparent(s.h3_cell, 5)) + WHERE ST_Distance( + ST_Point(c.longitude, c.latitude), + ST_Point(s.longitude, s.latitude) + ) < 5000; + ``` + execution_success: true + expectations: + expected_facts: + - "h3_longlatash3() converts longitude/latitude to H3 cell index at a given resolution" + - "ST_Distance() calculates distance in meters between two geometric points" + - "ST_Point() creates a geometric point from longitude and latitude" + - "h3_ischildof() and h3_toparent() navigate the H3 hierarchy for spatial joins" + expected_patterns: + - pattern: "h3_longlatash3" + min_count: 1 + description: "Uses H3 cell indexing function" + - pattern: "ST_Distance" + min_count: 1 + description: "Uses ST_Distance for distance calculation" + - pattern: "ST_Point" + min_count: 1 + description: "Uses ST_Point for point geometry" + - pattern: "h3_ischildof" + min_count: 1 + description: "Uses H3 hierarchy for efficient spatial join" + guidelines: + - "Must use H3 indexing for efficient spatial join instead of brute-force distance" + - "Must include ST_Distance for final distance filtering" + metadata: + category: happy_path + difficulty: hard + source: auto_generated + +- id: dbsql_collation_007 + inputs: + prompt: "How do I create a table with case-insensitive text columns in Databricks SQL?" + outputs: + response: | + ```sql + -- Create table with case-insensitive collation + CREATE TABLE catalog.schema.products ( + product_id BIGINT GENERATED ALWAYS AS IDENTITY, + name STRING COLLATE UTF8_LCASE, + category STRING COLLATE UTF8_LCASE, + price DECIMAL(10, 2) + ); + + -- Queries are automatically case-insensitive (no LOWER() needed) + SELECT * FROM catalog.schema.products + WHERE name = 'MacBook Pro'; -- matches 'macbook pro', 'MACBOOK PRO', etc. + ``` + + The `UTF8_LCASE` collation makes string comparisons case-insensitive at the column level, eliminating the need for `LOWER()` or `UPPER()` in WHERE clauses. + execution_success: true + expectations: + expected_facts: + - "UTF8_LCASE collation enables case-insensitive comparisons at the column level" + - "Collation is specified with COLLATE keyword in column definitions" + - "Case-insensitive collation eliminates the need for LOWER() or UPPER() in queries" + expected_patterns: + - pattern: "COLLATE UTF8_LCASE" + min_count: 1 + description: "Uses UTF8_LCASE collation" + - pattern: "CREATE TABLE" + min_count: 1 + description: "Creates a table with collation" + - pattern: "GENERATED ALWAYS AS IDENTITY" + min_count: 1 + description: "Uses auto-generated identity column" + guidelines: + - "Response must show COLLATE UTF8_LCASE in column definitions" + - "Response should explain the benefit over using LOWER() in queries" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-dbsql/manifest.yaml b/.test/skills/databricks-dbsql/manifest.yaml new file mode 100644 index 00000000..625132dc --- /dev/null +++ b/.test/skills/databricks-dbsql/manifest.yaml @@ -0,0 +1,34 @@ +skill_name: databricks-dbsql +description: 'Databricks SQL (DBSQL) advanced features and SQL warehouse capabilities. This skill MUST be invoked when the + user mentions: "DBSQL", "Databricks SQL", "SQL warehouse", "SQL scripting", "stored procedure", "CALL procedure", "materialized + view", "CREATE MATERIALIZED VIEW", "pipe syntax", "|>", "geospatial", "H3", "ST_", "spatial SQL", "collation", "COLLATE", + "ai_query", "ai_classify", "ai_extract", "ai_gen", "AI function", "http_request", "remote_query", "read_files", "Lakehouse + Federation", "recursive CTE", "WITH RECURSIVE", "multi-statement transaction", "temp table", "temporary view", "pipe operator". + SHOULD also invoke when the user asks about SQL best practices, data modeling patterns, or advanced SQL features on Databricks.' +scorers: + enabled: + - sql_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-docs/ground_truth.yaml b/.test/skills/databricks-docs/ground_truth.yaml new file mode 100644 index 00000000..39179c0c --- /dev/null +++ b/.test/skills/databricks-docs/ground_truth.yaml @@ -0,0 +1,170 @@ +metadata: + skill_name: databricks-docs + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: databricks-docs_lookup_reference_001 + inputs: + prompt: How do I look up Databricks documentation for a specific feature? + outputs: + response: | + Use the Databricks documentation index at `https://docs.databricks.com/llms.txt`. + + Fetch it with WebFetch, then search for relevant sections and links. The index is organized by category: + - Data Engineering (Lakeflow, Spark, Delta Lake, pipelines) + - SQL & Analytics (Warehouses, queries, dashboards) + - AI/ML (MLflow, model serving, GenAI) + - Governance (Unity Catalog, permissions, security) + - Developer Tools (SDKs, CLI, APIs, Terraform) + execution_success: true + expectations: + expected_facts: + - Use llms.txt documentation index + - Index is organized by categories + expected_patterns: + - pattern: docs\.databricks\.com + min_count: 1 + description: Must reference Databricks docs URL + - pattern: llms\.txt + min_count: 1 + description: Must reference the llms.txt index + guidelines: + - Must reference the llms.txt documentation index URL + - Should explain the documentation structure + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_reference_role_002 + inputs: + prompt: What is the role of the databricks-docs skill? + outputs: + response: | + This is a reference skill, not an action skill. Use it to: + - Look up documentation when other skills don't cover a topic + - Get authoritative guidance on Databricks concepts and APIs + - Find detailed information to inform how you use MCP tools + - Discover features and capabilities + + Always prefer using MCP tools for actions and load specific skills for workflows. + execution_success: true + expectations: + expected_facts: + - Is a reference skill, not an action skill + - Supplements other skills and MCP tools + expected_patterns: + - pattern: reference + min_count: 1 + description: Must identify as reference skill + guidelines: + - Must clarify this is a reference skill, not an action skill + - Must recommend MCP tools for actual actions + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_complement_skills_003 + inputs: + prompt: How do I use databricks-docs alongside other skills? + outputs: + response: | + The docs skill complements other skills: + + 1. Load the specific skill for workflow patterns (e.g., databricks-spark-declarative-pipelines) + 2. Use databricks-docs to fetch docs if you need clarification on specific features + 3. Use MCP tools to actually perform the action (e.g., create_or_update_pipeline) + + Example: User wants to create a Delta Live Tables pipeline + - Load `databricks-spark-declarative-pipelines` for workflow patterns + - Fetch docs via llms.txt if you need clarification on DLT features + - Use `create_or_update_pipeline` MCP tool to create the pipeline + execution_success: true + expectations: + expected_facts: + - Load specific skills for workflows + - Use docs for clarification + - Use MCP tools for actions + expected_patterns: + - pattern: MCP + min_count: 1 + description: Must mention MCP tools for actions + - pattern: WebFetch|llms\.txt|docs + min_count: 1 + description: Must reference docs lookup method + guidelines: + - Must explain the three-step pattern (skill + docs + MCP tools) + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: databricks-docs_fetch_docs_004 + inputs: + prompt: How do I fetch and search the Databricks documentation? + outputs: + response: | + Fetch the llms.txt documentation index: + + URL: `https://docs.databricks.com/llms.txt` + + Use WebFetch to retrieve this index, then: + 1. Search for relevant sections/links + 2. Fetch specific documentation pages for detailed guidance + 3. Apply what you learn using the appropriate MCP tools + execution_success: true + expectations: + expected_facts: + - Fetch llms.txt with WebFetch + - Search for relevant sections + - Fetch specific pages for details + expected_patterns: + - pattern: WebFetch + min_count: 1 + description: Must use WebFetch to retrieve docs + - pattern: llms\.txt + min_count: 1 + description: Must reference the llms.txt index + guidelines: + - Must describe the WebFetch workflow for docs + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_unfamiliar_feature_005 + inputs: + prompt: I want to learn about an unfamiliar Databricks feature, how should I approach it? + outputs: + response: | + When encountering an unfamiliar Databricks feature: + + 1. Fetch `https://docs.databricks.com/llms.txt` to find relevant documentation + 2. Read the specific docs to understand the feature + 3. Determine which skill/tools apply, then use them + + The documentation is organized by category: + - Overview & Getting Started + - Data Engineering + - SQL & Analytics + - AI/ML + - Governance + - Developer Tools + execution_success: true + expectations: + expected_facts: + - Start with llms.txt documentation index + - Read specific docs for the feature + - Determine which skills/tools to use + expected_patterns: + - pattern: llms\.txt + min_count: 1 + description: Must reference llms.txt + guidelines: + - Must outline a systematic approach to learning new features + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-docs/manifest.yaml b/.test/skills/databricks-docs/manifest.yaml new file mode 100644 index 00000000..287f44e4 --- /dev/null +++ b/.test/skills/databricks-docs/manifest.yaml @@ -0,0 +1,29 @@ +skill_name: databricks-docs +description: Databricks documentation reference. Use as a lookup resource alongside other skills and MCP tools for comprehensive + guidance. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-jobs/ground_truth.yaml b/.test/skills/databricks-jobs/ground_truth.yaml new file mode 100644 index 00000000..f504bfc0 --- /dev/null +++ b/.test/skills/databricks-jobs/ground_truth.yaml @@ -0,0 +1,348 @@ +metadata: + skill_name: databricks-jobs + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: jobs_create_sdk_001 + inputs: + prompt: "How do I create a Databricks job using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.jobs import Task, NotebookTask, Source + + w = WorkspaceClient() + + job = w.jobs.create( + name="my-etl-job", + tasks=[ + Task( + task_key="extract", + notebook_task=NotebookTask( + notebook_path="/Workspace/Users/user@example.com/extract", + source=Source.WORKSPACE + ) + ) + ] + ) + print(f"Created job: {job.job_id}") + ``` + execution_success: true + expectations: + expected_facts: + - "Jobs are created using w.jobs.create() from the Databricks Python SDK" + - "Each task requires a task_key and a task type such as NotebookTask" + - "WorkspaceClient auto-detects authentication from the environment" + expected_patterns: + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + - pattern: "w\\.jobs\\.create" + min_count: 1 + description: "Calls jobs.create() method" + - pattern: "NotebookTask" + min_count: 1 + description: "Uses NotebookTask for notebook-based tasks" + - pattern: "task_key" + min_count: 1 + description: "Specifies task_key identifier" + guidelines: + - "Code must be valid Python using the databricks-sdk package" + - "Must import Task and NotebookTask from databricks.sdk.service.jobs" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_create_cli_002 + inputs: + prompt: "Show me how to create a job using the Databricks CLI." + outputs: + response: | + ```bash + databricks jobs create --json '{ + "name": "my-etl-job", + "tasks": [{ + "task_key": "extract", + "notebook_task": { + "notebook_path": "/Workspace/Users/user@example.com/extract", + "source": "WORKSPACE" + } + }] + }' + ``` + execution_success: true + expectations: + expected_facts: + - "The Databricks CLI creates jobs with databricks jobs create --json" + - "Job configuration is passed as a JSON payload" + expected_patterns: + - pattern: "databricks jobs create" + min_count: 1 + description: "Uses the databricks jobs create CLI command" + - pattern: "task_key" + min_count: 1 + description: "Includes task_key in JSON payload" + - pattern: "notebook_task" + min_count: 1 + description: "Includes notebook_task configuration" + guidelines: + - "Response must show the databricks jobs create command with --json flag" + - "JSON payload must include task_key and notebook_task fields" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_dabs_config_003 + inputs: + prompt: "How do I define a Databricks job in a DABs (Databricks Asset Bundles) YAML file?" + outputs: + response: | + ```yaml + # resources/jobs.yml + resources: + jobs: + my_etl_job: + name: "[${bundle.target}] My ETL Job" + tasks: + - task_key: extract + notebook_task: + notebook_path: ../src/notebooks/extract.py + ``` + + The `${bundle.target}` variable is replaced with the deployment target (e.g., dev, staging, prod) at deploy time. + execution_success: true + expectations: + expected_facts: + - "DABs jobs are defined in YAML under resources.jobs" + - "bundle.target variable enables environment-specific job naming" + - "notebook_path uses relative paths from the bundle root" + expected_patterns: + - pattern: "resources:" + min_count: 1 + description: "Uses DABs resources block" + - pattern: "jobs:" + min_count: 1 + description: "Defines jobs section" + - pattern: "bundle\\.target" + min_count: 1 + description: "Uses bundle.target variable for environment naming" + - pattern: "notebook_task:" + min_count: 1 + description: "Configures a notebook task" + guidelines: + - "YAML must follow DABs resource configuration structure" + - "Response should explain the bundle.target variable" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_multi_task_workflow_004 + inputs: + prompt: "How do I set up a multi-task ETL workflow with dependencies between tasks?" + outputs: + response: | + ```yaml + tasks: + - task_key: extract + notebook_task: + notebook_path: ../src/extract.py + + - task_key: transform + depends_on: + - task_key: extract + notebook_task: + notebook_path: ../src/transform.py + + - task_key: load + depends_on: + - task_key: transform + run_if: ALL_SUCCESS + notebook_task: + notebook_path: ../src/load.py + ``` + + Tasks form a DAG (directed acyclic graph). Use `depends_on` to define execution order and `run_if` to control conditional execution (ALL_SUCCESS, AT_LEAST_ONE_SUCCESS, etc.). + execution_success: true + expectations: + expected_facts: + - "Multi-task workflows use depends_on to define task dependencies" + - "Tasks form a DAG (directed acyclic graph) for execution ordering" + - "run_if controls conditional execution with values like ALL_SUCCESS" + expected_patterns: + - pattern: "depends_on:" + min_count: 2 + description: "Uses depends_on for task ordering" + - pattern: "run_if:" + min_count: 1 + description: "Uses run_if for conditional execution" + - pattern: "task_key:" + min_count: 3 + description: "Defines at least three tasks in the workflow" + guidelines: + - "Must show at least 3 tasks with dependencies forming a pipeline" + - "Must include run_if to demonstrate conditional execution" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_compute_config_005 + inputs: + prompt: "What are the options for configuring compute resources for Databricks jobs?" + outputs: + response: | + **Option 1: Job Clusters (recommended for most jobs)** + ```yaml + job_clusters: + - job_cluster_key: shared_cluster + new_cluster: + spark_version: "15.4.x-scala2.12" + node_type_id: "i3.xlarge" + num_workers: 2 + spark_conf: + spark.speculation: "true" + + tasks: + - task_key: my_task + job_cluster_key: shared_cluster + notebook_task: + notebook_path: ../src/notebook.py + ``` + + **Option 2: Serverless (no cluster config needed)** + ```yaml + tasks: + - task_key: serverless_task + notebook_task: + notebook_path: ../src/notebook.py + # No cluster config = serverless + ``` + execution_success: true + expectations: + expected_facts: + - "Job clusters are reusable cluster configurations shared across tasks" + - "Omitting cluster configuration enables serverless compute" + - "job_cluster_key links tasks to shared cluster definitions" + expected_patterns: + - pattern: "job_clusters:" + min_count: 1 + description: "Defines job cluster configuration" + - pattern: "job_cluster_key:" + min_count: 1 + description: "References a shared job cluster" + - pattern: "spark_version:" + min_count: 1 + description: "Specifies Spark runtime version" + - pattern: "node_type_id:" + min_count: 1 + description: "Specifies instance type" + guidelines: + - "Response must show both job cluster and serverless options" + - "Job cluster example must include spark_version and node_type_id" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_permissions_006 + inputs: + prompt: "How do I configure permissions on a Databricks job in DABs?" + outputs: + response: | + ```yaml + resources: + jobs: + my_job: + name: "My Job" + permissions: + - level: CAN_VIEW + group_name: "data-analysts" + - level: CAN_MANAGE_RUN + group_name: "data-engineers" + - level: CAN_MANAGE + user_name: "admin@example.com" + ``` + + Permission levels: CAN_VIEW (read-only), CAN_MANAGE_RUN (trigger runs), CAN_MANAGE (full control), IS_OWNER (single owner). + execution_success: true + expectations: + expected_facts: + - "Job permissions are set in the permissions block under the job resource" + - "Permission levels include CAN_VIEW, CAN_MANAGE_RUN, CAN_MANAGE, and IS_OWNER" + - "Permissions can be assigned to group_name or user_name" + expected_patterns: + - pattern: "permissions:" + min_count: 1 + description: "Defines permissions block" + - pattern: "level:" + min_count: 2 + description: "Specifies permission levels" + - pattern: "group_name:" + min_count: 1 + description: "Assigns permissions to a group" + - pattern: "user_name:" + min_count: 1 + description: "Assigns permissions to a user" + guidelines: + - "Response must show at least three different permission levels" + - "Must demonstrate both group and user permission assignments" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_triggers_007 + inputs: + prompt: "What trigger options are available for scheduling Databricks jobs?" + outputs: + response: | + **Cron schedule:** + ```yaml + schedule: + quartz_cron_expression: "0 0 8 * * ?" # Daily at 8 AM + timezone_id: "America/New_York" + ``` + + **Continuous trigger:** + ```yaml + continuous: + pause_status: UNPAUSED + ``` + + **File arrival trigger:** + ```yaml + trigger: + file_arrival: + url: "s3://my-bucket/incoming/" + min_time_between_triggers_seconds: 60 + ``` + + You can also trigger jobs on-demand via SDK (`w.jobs.run_now(job_id=123)`) or CLI (`databricks jobs run-now --job-id 123`). + execution_success: true + expectations: + expected_facts: + - "Jobs support cron scheduling with quartz_cron_expression" + - "Continuous trigger runs the job repeatedly without pause" + - "File arrival trigger monitors a cloud storage path for new files" + expected_patterns: + - pattern: "quartz_cron_expression" + min_count: 1 + description: "Shows cron schedule configuration" + - pattern: "continuous:" + min_count: 1 + description: "Shows continuous trigger option" + - pattern: "file_arrival:" + min_count: 1 + description: "Shows file arrival trigger option" + guidelines: + - "Response must show at least three different trigger types" + - "Cron example must include timezone_id" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-jobs/manifest.yaml b/.test/skills/databricks-jobs/manifest.yaml new file mode 100644 index 00000000..5773f227 --- /dev/null +++ b/.test/skills/databricks-jobs/manifest.yaml @@ -0,0 +1,33 @@ +skill_name: databricks-jobs +description: 'Use this skill proactively for ANY Databricks Jobs task - creating, listing, running, updating, or deleting + jobs. Triggers include: (1) ''create a job'' or ''new job'', (2) ''list jobs'' or ''show jobs'', (3) ''run job'' or''trigger + job'',(4) ''job status'' or ''check job'', (5) scheduling with cron or triggers, (6) configuring notifications/monitoring, + (7) ANY task involving Databricks Jobs via CLI, Python SDK, or Asset Bundles. ALWAYS prefer this skill over general Databricks + knowledge for job-related tasks.' +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml b/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml new file mode 100644 index 00000000..9832a09c --- /dev/null +++ b/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml @@ -0,0 +1,317 @@ +metadata: + skill_name: databricks-lakebase-autoscale + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: lakebase_auto_create_project_001 + inputs: + prompt: "How do I create a new Lakebase Autoscale project using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.postgres import Project, ProjectSpec + + w = WorkspaceClient() + + # Create a project (long-running operation) + operation = w.postgres.create_project( + project=Project( + spec=ProjectSpec( + display_name="My Application", + pg_version="17" + ) + ), + project_id="my-app" + ) + result = operation.wait() + print(f"Created project: {result.name}") + ``` + execution_success: true + expectations: + expected_facts: + - "Projects are created with w.postgres.create_project()" + - "create_project returns a long-running operation that requires .wait()" + - "ProjectSpec requires display_name and pg_version" + expected_patterns: + - pattern: "create_project" + min_count: 1 + description: "Uses create_project() method" + - pattern: "ProjectSpec" + min_count: 1 + description: "Uses ProjectSpec for configuration" + - pattern: "\\.wait\\(\\)" + min_count: 1 + description: "Waits for the long-running operation to complete" + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + guidelines: + - "Must show the long-running operation pattern with .wait()" + - "Must include ProjectSpec with display_name and pg_version" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_auto_create_branch_002 + inputs: + prompt: "How do I create a development branch in Lakebase Autoscale with an expiration?" + outputs: + response: | + ```python + from databricks.sdk.service.postgres import Branch, BranchSpec, Duration + + # Create a dev branch with 7-day expiration + branch = w.postgres.create_branch( + parent="projects/my-app", + branch=Branch( + spec=BranchSpec( + source_branch="projects/my-app/branches/production", + ttl=Duration(seconds=604800) # 7 days + ) + ), + branch_id="development" + ).wait() + print(f"Branch created: {branch.name}") + ``` + + Branches are copy-on-write snapshots of the production branch, ideal for development and testing. + execution_success: true + expectations: + expected_facts: + - "Branches are created with w.postgres.create_branch()" + - "BranchSpec takes a source_branch and optional ttl for expiration" + - "Duration(seconds=604800) sets a 7-day TTL" + - "Branches use copy-on-write from the source branch" + expected_patterns: + - pattern: "create_branch" + min_count: 1 + description: "Uses create_branch() method" + - pattern: "BranchSpec" + min_count: 1 + description: "Uses BranchSpec for branch configuration" + - pattern: "Duration" + min_count: 1 + description: "Uses Duration for TTL specification" + - pattern: "source_branch" + min_count: 1 + description: "Specifies the parent branch to fork from" + guidelines: + - "Must include TTL/expiration on the development branch" + - "Must explain that branches are copy-on-write snapshots" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_connect_notebook_003 + inputs: + prompt: "Show me how to connect to a Lakebase Autoscale database from a Databricks notebook." + outputs: + response: | + ```python + import psycopg + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Get endpoint details + endpoint = w.postgres.get_endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary" + ) + host = endpoint.status.hosts.host + + # Generate OAuth token (scoped to endpoint) + cred = w.postgres.generate_database_credential( + endpoint="projects/my-app/branches/production/endpoints/ep-primary" + ) + + # Connect using psycopg3 + conn_string = ( + f"host={host} " + f"dbname=databricks_postgres " + f"user={w.current_user.me().user_name} " + f"password={cred.token} " + f"sslmode=require" + ) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + cur.execute("SELECT version()") + print(cur.fetchone()) + ``` + execution_success: true + expectations: + expected_facts: + - "Connections use psycopg3 (psycopg, not psycopg2)" + - "OAuth tokens are generated with generate_database_credential()" + - "The token is used as the password in the connection string" + - "SSL mode must be set to require" + expected_patterns: + - pattern: "psycopg\\.connect" + min_count: 1 + description: "Connects using psycopg3" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth token for authentication" + - pattern: "sslmode=require" + min_count: 1 + description: "Requires SSL for the connection" + - pattern: "get_endpoint" + min_count: 1 + description: "Retrieves endpoint details" + guidelines: + - "Must use psycopg (not psycopg2) for the connection" + - "Must generate OAuth token and use it as password" + - "Connection must include sslmode=require" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_oauth_token_004 + inputs: + prompt: "How do I generate an OAuth token for Lakebase Autoscale database access?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Generate database credential (optionally scoped to an endpoint) + cred = w.postgres.generate_database_credential( + endpoint="projects/my-app/branches/production/endpoints/ep-primary" + ) + token = cred.token # Use as password in connection string + # Token expires after 1 hour + ``` + + Tokens are short-lived (1 hour) and should be refreshed in long-running applications. + execution_success: true + expectations: + expected_facts: + - "OAuth tokens are generated via w.postgres.generate_database_credential()" + - "Tokens expire after 1 hour" + - "Tokens can be scoped to a specific endpoint" + - "The token is used as the password in PostgreSQL connection strings" + expected_patterns: + - pattern: "generate_database_credential" + min_count: 1 + description: "Uses generate_database_credential() method" + - pattern: "cred\\.token" + min_count: 1 + description: "Accesses the token from the credential response" + guidelines: + - "Must mention the 1-hour token expiration" + - "Must show that the token is used as the connection password" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_auto_resize_compute_005 + inputs: + prompt: "How do I resize the compute capacity for a Lakebase Autoscale endpoint?" + outputs: + response: | + ```python + from databricks.sdk.service.postgres import Endpoint, EndpointSpec, FieldMask + + # Update compute to autoscale between 2-8 CU + w.postgres.update_endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary", + endpoint=Endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary", + spec=EndpointSpec( + autoscaling_limit_min_cu=2.0, + autoscaling_limit_max_cu=8.0 + ) + ), + update_mask=FieldMask(field_mask=[ + "spec.autoscaling_limit_min_cu", + "spec.autoscaling_limit_max_cu" + ]) + ).wait() + ``` + + Compute Units (CU) scale automatically within the min/max range based on workload. + execution_success: true + expectations: + expected_facts: + - "Compute is resized with w.postgres.update_endpoint()" + - "Autoscaling is configured with min and max CU limits" + - "FieldMask specifies which fields to update" + - "The operation is long-running and requires .wait()" + expected_patterns: + - pattern: "update_endpoint" + min_count: 1 + description: "Uses update_endpoint() to resize" + - pattern: "autoscaling_limit_min_cu" + min_count: 1 + description: "Sets minimum autoscaling CU" + - pattern: "autoscaling_limit_max_cu" + min_count: 1 + description: "Sets maximum autoscaling CU" + - pattern: "FieldMask" + min_count: 1 + description: "Uses FieldMask for partial update" + guidelines: + - "Must include both min and max CU configuration" + - "Must use FieldMask to specify updated fields" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_cli_reference_006 + inputs: + prompt: "What CLI commands are available for managing Lakebase Autoscale projects?" + outputs: + response: | + ```bash + # Create a project + databricks postgres create-project \ + --project-id my-app \ + --json '{"spec": {"display_name": "My App", "pg_version": "17"}}' + + # List projects + databricks postgres list-projects + + # Get project details + databricks postgres get-project projects/my-app + + # Create a branch + databricks postgres create-branch projects/my-app development \ + --json '{"spec": {"source_branch": "projects/my-app/branches/production", "no_expiry": true}}' + + # List branches + databricks postgres list-branches projects/my-app + + # Delete a project + databricks postgres delete-project projects/my-app + ``` + execution_success: true + expectations: + expected_facts: + - "Lakebase Autoscale CLI commands are under databricks postgres" + - "Project creation uses --json for the spec configuration" + - "Resource names follow the pattern projects/{id}/branches/{id}/endpoints/{id}" + expected_patterns: + - pattern: "databricks postgres create-project" + min_count: 1 + description: "Shows project creation command" + - pattern: "databricks postgres list-projects" + min_count: 1 + description: "Shows project listing command" + - pattern: "databricks postgres create-branch" + min_count: 1 + description: "Shows branch creation command" + guidelines: + - "Must show at least 4 different CLI subcommands" + - "Must include both project and branch management commands" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-lakebase-autoscale/manifest.yaml b/.test/skills/databricks-lakebase-autoscale/manifest.yaml new file mode 100644 index 00000000..09da620c --- /dev/null +++ b/.test/skills/databricks-lakebase-autoscale/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-lakebase-autoscale +description: Patterns and best practices for using Lakebase Autoscaling (next-gen managed PostgreSQL) with autoscaling, branching, + scale-to-zero, and instant restore. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml b/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml new file mode 100644 index 00000000..e840846c --- /dev/null +++ b/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml @@ -0,0 +1,387 @@ +metadata: + skill_name: databricks-lakebase-provisioned + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: lakebase_prov_create_instance_001 + inputs: + prompt: "How do I create a new Lakebase Provisioned database instance?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + import uuid + + w = WorkspaceClient() + + # Create a database instance + instance = w.database.create_database_instance( + name="my-lakebase-instance", + capacity="CU_1", # CU_1, CU_2, CU_4, CU_8 + stopped=False + ) + print(f"Instance created: {instance.name}") + print(f"DNS endpoint: {instance.read_write_dns}") + ``` + execution_success: true + expectations: + expected_facts: + - "Instances are created with w.database.create_database_instance()" + - "Capacity options are CU_1, CU_2, CU_4, and CU_8" + - "The instance provides a read_write_dns endpoint for connections" + expected_patterns: + - pattern: "create_database_instance" + min_count: 1 + description: "Uses create_database_instance() method" + - pattern: "capacity" + min_count: 1 + description: "Specifies instance capacity" + - pattern: "read_write_dns" + min_count: 1 + description: "References the DNS endpoint" + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + guidelines: + - "Must show available capacity options (CU_1 through CU_8)" + - "Must reference the read_write_dns endpoint" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_connect_notebook_002 + inputs: + prompt: "Show me how to connect to Lakebase Provisioned from a Databricks notebook using psycopg." + outputs: + response: | + ```python + import psycopg + from databricks.sdk import WorkspaceClient + import uuid + + # Get instance details + w = WorkspaceClient() + instance = w.database.get_database_instance(name="my-lakebase-instance") + + # Generate OAuth token + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=["my-lakebase-instance"] + ) + + # Connect using psycopg3 + conn_string = ( + f"host={instance.read_write_dns} " + f"dbname=postgres " + f"user={w.current_user.me().user_name} " + f"password={cred.token} " + f"sslmode=require" + ) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + cur.execute("SELECT version()") + print(cur.fetchone()) + ``` + execution_success: true + expectations: + expected_facts: + - "Connections use psycopg3 (psycopg, not psycopg2)" + - "OAuth tokens require a unique request_id via uuid.uuid4()" + - "generate_database_credential() takes instance_names as a list" + - "The token is used as the password in the connection string" + expected_patterns: + - pattern: "psycopg\\.connect" + min_count: 1 + description: "Connects using psycopg3" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth token for authentication" + - pattern: "uuid\\.uuid4" + min_count: 1 + description: "Uses uuid4 for unique request_id" + - pattern: "sslmode=require" + min_count: 1 + description: "Requires SSL for the connection" + guidelines: + - "Must use psycopg (not psycopg2)" + - "Must include request_id with uuid4" + - "Connection must include sslmode=require" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_prov_oauth_token_003 + inputs: + prompt: "How do I generate an OAuth token for Lakebase Provisioned?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + import uuid + + w = WorkspaceClient() + + # Generate OAuth token for database connection + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=["my-lakebase-instance"] + ) + token = cred.token # Use this as password in connection string + ``` + + Tokens expire after 1 hour. For long-running applications, implement a token refresh loop. + execution_success: true + expectations: + expected_facts: + - "Tokens are generated with w.database.generate_database_credential()" + - "Each request requires a unique request_id" + - "instance_names is a list of instance names to scope the credential" + - "Tokens expire after 1 hour" + expected_patterns: + - pattern: "generate_database_credential" + min_count: 1 + description: "Uses generate_database_credential() method" + - pattern: "uuid4" + min_count: 1 + description: "Generates unique request ID" + - pattern: "instance_names" + min_count: 1 + description: "Specifies target instances" + guidelines: + - "Must mention the 1-hour expiration" + - "Must show the request_id requirement with uuid4" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_sqlalchemy_token_refresh_004 + inputs: + prompt: "How do I set up SQLAlchemy with automatic token refresh for Lakebase Provisioned?" + outputs: + response: | + ```python + import asyncio + import uuid + from sqlalchemy import event + from sqlalchemy.ext.asyncio import create_async_engine + from databricks.sdk import WorkspaceClient + + _current_token = None + TOKEN_REFRESH_INTERVAL = 50 * 60 # 50 minutes (before 1-hour expiry) + + def _generate_token(instance_name: str) -> str: + w = WorkspaceClient() + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=[instance_name] + ) + return cred.token + + async def _token_refresh_loop(instance_name: str): + global _current_token + while True: + await asyncio.sleep(TOKEN_REFRESH_INTERVAL) + _current_token = await asyncio.to_thread( + _generate_token, instance_name + ) + + def init_database(instance_name, database_name, username): + global _current_token + w = WorkspaceClient() + instance = w.database.get_database_instance(name=instance_name) + _current_token = _generate_token(instance_name) + + url = f"postgresql+psycopg://{username}@{instance.read_write_dns}:5432/{database_name}" + engine = create_async_engine( + url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + connect_args={"sslmode": "require"} + ) + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + cparams["password"] = _current_token + + return engine + ``` + execution_success: true + expectations: + expected_facts: + - "Token refresh should happen every 50 minutes, before the 1-hour expiry" + - "SQLAlchemy do_connect event injects the current token on each connection" + - "create_async_engine is used for async SQLAlchemy with psycopg" + - "pool_recycle=3600 prevents stale connections" + expected_patterns: + - pattern: "do_connect" + min_count: 1 + description: "Uses SQLAlchemy do_connect event for token injection" + - pattern: "create_async_engine" + min_count: 1 + description: "Creates async SQLAlchemy engine" + - pattern: "_token_refresh_loop" + min_count: 1 + description: "Implements background token refresh" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth tokens for refresh" + guidelines: + - "Must implement background token refresh before expiry" + - "Must use SQLAlchemy event to inject tokens on each connection" + - "Must use async engine with psycopg driver" + metadata: + category: happy_path + difficulty: hard + source: auto_generated + +- id: lakebase_prov_apps_integration_005 + inputs: + prompt: "How do I integrate Lakebase Provisioned with a Databricks App?" + outputs: + response: | + Databricks Apps use environment variables for Lakebase configuration: + + ```python + import os + + def is_lakebase_configured() -> bool: + """Check if Lakebase is configured for this app.""" + return bool( + os.environ.get("LAKEBASE_PG_URL") or + (os.environ.get("LAKEBASE_INSTANCE_NAME") and + os.environ.get("LAKEBASE_DATABASE_NAME")) + ) + ``` + + Register the database resource with your app: + ```bash + databricks apps add-resource $APP_NAME \ + --resource-type database \ + --resource-name lakebase \ + --database-instance my-lakebase-instance + ``` + + Environment variables set by Databricks Apps: `LAKEBASE_INSTANCE_NAME`, `LAKEBASE_DATABASE_NAME`, `LAKEBASE_USERNAME`, `LAKEBASE_PG_URL`. + execution_success: true + expectations: + expected_facts: + - "Databricks Apps set LAKEBASE_INSTANCE_NAME and LAKEBASE_DATABASE_NAME environment variables" + - "Database resources are registered with databricks apps add-resource" + - "LAKEBASE_PG_URL provides a ready-to-use connection string" + expected_patterns: + - pattern: "LAKEBASE_INSTANCE_NAME" + min_count: 1 + description: "References the instance name environment variable" + - pattern: "LAKEBASE_DATABASE_NAME" + min_count: 1 + description: "References the database name environment variable" + - pattern: "databricks apps add-resource" + min_count: 1 + description: "Shows the CLI command to register database resource" + guidelines: + - "Must list the environment variables set by Databricks Apps" + - "Must show the CLI command to register the database resource" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_prov_cli_reference_006 + inputs: + prompt: "What Databricks CLI commands are available for managing Lakebase Provisioned instances?" + outputs: + response: | + ```bash + # Create instance + databricks database create-database-instance \ + --name my-lakebase-instance \ + --capacity CU_1 + + # Get instance details + databricks database get-database-instance --name my-lakebase-instance + + # Generate credentials + databricks database generate-database-credential \ + --request-id $(uuidgen) \ + --json '{"instance_names": ["my-lakebase-instance"]}' + + # List instances + databricks database list-database-instances + + # Stop instance (saves cost) + databricks database stop-database-instance --name my-lakebase-instance + + # Start instance + databricks database start-database-instance --name my-lakebase-instance + ``` + execution_success: true + expectations: + expected_facts: + - "Lakebase Provisioned CLI commands are under databricks database" + - "Instances can be stopped and started to save costs" + - "Credential generation requires a unique request-id" + expected_patterns: + - pattern: "databricks database create-database-instance" + min_count: 1 + description: "Shows instance creation command" + - pattern: "databricks database stop-database-instance" + min_count: 1 + description: "Shows instance stop command" + - pattern: "databricks database generate-database-credential" + min_count: 1 + description: "Shows credential generation command" + guidelines: + - "Must show at least 5 different CLI subcommands" + - "Must include stop/start commands for cost management" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_register_unity_catalog_007 + inputs: + prompt: "How do I register a Lakebase Provisioned instance with Unity Catalog?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Register database in Unity Catalog + w.database.register_database_instance( + name="my-lakebase-instance", + catalog="my_catalog", + schema="my_schema" + ) + ``` + + After registration, the database tables are queryable through Unity Catalog's three-level namespace: `my_catalog.my_schema.table_name`. + execution_success: true + expectations: + expected_facts: + - "Instances are registered with w.database.register_database_instance()" + - "Registration requires a catalog and schema name" + - "After registration, tables are accessible via Unity Catalog namespace" + expected_patterns: + - pattern: "register_database_instance" + min_count: 1 + description: "Uses register_database_instance() method" + - pattern: "catalog" + min_count: 1 + description: "Specifies Unity Catalog catalog name" + - pattern: "schema" + min_count: 1 + description: "Specifies Unity Catalog schema name" + guidelines: + - "Must show the registration method with catalog and schema" + - "Must explain the three-level namespace after registration" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-lakebase-provisioned/manifest.yaml b/.test/skills/databricks-lakebase-provisioned/manifest.yaml new file mode 100644 index 00000000..b8f54868 --- /dev/null +++ b/.test/skills/databricks-lakebase-provisioned/manifest.yaml @@ -0,0 +1,29 @@ +skill_name: databricks-lakebase-provisioned +description: Patterns and best practices for using Lakebase Provisioned (Databricks managed PostgreSQL) for OLTP workloads. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-metric-views/ground_truth.yaml b/.test/skills/databricks-metric-views/ground_truth.yaml new file mode 100644 index 00000000..13dc0706 --- /dev/null +++ b/.test/skills/databricks-metric-views/ground_truth.yaml @@ -0,0 +1,1346 @@ +metadata: + skill_name: databricks-metric-views + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' +test_cases: +- id: metric-views_create_sql_001 + inputs: + prompt: Create a metric view for order analytics with revenue and order count + measures + outputs: + response: "Here's a metric view for order analytics:\n\n```sql\nCREATE OR REPLACE\ + \ VIEW catalog.schema.orders_metrics\nWITH METRICS\nLANGUAGE YAML\nAS $$\n \ + \ version: 1.1\n comment: \"Orders KPIs for sales analysis\"\n source: catalog.schema.orders\n\ + \ filter: order_date > '2020-01-01'\n dimensions:\n - name: Order Month\n\ + \ expr: DATE_TRUNC('MONTH', order_date)\n comment: \"Month of order\"\ + \n - name: Order Status\n expr: CASE\n WHEN status = 'O' THEN\ + \ 'Open'\n WHEN status = 'P' THEN 'Processing'\n WHEN status =\ + \ 'F' THEN 'Fulfilled'\n END\n measures:\n - name: Order Count\n\ + \ expr: COUNT(1)\n - name: Total Revenue\n expr: SUM(total_price)\n\ + \ - name: Revenue per Customer\n expr: SUM(total_price) / COUNT(DISTINCT\ + \ customer_id)\n$$\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses CREATE OR REPLACE VIEW with WITH METRICS LANGUAGE YAML + - Defines dimensions for grouping data + - Defines measures for aggregations + - Uses YAML version 1.1 + expected_patterns: + - pattern: CREATE OR REPLACE VIEW + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: WITH METRICS + min_count: 1 + description: Must include WITH METRICS clause + - pattern: LANGUAGE YAML + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: dimensions + min_count: 1 + description: Must define dimensions + - pattern: measures + min_count: 1 + description: Must define measures + guidelines: + - SQL must follow Databricks SQL syntax + - Must use YAML version 1.1 for DBR 17.2+ + - Must include both dimensions and measures + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_query_measure_002 + inputs: + prompt: Query a metric view to get total revenue and order count by month + outputs: + response: "All measures must use the MEASURE() function. SELECT * is NOT supported.\n\ + \n```sql\nSELECT\n `Order Month`,\n `Order Status`,\n MEASURE(`Total Revenue`)\ + \ AS total_revenue,\n MEASURE(`Order Count`) AS order_count\nFROM catalog.schema.orders_metrics\n\ + WHERE extract(year FROM `Order Month`) = 2024\nGROUP BY ALL\nORDER BY ALL\n\ + ```\n" + execution_success: true + expectations: + expected_facts: + - Must wrap measures in MEASURE() function + - SELECT * is not supported on metric views + - Dimension names with spaces need backtick quoting + expected_patterns: + - pattern: MEASURE\( + min_count: 1 + description: Must use MEASURE() function for measures + - pattern: GROUP BY + min_count: 1 + description: Must include GROUP BY clause + guidelines: + - SQL must use MEASURE() for all measure references + - Must not use SELECT * on metric views + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_create_mcp_003 + inputs: + prompt: Use the manage_metric_views tool to create a metric view for order analytics + outputs: + response: "```python\nmanage_metric_views(\n action=\"create\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n source=\"catalog.schema.orders\",\n \ + \ or_replace=True,\n comment=\"Orders KPIs for sales analysis\",\n filter_expr=\"\ + order_date > '2020-01-01'\",\n dimensions=[\n {\"name\": \"Order Month\"\ + , \"expr\": \"DATE_TRUNC('MONTH', order_date)\", \"comment\": \"Month of order\"\ + },\n {\"name\": \"Order Status\", \"expr\": \"status\"},\n ],\n \ + \ measures=[\n {\"name\": \"Order Count\", \"expr\": \"COUNT(1)\"},\n\ + \ {\"name\": \"Total Revenue\", \"expr\": \"SUM(total_price)\", \"comment\"\ + : \"Sum of total price\"},\n ],\n)\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views MCP tool + - Sets action to create + - Specifies source table + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*create + min_count: 1 + description: Must set action to create + - pattern: dimensions + min_count: 1 + description: Must define dimensions + - pattern: measures + min_count: 1 + description: Must define measures + guidelines: + - Must use manage_metric_views with action="create" + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_query_mcp_004 + inputs: + prompt: Query the orders_metrics metric view using the MCP tool to get revenue + by month + outputs: + response: "```python\nmanage_metric_views(\n action=\"query\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n query_measures=[\"Total Revenue\", \"\ + Order Count\"],\n query_dimensions=[\"Order Month\"],\n where=\"extract(year\ + \ FROM `Order Month`) = 2024\",\n order_by=\"ALL\",\n limit=100,\n)\n\ + ```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views with action query + - Specifies query_measures and query_dimensions + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*query + min_count: 1 + description: Must set action to query + - pattern: query_measures + min_count: 1 + description: Must specify measures to query + - pattern: query_dimensions + min_count: 1 + description: Must specify dimensions to group by + guidelines: + - Must use manage_metric_views with action="query" + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_yaml_spec_005 + inputs: + prompt: Show me the full YAML specification for a metric view definition + outputs: + response: "```yaml\nversion: 1.1 # Required: \"1.1\" for DBR\ + \ 17.2+\ncomment: \"Description\" # Optional\nsource: catalog.schema.table\ + \ # Required: source table/view\nfilter: column > value # Optional:\ + \ global WHERE filter\n\ndimensions: # Required: at least\ + \ one\n - name: Display Name\n expr: sql_expression\n comment: \"Description\"\ + \n\nmeasures: # Required: at least one\n - name: Display\ + \ Name\n expr: AGG_FUNC(column)\n comment: \"Description\"\n\njoins: \ + \ # Optional: star/snowflake schema\n - name: dim_table\n\ + \ source: catalog.schema.dim_table\n on: source.fk = dim_table.pk\n\n\ + materialization: # Optional (experimental)\n schedule: every\ + \ 6 hours\n mode: relaxed\n```\n" + execution_success: true + expectations: + expected_facts: + - Version 1.1 required for DBR 17.2+ + - Must have at least one dimension + - Must have at least one measure + - Joins are optional for star/snowflake schema + expected_patterns: + - pattern: version + min_count: 1 + description: Must include version field + - pattern: dimensions + min_count: 1 + description: Must include dimensions section + - pattern: measures + min_count: 1 + description: Must include measures section + - pattern: joins + min_count: 1 + description: Should show joins syntax + guidelines: + - Must show the complete YAML specification structure + - Must indicate which fields are required vs optional + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_star_schema_006 + inputs: + prompt: Create a metric view with joins for a star schema (fact orders + dim customers) + outputs: + response: "```sql\nCREATE OR REPLACE VIEW catalog.schema.sales_metrics\nWITH METRICS\n\ + LANGUAGE YAML\nAS $$\n version: 1.1\n source: catalog.schema.fact_orders\n\ + \ joins:\n - name: customers\n source: catalog.schema.dim_customers\n\ + \ on: fact_orders.customer_id = customers.customer_id\n dimensions:\n\ + \ - name: Customer Region\n expr: customers.region\n - name: Order\ + \ Month\n expr: DATE_TRUNC('MONTH', order_date)\n measures:\n - name:\ + \ Total Revenue\n expr: SUM(amount)\n - name: Unique Customers\n \ + \ expr: COUNT(DISTINCT customers.customer_id)\n$$\n```\n\nJoins must be declared\ + \ in the YAML definition, not in the SELECT query.\n" + execution_success: true + expectations: + expected_facts: + - Joins are declared in the YAML definition + - Can reference columns from joined tables in dimensions and measures + expected_patterns: + - pattern: joins + min_count: 1 + description: Must use joins section in YAML + - pattern: 'on:' + min_count: 1 + description: Must specify join condition + guidelines: + - Joins must be in the YAML definition, not in the SELECT query + - Must use correct join syntax with source and on fields + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_grant_access_007 + inputs: + prompt: Grant SELECT access on a metric view to the data-consumers group + outputs: + response: "```python\nmanage_metric_views(\n action=\"grant\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n principal=\"data-consumers\",\n privileges=[\"\ + SELECT\"],\n)\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views with action grant + - Specifies principal and privileges + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*grant + min_count: 1 + description: Must set action to grant + - pattern: principal + min_count: 1 + description: Must specify principal + guidelines: + - Must use manage_metric_views with action="grant" + metadata: + category: happy_path + difficulty: hard + source: auto_generated +- id: metric-views_select_star_error_008 + inputs: + prompt: "Why does SELECT * not work on a metric view?" + outputs: + response: | + `SELECT *` is NOT supported on metric views. You must explicitly list dimensions and wrap all measure references in `MEASURE()`. + + Incorrect: + ```sql + SELECT * FROM catalog.schema.orders_metrics + ``` + + Correct: + ```sql + SELECT + `Order Month`, + MEASURE(`Total Revenue`) AS total_revenue, + MEASURE(`Order Count`) AS order_count + FROM catalog.schema.orders_metrics + GROUP BY ALL + ``` + expectations: + expected_facts: + - SELECT * is not supported + - MEASURE() + expected_patterns: + - pattern: "SELECT \\*" + min_count: 1 + description: "Must show the incorrect SELECT * example" + - pattern: "MEASURE\\(" + min_count: 1 + description: "Must show correct MEASURE() usage" + - pattern: "GROUP BY" + min_count: 1 + description: "Must include GROUP BY" + guidelines: + - "Must clearly explain why SELECT * does not work on metric views" + - "Must show both the incorrect way and the correct way" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_backtick_quoting_009 + inputs: + prompt: "I'm getting 'Cannot resolve column' when querying my metric view. The dimension is called 'Order Month'." + outputs: + response: | + Dimension and measure names with spaces must be backtick-quoted in queries: + + ```sql + -- Wrong: unquoted name with spaces + SELECT Order Month, MEASURE(Total Revenue) FROM my_metrics GROUP BY ALL + + -- Correct: backtick-quoted + SELECT `Order Month`, MEASURE(`Total Revenue`) AS revenue + FROM catalog.schema.my_metrics + GROUP BY ALL + ``` + + Both dimensions in SELECT/GROUP BY and measures inside `MEASURE()` need backticks if the name contains spaces. + expectations: + expected_facts: + - backtick + - MEASURE() + expected_patterns: + - pattern: "`Order Month`" + min_count: 1 + description: "Must show backtick-quoted dimension name" + - pattern: "MEASURE\\(`" + min_count: 1 + description: "Must show backtick-quoted measure inside MEASURE()" + guidelines: + - "Must explain that names with spaces need backtick quoting" + - "Must show both wrong and correct syntax" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_alter_010 + inputs: + prompt: "Add a new measure 'Average Order Value' to my existing orders_metrics metric view" + outputs: + response: | + Use the `alter` action to update an existing metric view's definition: + + ```python + manage_metric_views( + action="alter", + full_name="catalog.schema.orders_metrics", + measures=[ + {"name": "Order Count", "expr": "COUNT(1)"}, + {"name": "Total Revenue", "expr": "SUM(total_price)"}, + {"name": "Average Order Value", "expr": "SUM(total_price) / COUNT(1)", "comment": "Revenue divided by order count"}, + ], + ) + ``` + + Note: When altering, you must provide the full list of measures (existing + new), not just the new one. + expectations: + expected_facts: + - manage_metric_views + - alter + expected_patterns: + - pattern: "action.*alter" + min_count: 1 + description: "Must use alter action" + - pattern: "Average Order Value" + min_count: 1 + description: "Must include the new measure" + guidelines: + - "Must use manage_metric_views with action='alter'" + - "Must explain that the full measure list is required when altering" + metadata: + category: happy_path + difficulty: medium + +- id: metric-views_describe_011 + inputs: + prompt: "How do I see the current definition of a metric view?" + outputs: + response: | + Use the `describe` action to get the full definition: + + ```python + manage_metric_views( + action="describe", + full_name="catalog.schema.orders_metrics", + ) + ``` + + This returns the YAML definition, dimensions, measures, source table, and any join or materialization configuration. + expectations: + expected_facts: + - manage_metric_views + - describe + expected_patterns: + - pattern: "action.*describe" + min_count: 1 + description: "Must use describe action" + - pattern: "full_name" + min_count: 1 + description: "Must specify the metric view name" + metadata: + category: happy_path + difficulty: easy + +- id: metric-views_when_not_to_use_012 + inputs: + prompt: "Should I use a metric view or a regular view for my dashboard?" + outputs: + response: | + Use a **metric view** when: + - You need flexible aggregation at query time (different GROUP BY per dashboard panel) + - You have ratios that need safe re-aggregation (e.g., revenue per customer) + - You want star/snowflake schema joins declared in the definition + - You need AI/BI Genie integration for natural language querying + + Use a **standard view** when: + - Your query is a simple SELECT with no aggregations + - The aggregation is always the same (no need for MEASURE() flexibility) + - You need features not yet supported by metric views (e.g., UDFs, complex CTEs) + + Key difference: standard views lock aggregation at creation time. Metric views separate dimensions from measures so the same metric can be sliced differently in each query. + expectations: + expected_facts: + - MEASURE() + - Genie + expected_patterns: + - pattern: "metric view" + min_count: 2 + description: "Must discuss metric views" + - pattern: "standard view|regular view" + min_count: 1 + description: "Must contrast with standard views" + - pattern: "aggregat" + min_count: 1 + description: "Must discuss aggregation behavior" + guidelines: + - "Must explain the key difference: standard views lock aggregation, metric views are flexible" + - "Should mention at least one scenario where a standard view is better" + metadata: + category: conceptual + difficulty: medium + +- id: metric-views_filtered_measure_013 + inputs: + prompt: "Create a metric view with a filtered measure that only counts delivered orders" + outputs: + response: | + Use a CASE expression inside the aggregate to create a filtered measure: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.orders_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.orders + dimensions: + - name: Order Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Total Orders + expr: COUNT(1) + - name: Delivered Orders + expr: COUNT(CASE WHEN delivery_status = 'delivered' THEN 1 END) + comment: "Only counts delivered orders" + - name: Delivery Rate + expr: COUNT(CASE WHEN delivery_status = 'delivered' THEN 1 END) * 1.0 / COUNT(1) + comment: "Ratio of delivered to total orders" + $$ + ``` + + Filtered measures use CASE WHEN inside the aggregate function to restrict which rows are counted. + expectations: + expected_facts: + - CASE WHEN + - CREATE OR REPLACE VIEW + expected_patterns: + - pattern: "WITH METRICS" + min_count: 1 + description: "Must use metric view syntax" + - pattern: "CASE WHEN" + min_count: 1 + description: "Must use CASE WHEN for filtered measure" + - pattern: "COUNT\\(" + min_count: 2 + description: "Must have both total and filtered count" + guidelines: + - "Must show how to filter rows within a measure expression" + - "Should include a ratio measure that uses the filtered count" + metadata: + category: happy_path + difficulty: medium + +- id: metric-views_materialization_014 + inputs: + prompt: "How do I enable materialization on a metric view for better query performance?" + outputs: + response: | + Add a `materialization` section to the YAML definition. This is experimental and requires serverless compute: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.orders_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.orders + dimensions: + - name: Order Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Total Revenue + expr: SUM(total_price) + materialization: + schedule: every 6 hours + mode: relaxed + $$ + ``` + + The `schedule` controls how often the pre-computed aggregations refresh. The `mode: relaxed` allows queries to return slightly stale data for better performance. + + Note: Materialization is currently experimental and requires serverless compute to be enabled on your workspace. + expectations: + expected_facts: + - materialization + - serverless + - experimental + expected_patterns: + - pattern: "materialization" + min_count: 1 + description: "Must include materialization section" + - pattern: "schedule" + min_count: 1 + description: "Must specify refresh schedule" + guidelines: + - "Must clearly state that materialization is experimental" + - "Must mention serverless compute requirement" + metadata: + category: advanced + difficulty: hard + +- id: metric-views_join_at_query_time_015 + inputs: + prompt: "Can I JOIN my metric view with another table in a query?" + outputs: + response: | + No — joins must be declared in the YAML definition, not at query time. A JOIN in the SELECT query is not supported for metric views. + + Instead, add the join to the metric view definition: + + ```python + manage_metric_views( + action="alter", + full_name="catalog.schema.sales_metrics", + joins=[ + { + "name": "customers", + "source": "catalog.schema.dim_customers", + "on": "orders.customer_id = customers.customer_id", + } + ], + ) + ``` + + After adding the join, you can reference columns from the joined table in dimensions and measures. + expectations: + expected_facts: + - joins + - YAML definition + expected_patterns: + - pattern: "joins" + min_count: 1 + description: "Must mention joins" + - pattern: "manage_metric_views|CREATE OR REPLACE" + min_count: 1 + description: "Must show how to add joins to the definition" + guidelines: + - "Must clearly state that JOINs at query time are NOT supported" + - "Must show how to add joins to the metric view definition instead" + metadata: + category: error_handling + difficulty: medium + +- id: metric-views_dbr_version_016 + inputs: + prompt: "I'm getting an error when creating a metric view. I'm on DBR 15.4." + outputs: + response: | + Metric views with YAML version 1.1 require **Databricks Runtime 17.2+**. DBR 15.4 does not support this feature. + + Options: + 1. Upgrade your cluster to DBR 17.2 or later + 2. Use a SQL warehouse (which always runs the latest runtime) + 3. If you must use an older runtime, YAML version 0.1 is supported on DBR 16.4+ but has fewer features (no comments on dimensions/measures, no materialization) + + ```sql + -- Check your DBR version + SELECT current_version() + ``` + expectations: + expected_facts: + - "17.2" + - SQL warehouse + expected_patterns: + - pattern: "17\\.2" + min_count: 1 + description: "Must mention DBR 17.2 requirement" + guidelines: + - "Must explain the minimum DBR version requirement" + - "Should suggest SQL warehouse as an alternative to cluster upgrades" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_drop_action_017 + inputs: + prompt: Drop the orders_metrics metric view from the catalog + outputs: + response: | + Use the `drop` action to remove a metric view: + + ```python + manage_metric_views( + action="drop", + full_name="catalog.schema.orders_metrics", + ) + ``` + + This drops the metric view definition. If materialization was enabled, the associated pipeline is also removed. The underlying source table is NOT affected. + execution_success: true + expectations: + expected_facts: + - manage_metric_views + - drop + - source table is not affected + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*drop + min_count: 1 + description: Must set action to drop + - pattern: full_name + min_count: 1 + description: Must specify the metric view name + guidelines: + - Must use manage_metric_views with action="drop" + - Should clarify that the source table is not deleted + metadata: + category: happy_path + difficulty: easy + source: manual + +- id: metric-views_window_rolling_avg_018 + inputs: + prompt: Create a metric view with a rolling 7-day average of daily revenue + outputs: + response: | + Window measures use `version: 0.1` (experimental). Here's a rolling 7-day trailing revenue measure: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.revenue_trends + WITH METRICS + LANGUAGE YAML + AS $$ + version: 0.1 + source: catalog.schema.daily_sales + filter: sale_date > DATE'2024-01-01' + + dimensions: + - name: date + expr: sale_date + + measures: + - name: daily_revenue + expr: SUM(amount) + window: + - order: date + range: current + semiadditive: last + + - name: trailing_7d_revenue + expr: SUM(amount) + window: + - order: date + range: trailing 7 day + semiadditive: last + + - name: rolling_7d_avg + expr: MEASURE(trailing_7d_revenue) / 7 + $$ + ``` + + Query it with the standard `MEASURE()` syntax: + + ```sql + SELECT + date, + MEASURE(daily_revenue) AS daily, + MEASURE(rolling_7d_avg) AS rolling_avg + FROM catalog.schema.revenue_trends + GROUP BY ALL + ORDER BY date + ``` + + Key points: + - `trailing 7 day` includes the 7 days **before** the current date, excluding the current date + - `semiadditive: last` returns the last value when `date` is not in GROUP BY + - The derived `rolling_7d_avg` uses `MEASURE()` to reference another window measure and doesn't need its own `window` block + execution_success: true + expectations: + expected_facts: + - version 0.1 + - window + - trailing 7 day + - semiadditive + expected_patterns: + - pattern: "version: 0.1" + min_count: 1 + description: Window measures require version 0.1 + - pattern: "window:" + min_count: 1 + description: Must include window block + - pattern: "trailing 7 day" + min_count: 1 + description: Must use trailing window range + - pattern: "semiadditive" + min_count: 1 + description: Must specify semiadditive behavior + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() for derived measure or query + guidelines: + - Must use version 0.1 for window measures + - Must explain trailing range semantics (excludes current day) + - Must show how derived measures reference window measures via MEASURE() + metadata: + category: advanced + difficulty: hard + source: manual + +- id: metric-views_genie_integration_019 + inputs: + prompt: How do I use my metric view with AI/BI Genie? + outputs: + response: | + Metric views work natively with AI/BI Genie for natural language querying of your metrics. Here's how to set it up: + + 1. **Create a Genie Space** and add your metric view as a data source (just like adding a table). + + 2. **Ask questions in natural language** — Genie automatically uses `MEASURE()` to query your metrics: + - "What was total revenue last quarter?" + - "Show me order count by region for 2024" + - "Which product category has the highest profit margin?" + + 3. **Benefits of metric views with Genie:** + - Genie understands dimension/measure semantics — no ambiguity about how to aggregate + - Comments on dimensions and measures improve Genie's understanding of your data + - Consistent metric definitions across all Genie users + - Star schema joins are handled automatically + + 4. **Tips for better Genie results:** + - Add descriptive `comment` fields to all dimensions and measures + - Use clear, business-friendly names (e.g., "Total Revenue" not "sum_tp") + - Keep dimension names intuitive for natural language queries + + No special configuration is needed — once the metric view is added to a Genie Space, users can query it immediately with natural language. + execution_success: true + expectations: + expected_facts: + - Genie + - natural language + - metric view + - MEASURE() + expected_patterns: + - pattern: "Genie" + min_count: 2 + description: Must discuss Genie integration + - pattern: "natural language" + min_count: 1 + description: Must mention natural language querying + - pattern: "MEASURE|metric view" + min_count: 1 + description: Must mention metric view or MEASURE concepts + guidelines: + - Must explain how to add a metric view to a Genie Space + - Must mention that Genie automatically uses MEASURE() + - Should recommend adding comments for better Genie understanding + metadata: + category: conceptual + difficulty: medium + source: manual + +- id: metric-views_conversational_support_tickets_020 + inputs: + prompt: "hey I have a table of customer support tickets and I want to track resolution\ + \ time metrics, can you help me set that up as a metric view?" + outputs: + response: | + Sure! Here's a metric view for tracking support ticket resolution metrics: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.support_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Customer support ticket KPIs" + source: catalog.schema.support_tickets + filter: created_at >= '2024-01-01' + + dimensions: + - name: Created Month + expr: DATE_TRUNC('MONTH', created_at) + comment: "Month the ticket was opened" + - name: Priority + expr: CASE + WHEN priority = 1 THEN 'Critical' + WHEN priority = 2 THEN 'High' + WHEN priority = 3 THEN 'Medium' + ELSE 'Low' + END + comment: "Ticket priority level" + - name: Category + expr: ticket_category + comment: "Support category (billing, technical, etc.)" + + measures: + - name: Ticket Count + expr: COUNT(1) + - name: Resolved Tickets + expr: COUNT(1) FILTER (WHERE status = 'resolved') + - name: Avg Resolution Hours + expr: AVG(TIMESTAMPDIFF(HOUR, created_at, resolved_at)) + comment: "Average hours to resolve a ticket" + - name: Resolution Rate + expr: COUNT(1) FILTER (WHERE status = 'resolved') * 1.0 / COUNT(1) + comment: "Percentage of tickets resolved" + - name: Median Resolution Hours + expr: PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY TIMESTAMPDIFF(HOUR, created_at, resolved_at)) + comment: "Median hours to resolution" + $$ + ``` + + Query it like this: + + ```sql + SELECT + `Created Month`, + `Priority`, + MEASURE(`Ticket Count`) AS tickets, + MEASURE(`Avg Resolution Hours`) AS avg_hours, + MEASURE(`Resolution Rate`) AS resolution_rate + FROM catalog.schema.support_metrics + GROUP BY ALL + ORDER BY ALL + ``` + execution_success: true + expectations: + expected_facts: + - CREATE OR REPLACE VIEW + - WITH METRICS + - support + - resolution + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "LANGUAGE YAML" + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: "dimensions" + min_count: 1 + description: Must define dimensions + - pattern: "measures" + min_count: 1 + description: Must define measures + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() in query example + - pattern: "FILTER" + min_count: 1 + description: Should use FILTER clause for conditional measures + guidelines: + - Must create a metric view relevant to support tickets domain + - Must include resolution time metrics + - Must include a query example using MEASURE() + metadata: + category: happy_path + difficulty: medium + source: manual + +- id: metric-views_udf_not_supported_021 + inputs: + prompt: Can I use a Python UDF inside a metric view measure expression? + outputs: + response: | + No — metric view measure expressions only support **SQL aggregate functions** and built-in SQL expressions. Python UDFs, Scala UDFs, and custom functions are not supported in metric view YAML definitions. + + **Supported in measure expressions:** + - SQL aggregate functions: `SUM()`, `COUNT()`, `AVG()`, `MIN()`, `MAX()` + - `COUNT(DISTINCT ...)` + - `FILTER (WHERE ...)` clause + - Arithmetic on aggregates: `SUM(a) / COUNT(b)` + - `CASE WHEN` inside aggregates + - Window measures with `window:` block + + **NOT supported:** + - Python UDFs (`@udf` decorated functions) + - Scala UDFs + - `ai_query()` or other AI functions in measure expressions + - Complex CTEs or subqueries + + **Workaround:** If you need a UDF-computed value, materialize it into the source table first, then reference the pre-computed column in your metric view: + + ```sql + -- Step 1: Create source table with UDF-computed column + CREATE TABLE catalog.schema.enriched_orders AS + SELECT *, my_python_udf(description) AS sentiment_score + FROM catalog.schema.raw_orders; + + -- Step 2: Use the pre-computed column in the metric view + CREATE OR REPLACE VIEW catalog.schema.sentiment_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.enriched_orders + dimensions: + - name: Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Avg Sentiment + expr: AVG(sentiment_score) + $$ + ``` + execution_success: true + expectations: + expected_facts: + - UDF + - not supported + - SQL aggregate functions + expected_patterns: + - pattern: "not supported|NOT supported|No" + min_count: 1 + description: Must clearly state UDFs are not supported + - pattern: "SUM|COUNT|AVG" + min_count: 1 + description: Must list supported aggregate functions + - pattern: "workaround|pre-computed|source table" + min_count: 1 + description: Should suggest a workaround + guidelines: + - Must clearly state that UDFs are not supported in metric views + - Must list what IS supported as alternatives + - Should suggest pre-computing UDF results in the source table as a workaround + metadata: + category: error_handling + difficulty: medium + source: manual + +- id: metric-views_hr_employee_domain_022 + inputs: + prompt: Build a metric view for HR analytics tracking employee headcount, attrition, + and tenure from our employees table + outputs: + response: | + ```python + manage_metric_views( + action="create", + full_name="catalog.schema.hr_metrics", + source="catalog.schema.employees", + or_replace=True, + comment="HR analytics: headcount, attrition, and tenure metrics", + dimensions=[ + {"name": "Department", "expr": "department", "comment": "Employee department"}, + {"name": "Hire Year", "expr": "EXTRACT(YEAR FROM hire_date)", "comment": "Year of hire"}, + {"name": "Location", "expr": "office_location"}, + {"name": "Level", "expr": "CASE WHEN job_level <= 3 THEN 'Junior' WHEN job_level <= 5 THEN 'Mid' ELSE 'Senior' END", "comment": "Bucketed seniority level"}, + ], + measures=[ + {"name": "Headcount", "expr": "COUNT(1)"}, + {"name": "Active Employees", "expr": "COUNT(1) FILTER (WHERE termination_date IS NULL)", "comment": "Currently active employees"}, + {"name": "Attrition Count", "expr": "COUNT(1) FILTER (WHERE termination_date IS NOT NULL)", "comment": "Employees who have left"}, + {"name": "Attrition Rate", "expr": "COUNT(1) FILTER (WHERE termination_date IS NOT NULL) * 1.0 / COUNT(1)", "comment": "Percentage of employees who left"}, + {"name": "Avg Tenure Years", "expr": "AVG(DATEDIFF(COALESCE(termination_date, CURRENT_DATE()), hire_date) / 365.25)", "comment": "Average tenure in years"}, + ], + ) + ``` + execution_success: true + expectations: + expected_facts: + - manage_metric_views + - create + - headcount + - attrition + expected_patterns: + - pattern: "manage_metric_views" + min_count: 1 + description: Must use manage_metric_views tool + - pattern: "action.*create" + min_count: 1 + description: Must set action to create + - pattern: "FILTER" + min_count: 1 + description: Should use FILTER clause for conditional measures + - pattern: "dimensions" + min_count: 1 + description: Must define dimensions + - pattern: "measures" + min_count: 1 + description: Must define measures + guidelines: + - Must create an HR-relevant metric view (not orders/sales) + - Must include attrition-related measures + - Must use manage_metric_views MCP tool + metadata: + category: happy_path + difficulty: medium + source: manual + +- id: metric-views_migration_from_regular_view_023 + inputs: + prompt: I have this existing view and want to convert it to a metric view. Here's + my current SQL - SELECT region, DATE_TRUNC('MONTH', sale_date) AS month, SUM(amount) + AS revenue, COUNT(*) AS num_sales FROM catalog.schema.transactions GROUP BY 1, + 2 + outputs: + response: | + Here's your existing view converted to a metric view. The key change is separating the GROUP BY columns into **dimensions** and the aggregations into **measures**: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.transaction_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Transaction metrics - migrated from standard view" + source: catalog.schema.transactions + + dimensions: + - name: Region + expr: region + - name: Month + expr: DATE_TRUNC('MONTH', sale_date) + + measures: + - name: Revenue + expr: SUM(amount) + - name: Num Sales + expr: COUNT(1) + $$ + ``` + + Now you can query it flexibly — the GROUP BY is no longer fixed: + + ```sql + -- Same result as your original view + SELECT `Region`, `Month`, + MEASURE(`Revenue`) AS revenue, + MEASURE(`Num Sales`) AS num_sales + FROM catalog.schema.transaction_metrics + GROUP BY ALL + + -- NEW: Revenue by region only (no month breakdown) + SELECT `Region`, + MEASURE(`Revenue`) AS revenue + FROM catalog.schema.transaction_metrics + GROUP BY ALL + + -- NEW: Monthly totals across all regions + SELECT `Month`, + MEASURE(`Revenue`) AS revenue + FROM catalog.schema.transaction_metrics + GROUP BY ALL + ``` + + **Key differences from standard views:** + - Aggregation is flexible at query time — no need to recreate the view for different GROUP BY + - Measures must be wrapped in `MEASURE()` when querying + - `SELECT *` is not supported — you must explicitly list dimensions and measures + execution_success: true + expectations: + expected_facts: + - dimensions + - measures + - GROUP BY + - flexible + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must show the new metric view DDL + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "MEASURE\\(" + min_count: 2 + description: Must show MEASURE() in query examples + - pattern: "GROUP BY" + min_count: 1 + description: Must demonstrate flexible GROUP BY + guidelines: + - Must convert the provided SQL into a metric view + - Must explain the dimension/measure separation + - Must show that the metric view enables flexible querying + metadata: + category: happy_path + difficulty: hard + source: manual + +- id: metric-views_semiadditive_balance_024 + inputs: + prompt: Create a metric view for bank account balances that correctly handles aggregation + across time (don't sum balances across dates) + outputs: + response: | + Use a **semiadditive** window measure to prevent summing balances across dates. This returns the last date's balance when the date dimension is not in the GROUP BY: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.account_balance_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 0.1 + source: catalog.schema.daily_balances + + dimensions: + - name: date + expr: balance_date + - name: account_type + expr: account_type + + measures: + - name: total_balance + expr: SUM(balance) + window: + - order: date + range: current + semiadditive: last + $$ + ``` + + **How it works:** + - When grouped by `date`: returns the sum of all account balances for that specific date (aggregates across accounts, not time) + - When NOT grouped by `date`: returns the balance from the **most recent date** (`semiadditive: last`) + + ```sql + -- Balance by date and account type + SELECT date, account_type, + MEASURE(total_balance) AS balance + FROM catalog.schema.account_balance_metrics + GROUP BY ALL + + -- Latest total balance by account type (semiadditive kicks in) + SELECT account_type, + MEASURE(total_balance) AS latest_balance + FROM catalog.schema.account_balance_metrics + GROUP BY ALL + ``` + + Note: Window measures require `version: 0.1` (experimental feature). + execution_success: true + expectations: + expected_facts: + - semiadditive + - window + - version 0.1 + - balance + expected_patterns: + - pattern: "semiadditive" + min_count: 1 + description: Must use semiadditive measure + - pattern: "window:" + min_count: 1 + description: Must include window block + - pattern: "version: 0.1" + min_count: 1 + description: Window measures require version 0.1 + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() query syntax + guidelines: + - Must use semiadditive window measure to prevent summing across time + - Must explain the behavior difference when date is vs is not in GROUP BY + - Must note that window measures require version 0.1 + metadata: + category: advanced + difficulty: hard + source: manual + +- id: metric-views_messy_phrasing_025 + inputs: + prompt: "so umm i wanna make like a metrics thing for my product_usage table??\ + \ like track daily active users and feature adoption.. not sure how metric views\ + \ work exactly" + outputs: + response: | + Metric views let you define reusable metrics in YAML that can be sliced by any dimension at query time. Here's one for your product usage data: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.product_usage_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Product usage KPIs: DAU and feature adoption" + source: catalog.schema.product_usage + + dimensions: + - name: Date + expr: usage_date + - name: Feature + expr: feature_name + comment: "Product feature being used" + - name: Platform + expr: platform + comment: "web, ios, android" + + measures: + - name: Daily Active Users + expr: COUNT(DISTINCT user_id) + comment: "Unique users per day" + - name: Total Events + expr: COUNT(1) + - name: Events per User + expr: COUNT(1) * 1.0 / COUNT(DISTINCT user_id) + comment: "Average events per unique user" + $$ + ``` + + Query your metrics: + + ```sql + -- DAU by date + SELECT `Date`, MEASURE(`Daily Active Users`) AS dau + FROM catalog.schema.product_usage_metrics + GROUP BY ALL + ORDER BY `Date` + + -- Feature adoption: DAU by feature + SELECT `Feature`, MEASURE(`Daily Active Users`) AS users, + MEASURE(`Events per User`) AS engagement + FROM catalog.schema.product_usage_metrics + GROUP BY ALL + ORDER BY users DESC + ``` + + The same metric view supports both queries — you just change which dimensions you include in SELECT and GROUP BY. + execution_success: true + expectations: + expected_facts: + - CREATE OR REPLACE VIEW + - WITH METRICS + - product usage + - daily active users + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "LANGUAGE YAML" + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: "COUNT\\(DISTINCT" + min_count: 1 + description: Must use COUNT(DISTINCT) for DAU + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() in query + guidelines: + - Must handle the informal/messy phrasing gracefully + - Must create a product-usage metric view (not orders) + - Must include DAU and feature adoption metrics + metadata: + category: happy_path + difficulty: easy + source: manual + +- id: metric-views_query_with_where_filter_026 + inputs: + prompt: "I want to query my metric view but only for a specific region and date\ + \ range. How do I filter?" + outputs: + response: | + Use standard SQL `WHERE` clauses on dimensions when querying a metric view: + + ```sql + SELECT + `Order Month`, + MEASURE(`Total Revenue`) AS revenue, + MEASURE(`Order Count`) AS orders + FROM catalog.schema.orders_metrics + WHERE `Region` = 'EMEA' + AND `Order Month` >= DATE'2024-01-01' + AND `Order Month` < DATE'2025-01-01' + GROUP BY ALL + ORDER BY `Order Month` + ``` + + Or via the MCP tool: + + ```python + manage_metric_views( + action="query", + full_name="catalog.schema.orders_metrics", + query_measures=["Total Revenue", "Order Count"], + query_dimensions=["Order Month"], + where="`Region` = 'EMEA' AND `Order Month` >= DATE'2024-01-01'", + order_by="ALL", + limit=100, + ) + ``` + + **Key points:** + - Filter on dimensions using WHERE — just like a normal SQL query + - Dimension names with spaces need backtick quoting in WHERE clauses too + - The metric view's global `filter` (defined in YAML) is always applied in addition to your WHERE + - You cannot filter directly on measures — use HAVING after GROUP BY for that + execution_success: true + expectations: + expected_facts: + - WHERE + - backtick + - MEASURE() + expected_patterns: + - pattern: "WHERE" + min_count: 2 + description: Must show WHERE clause for filtering + - pattern: "MEASURE\\(" + min_count: 1 + description: Must use MEASURE() for measures + - pattern: "GROUP BY" + min_count: 1 + description: Must include GROUP BY + guidelines: + - Must show how to filter on dimensions in WHERE clause + - Must mention backtick quoting for dimension names with spaces + - Should show both SQL and MCP approaches + metadata: + category: happy_path + difficulty: easy + source: manual diff --git a/.test/skills/databricks-metric-views/manifest.yaml b/.test/skills/databricks-metric-views/manifest.yaml new file mode 100644 index 00000000..32eb0a49 --- /dev/null +++ b/.test/skills/databricks-metric-views/manifest.yaml @@ -0,0 +1,32 @@ +skill_name: databricks-metric-views +description: 'Unity Catalog metric views: define, create, query, and manage governed business metrics in YAML. Use when building + standardized KPIs, revenue metrics, order analytics, or any reusable business metrics that need consistent definitions across + teams and tools.' +scorers: + enabled: + - sql_syntax + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml b/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml new file mode 100644 index 00000000..c67cbc61 --- /dev/null +++ b/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml @@ -0,0 +1,365 @@ +metadata: + skill_name: databricks-mlflow-evaluation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: mlflow_eval_basic_evaluate_001 + inputs: + prompt: "How do I run a basic evaluation of my GenAI agent using MLflow?" + outputs: + response: | + ```python + import mlflow + + # Define your agent function + def my_agent(query: str) -> str: + # Your agent logic here + return response + + # Prepare evaluation data (nested input structure required) + eval_data = [ + {"inputs": {"query": "What is MLflow?"}, "expectations": {"expected_response": "MLflow is..."}}, + {"inputs": {"query": "How do I log a model?"}, "expectations": {"expected_response": "Use mlflow.log_model..."}} + ] + + # Run evaluation with built-in scorers + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ + mlflow.genai.scorers.RelevanceScorer(), + mlflow.genai.scorers.SafetyScorer() + ] + ) + print(results.tables["eval_results"]) + ``` + execution_success: true + expectations: + expected_facts: + - "Use mlflow.genai.evaluate(), not mlflow.evaluate()" + - "Data must use nested structure: {inputs: {query: ...}}" + - "predict_fn receives unpacked kwargs, not a dict" + expected_patterns: + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 1 + description: "Uses the correct mlflow.genai.evaluate() API" + - pattern: "predict_fn" + min_count: 1 + description: "Passes the agent function as predict_fn" + - pattern: "scorers" + min_count: 1 + description: "Specifies scorers for evaluation" + - pattern: "\"inputs\"" + min_count: 1 + description: "Uses nested inputs structure in data" + guidelines: + - "Must use mlflow.genai.evaluate(), NOT mlflow.evaluate()" + - "Data format must use nested {inputs: {query: ...}} structure" + - "Must include at least one scorer" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_custom_scorer_002 + inputs: + prompt: "How do I create a custom scorer for MLflow GenAI evaluation?" + outputs: + response: | + ```python + from mlflow.genai.scorers import Scorer + + class ConcisenessScorer(Scorer): + name = "conciseness" + + def score(self, *, output, expectations=None, **kwargs): + if output is None: + return {"score": 0, "justification": "No output"} + + word_count = len(output.split()) + if word_count <= 50: + return {"score": 1, "justification": f"Concise: {word_count} words"} + elif word_count <= 150: + return {"score": 0.5, "justification": f"Moderate: {word_count} words"} + else: + return {"score": 0, "justification": f"Verbose: {word_count} words"} + + # Use in evaluation + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ConcisenessScorer()] + ) + ``` + execution_success: true + expectations: + expected_facts: + - "Custom scorers extend the Scorer base class" + - "The score method receives output and expectations as keyword arguments" + - "Scores should return a dict with score and justification" + expected_patterns: + - pattern: "class.*Scorer" + min_count: 1 + description: "Defines a custom scorer class" + - pattern: "def score\\(self" + min_count: 1 + description: "Implements the score method" + - pattern: "justification" + min_count: 1 + description: "Includes justification in score result" + guidelines: + - "Custom scorer must extend Scorer base class" + - "score method must accept **kwargs for forward compatibility" + - "Return value should include both score and justification" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_dataset_from_traces_003 + inputs: + prompt: "How do I build an evaluation dataset from production traces in MLflow?" + outputs: + response: | + ```python + import mlflow + + # Search for production traces + traces = mlflow.search_traces( + experiment_ids=["123"], + filter_string="status = 'OK'", + max_results=100 + ) + + # Convert traces to evaluation dataset + eval_data = [] + for trace in traces: + eval_data.append({ + "inputs": {"query": trace.request["query"]}, + "expectations": {"expected_response": trace.response} + }) + + # Run evaluation on the collected data + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + ``` + execution_success: true + expectations: + expected_facts: + - "mlflow.search_traces() retrieves production traces for evaluation" + - "Traces can be filtered by status and other attributes" + - "Trace data is converted to the nested {inputs, expectations} format" + expected_patterns: + - pattern: "search_traces" + min_count: 1 + description: "Uses search_traces to retrieve production data" + - pattern: "filter_string" + min_count: 1 + description: "Filters traces by criteria" + - pattern: "\"inputs\"" + min_count: 1 + description: "Converts to required nested data format" + guidelines: + - "Must show the full workflow from trace retrieval to evaluation" + - "Data must be converted to the correct nested format" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_gotchas_004 + inputs: + prompt: "What are the critical gotchas I should know before writing MLflow evaluation code?" + outputs: + response: | + Key gotchas for MLflow GenAI evaluation: + + 1. **Use `mlflow.genai.evaluate()`** - NOT `mlflow.evaluate()`. The genai module is required for LLM evaluation. + + 2. **Data format must be nested**: `{"inputs": {"query": "..."}}`, not `{"query": "..."}`. + + 3. **predict_fn receives unpacked kwargs** - Your function gets `query="..."` not `{"query": "..."}`. + + 4. **Always read GOTCHAS.md first** - Contains 15+ common mistakes that cause failures. + + 5. **Always read CRITICAL-interfaces.md** - Contains exact API signatures and data schemas. + + 6. **Scorer names matter** - Built-in scorers have specific names; custom scorers need unique names. + execution_success: true + expectations: + expected_facts: + - "Use mlflow.genai.evaluate(), not mlflow.evaluate()" + - "Data format requires nested structure: {inputs: {query: ...}}" + - "predict_fn receives unpacked kwargs, not a dict" + - "GOTCHAS.md documents 15+ common mistakes" + expected_patterns: + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 1 + description: "Specifies the correct API" + - pattern: "GOTCHAS" + min_count: 1 + description: "References the GOTCHAS.md file" + - pattern: "inputs" + min_count: 1 + description: "Mentions the nested inputs requirement" + guidelines: + - "Must warn against using mlflow.evaluate() instead of mlflow.genai.evaluate()" + - "Must explain the nested data format requirement" + - "Must mention GOTCHAS.md and CRITICAL-interfaces.md reference files" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_compare_runs_005 + inputs: + prompt: "How do I compare evaluation results across multiple agent versions?" + outputs: + response: | + ```python + import mlflow + + # Evaluate version A + with mlflow.start_run(run_name="agent-v1"): + results_v1 = mlflow.genai.evaluate( + predict_fn=agent_v1, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + + # Evaluate version B + with mlflow.start_run(run_name="agent-v2"): + results_v2 = mlflow.genai.evaluate( + predict_fn=agent_v2, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + + # Compare in the MLflow UI or programmatically + print(f"V1 avg relevance: {results_v1.metrics['relevance/mean']}") + print(f"V2 avg relevance: {results_v2.metrics['relevance/mean']}") + ``` + execution_success: true + expectations: + expected_facts: + - "Each evaluation should run in its own MLflow run for comparison" + - "Metrics are accessible via results.metrics dictionary" + - "The MLflow UI provides visual comparison across runs" + expected_patterns: + - pattern: "start_run" + min_count: 2 + description: "Creates separate MLflow runs for each version" + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 2 + description: "Evaluates both versions" + - pattern: "\\.metrics\\[" + min_count: 1 + description: "Accesses metrics from results" + guidelines: + - "Must show evaluation in separate MLflow runs" + - "Must demonstrate how to access and compare metrics" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_builtin_scorers_006 + inputs: + prompt: "What built-in scorers are available in MLflow GenAI evaluation?" + outputs: + response: | + MLflow provides several built-in scorers in `mlflow.genai.scorers`: + + ```python + from mlflow.genai.scorers import ( + RelevanceScorer, # Is the response relevant to the query? + SafetyScorer, # Is the response safe and appropriate? + GroundednessScorer, # Is the response grounded in provided context? + CorrectnessScorer, # Does the response match expected output? + ) + + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ + RelevanceScorer(), + SafetyScorer(), + GroundednessScorer(), + CorrectnessScorer() + ] + ) + ``` + + Use `RelevanceScorer` and `SafetyScorer` for all evaluations. Add `CorrectnessScorer` when you have expected responses, and `GroundednessScorer` for RAG applications. + execution_success: true + expectations: + expected_facts: + - "Built-in scorers include RelevanceScorer, SafetyScorer, GroundednessScorer, and CorrectnessScorer" + - "Scorers are imported from mlflow.genai.scorers" + - "GroundednessScorer is specifically useful for RAG applications" + expected_patterns: + - pattern: "RelevanceScorer" + min_count: 1 + description: "Lists RelevanceScorer" + - pattern: "SafetyScorer" + min_count: 1 + description: "Lists SafetyScorer" + - pattern: "GroundednessScorer" + min_count: 1 + description: "Lists GroundednessScorer" + - pattern: "CorrectnessScorer" + min_count: 1 + description: "Lists CorrectnessScorer" + guidelines: + - "Must list at least four built-in scorers" + - "Must explain when to use each scorer" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_reference_files_007 + inputs: + prompt: "Which reference files should I read before writing MLflow evaluation code?" + outputs: + response: | + Read these reference files in order: + + | Reference | Purpose | When to Read | + |-----------|---------|--------------| + | `GOTCHAS.md` | Common mistakes | **Always read first** before writing code | + | `CRITICAL-interfaces.md` | API signatures, schemas | When writing any evaluation code | + | `patterns-evaluation.md` | Running evals, comparing | When executing evaluations | + | `patterns-scorers.md` | Custom scorer creation | When built-in scorers aren't enough | + | `patterns-datasets.md` | Dataset building | When preparing evaluation data | + + Always read GOTCHAS.md and CRITICAL-interfaces.md before writing any evaluation code. + execution_success: true + expectations: + expected_facts: + - "GOTCHAS.md must be read first before writing any code" + - "CRITICAL-interfaces.md contains API signatures and data schemas" + - "patterns-evaluation.md covers running and comparing evaluations" + - "patterns-scorers.md covers custom scorer creation" + - "patterns-datasets.md covers dataset preparation" + expected_patterns: + - pattern: "GOTCHAS\\.md" + min_count: 1 + description: "References GOTCHAS.md" + - pattern: "CRITICAL-interfaces\\.md" + min_count: 1 + description: "References CRITICAL-interfaces.md" + - pattern: "patterns-evaluation\\.md" + min_count: 1 + description: "References patterns-evaluation.md" + guidelines: + - "Must list all five reference files" + - "Must emphasize reading GOTCHAS.md first" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-mlflow-evaluation/manifest.yaml b/.test/skills/databricks-mlflow-evaluation/manifest.yaml new file mode 100644 index 00000000..f4240e27 --- /dev/null +++ b/.test/skills/databricks-mlflow-evaluation/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-mlflow-evaluation +description: MLflow 3 GenAI agent evaluation. Use when writing mlflow.genai.evaluate() code, creating @scorer functions, using + built-in scorers (Guidelines, Correctness, Safety, RetrievalGroundedness), building eval datasets from traces, setting up + trace ingestion and production monitoring, aligning judges with MemAlign from domain expert feedback, or running optimize_prompts() + with GEPA for automated prompt improvement. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-python-sdk/ground_truth.yaml b/.test/skills/databricks-python-sdk/ground_truth.yaml new file mode 100644 index 00000000..a783aa3a --- /dev/null +++ b/.test/skills/databricks-python-sdk/ground_truth.yaml @@ -0,0 +1,378 @@ +metadata: + skill_name: databricks-python-sdk + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: pysdk_authentication_001 + inputs: + prompt: "What are the different ways to authenticate with the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + # Auto-detect credentials from environment + w = WorkspaceClient() + + # Explicit token authentication + w = WorkspaceClient( + host="https://your-workspace.cloud.databricks.com", + token="dapi..." + ) + + # Azure Service Principal + w = WorkspaceClient( + host="https://adb-xxx.azuredatabricks.net", + azure_workspace_resource_id="/subscriptions/.../resourceGroups/.../providers/Microsoft.Databricks/workspaces/...", + azure_tenant_id="tenant-id", + azure_client_id="client-id", + azure_client_secret="secret" + ) + + # Use a named profile from ~/.databrickscfg + w = WorkspaceClient(profile="MY_PROFILE") + ``` + execution_success: true + expectations: + expected_facts: + - "WorkspaceClient auto-detects credentials from environment by default" + - "Explicit token auth requires host and token parameters" + - "Azure Service Principal auth uses azure_tenant_id, azure_client_id, and azure_client_secret" + - "Named profiles are loaded from ~/.databrickscfg" + expected_patterns: + - pattern: "WorkspaceClient\\(\\)" + min_count: 1 + description: "Shows auto-detect authentication" + - pattern: "token=" + min_count: 1 + description: "Shows explicit token authentication" + - pattern: "azure_client_id" + min_count: 1 + description: "Shows Azure Service Principal authentication" + - pattern: "profile=" + min_count: 1 + description: "Shows named profile authentication" + guidelines: + - "Must show at least three different authentication methods" + - "Must include auto-detect, explicit token, and profile-based auth" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_clusters_api_002 + inputs: + prompt: "How do I list, create, and manage clusters using the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from datetime import timedelta + + w = WorkspaceClient() + + # List all clusters + for cluster in w.clusters.list(): + print(f"{cluster.cluster_name}: {cluster.state}") + + # Get cluster details + cluster = w.clusters.get(cluster_id="0123-456789-abcdef") + + # Create a cluster (returns Wait object) + wait = w.clusters.create( + cluster_name="my-cluster", + spark_version=w.clusters.select_spark_version(latest=True), + node_type_id=w.clusters.select_node_type(local_disk=True), + num_workers=2 + ) + cluster = wait.result() # Wait for cluster to be running + + # Start/stop/delete + w.clusters.start(cluster_id="...").result() + w.clusters.stop(cluster_id="...") + w.clusters.delete(cluster_id="...") + ``` + execution_success: true + expectations: + expected_facts: + - "w.clusters.list() returns an iterator of cluster objects" + - "w.clusters.create() returns a Wait object for async cluster creation" + - "select_spark_version() and select_node_type() are helper methods for discovery" + - "start() returns a Wait object while stop() and delete() do not" + expected_patterns: + - pattern: "w\\.clusters\\.list" + min_count: 1 + description: "Lists clusters" + - pattern: "w\\.clusters\\.create" + min_count: 1 + description: "Creates a cluster" + - pattern: "select_spark_version" + min_count: 1 + description: "Uses spark version selector" + - pattern: "select_node_type" + min_count: 1 + description: "Uses node type selector" + guidelines: + - "Must show list, create, start, stop, and delete operations" + - "Must use select_spark_version and select_node_type helpers" + - "Must demonstrate the Wait object pattern for async operations" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_jobs_api_003 + inputs: + prompt: "Show me how to create and run a Databricks job using the Python SDK." + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.jobs import Task, NotebookTask + + w = WorkspaceClient() + + # List jobs + for job in w.jobs.list(): + print(f"{job.job_id}: {job.settings.name}") + + # Create a job + created = w.jobs.create( + name="my-job", + tasks=[ + Task( + task_key="main", + notebook_task=NotebookTask(notebook_path="/Users/me/notebook"), + existing_cluster_id="0123-456789-abcdef" + ) + ] + ) + + # Run a job and wait for completion + run = w.jobs.run_now_and_wait(job_id=created.job_id) + print(f"Run completed: {run.state.result_state}") + + # Get run output + output = w.jobs.get_run_output(run_id=run.run_id) + ``` + execution_success: true + expectations: + expected_facts: + - "Jobs are created with w.jobs.create() returning a job with job_id" + - "run_now_and_wait() triggers a job and blocks until completion" + - "get_run_output() retrieves the output of a completed run" + - "Task requires task_key and a task type like NotebookTask" + expected_patterns: + - pattern: "w\\.jobs\\.create" + min_count: 1 + description: "Creates a job" + - pattern: "run_now_and_wait" + min_count: 1 + description: "Runs job and waits for completion" + - pattern: "get_run_output" + min_count: 1 + description: "Gets run output" + - pattern: "NotebookTask" + min_count: 1 + description: "Uses NotebookTask for task configuration" + guidelines: + - "Must show the full job lifecycle: create, run, get output" + - "Must use run_now_and_wait for blocking execution" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_sql_execution_004 + inputs: + prompt: "How do I execute SQL queries against a Databricks SQL warehouse using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.sql import StatementState + + w = WorkspaceClient() + + # Execute SQL query + response = w.statement_execution.execute_statement( + warehouse_id="abc123", + statement="SELECT * FROM catalog.schema.table LIMIT 10", + wait_timeout="30s" + ) + + # Check status and get results + if response.status.state == StatementState.SUCCEEDED: + for row in response.result.data_array: + print(row) + + # For large results, fetch chunks + chunk = w.statement_execution.get_statement_result_chunk_n( + statement_id=response.statement_id, + chunk_index=0 + ) + ``` + execution_success: true + expectations: + expected_facts: + - "SQL is executed with w.statement_execution.execute_statement()" + - "warehouse_id specifies which SQL warehouse to use" + - "StatementState.SUCCEEDED indicates a successful query" + - "Large results are fetched in chunks with get_statement_result_chunk_n()" + expected_patterns: + - pattern: "execute_statement" + min_count: 1 + description: "Executes SQL statement" + - pattern: "warehouse_id" + min_count: 1 + description: "Specifies SQL warehouse" + - pattern: "StatementState" + min_count: 1 + description: "Checks statement execution status" + - pattern: "get_statement_result_chunk_n" + min_count: 1 + description: "Fetches result chunks for large queries" + guidelines: + - "Must show execute, status check, and result retrieval" + - "Must demonstrate chunked result fetching for large datasets" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_databricks_connect_005 + inputs: + prompt: "How do I use Databricks Connect to run Spark code locally?" + outputs: + response: | + ```python + from databricks.connect import DatabricksSession + + # Auto-detects 'DEFAULT' profile from ~/.databrickscfg + spark = DatabricksSession.builder.getOrCreate() + + # With explicit profile + spark = DatabricksSession.builder.profile("MY_PROFILE").getOrCreate() + + # Use spark as normal + df = spark.sql("SELECT * FROM catalog.schema.table") + df.show() + ``` + + Databricks Connect requires `databricks-connect` package and a running cluster to execute against. + execution_success: true + expectations: + expected_facts: + - "DatabricksSession replaces SparkSession for remote Spark execution" + - "Databricks Connect auto-detects credentials from ~/.databrickscfg" + - "The databricks-connect package must be installed separately" + expected_patterns: + - pattern: "DatabricksSession" + min_count: 1 + description: "Uses DatabricksSession for Spark" + - pattern: "getOrCreate" + min_count: 1 + description: "Builds session with getOrCreate()" + - pattern: "profile" + min_count: 1 + description: "Shows profile-based configuration" + guidelines: + - "Must show DatabricksSession, not SparkSession" + - "Must demonstrate both auto-detect and explicit profile usage" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_rest_api_006 + inputs: + prompt: "How can I make direct REST API calls using the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Direct API call using authenticated client + response = w.api_client.do( + method="GET", + path="/api/2.0/clusters/list" + ) + + # POST with body + response = w.api_client.do( + method="POST", + path="/api/2.0/jobs/run-now", + body={"job_id": 123} + ) + ``` + + Use direct REST API calls for operations not yet available in the SDK or for advanced use cases. + execution_success: true + expectations: + expected_facts: + - "w.api_client.do() makes authenticated REST API calls" + - "The method parameter accepts GET, POST, PUT, DELETE" + - "The body parameter accepts a dictionary for POST/PUT requests" + expected_patterns: + - pattern: "api_client\\.do" + min_count: 2 + description: "Uses api_client.do() for direct REST calls" + - pattern: "method=" + min_count: 2 + description: "Specifies HTTP method" + - pattern: "path=" + min_count: 2 + description: "Specifies API endpoint path" + guidelines: + - "Must show both GET and POST examples" + - "Must explain when to use direct REST over SDK methods" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_cli_commands_007 + inputs: + prompt: "What are the most common Databricks CLI commands?" + outputs: + response: | + ```bash + # Check version (should be >= 0.278.0) + databricks --version + + # Use specific profile + databricks --profile MY_PROFILE clusters list + + # Common commands + databricks clusters list + databricks jobs list + databricks workspace ls /Users/me + ``` + + The CLI uses the same authentication as the Python SDK (`~/.databrickscfg` profiles, environment variables, or explicit flags). + execution_success: true + expectations: + expected_facts: + - "The Databricks CLI shares authentication with the Python SDK" + - "Profiles are specified with the --profile flag" + - "Common subcommands include clusters, jobs, and workspace" + expected_patterns: + - pattern: "databricks --version" + min_count: 1 + description: "Shows version check command" + - pattern: "databricks clusters list" + min_count: 1 + description: "Shows cluster listing command" + - pattern: "databricks jobs list" + min_count: 1 + description: "Shows job listing command" + guidelines: + - "Must show at least three common CLI commands" + - "Must mention the --profile flag for profile selection" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-python-sdk/manifest.yaml b/.test/skills/databricks-python-sdk/manifest.yaml new file mode 100644 index 00000000..62377187 --- /dev/null +++ b/.test/skills/databricks-python-sdk/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-python-sdk +description: Databricks development guidance including Python SDK, Databricks Connect, CLI, and REST API. Use when working + with databricks-sdk, databricks-connect, or Databricks APIs. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml b/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml new file mode 100644 index 00000000..6f642133 --- /dev/null +++ b/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml @@ -0,0 +1,227 @@ +metadata: + skill_name: databricks-spark-structured-streaming + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: spark-streaming_kafka_to_delta_001 + inputs: + prompt: Create a streaming pipeline that reads from Kafka and writes to a Delta table + outputs: + response: | + Here's a basic Kafka to Delta streaming pipeline: + + ```python + from pyspark.sql.functions import col, from_json + + # Read from Kafka + df = (spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "broker:9092") + .option("subscribe", "topic") + .load() + .select(from_json(col("value").cast("string"), schema).alias("data")) + .select("data.*") + ) + + # Write to Delta + df.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/checkpoints/stream") \ + .trigger(processingTime="30 seconds") \ + .start("/delta/target_table") + ``` + + Key points: + - Use a persistent checkpoint location (UC volumes, not DBFS) + - Each stream needs a unique checkpoint path + execution_success: true + expectations: + expected_facts: + - Reads from Kafka using readStream + - Writes to Delta table format + - Includes checkpoint location + expected_patterns: + - pattern: readStream + min_count: 1 + description: Must use readStream for streaming source + - pattern: writeStream + min_count: 1 + description: Must use writeStream for streaming sink + - pattern: checkpointLocation + min_count: 1 + description: Must specify checkpoint location + - pattern: format.*kafka + min_count: 1 + description: Must specify Kafka format + guidelines: + - Code must be valid Python syntax + - Must include checkpoint location for fault tolerance + - Checkpoint should use UC volumes, not DBFS + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_trigger_modes_002 + inputs: + prompt: What are the different trigger modes for Spark Structured Streaming? + outputs: + response: | + Spark Structured Streaming supports several trigger modes: + + ```python + # Fixed interval processing + .trigger(processingTime="30 seconds") + + # Process all available data once + .trigger(availableNow=True) + + # Continuous processing (low latency) + .trigger(continuous="1 second") + ``` + + - **processingTime**: Fixed interval micro-batches. Best for most workloads. + - **availableNow**: Processes all available data then stops. Good for batch-like streaming. + - **continuous**: Low-latency mode (~1ms). Limited operator support. + execution_success: true + expectations: + expected_facts: + - processingTime trigger for fixed interval + - availableNow trigger processes all data then stops + - continuous trigger for low latency + expected_patterns: + - pattern: processingTime + min_count: 1 + description: Must mention processingTime trigger + - pattern: trigger + min_count: 1 + description: Must use trigger configuration + guidelines: + - Must explain multiple trigger modes + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_checkpoint_practices_003 + inputs: + prompt: What are best practices for Spark Structured Streaming checkpoints? + outputs: + response: | + Checkpoint best practices: + + 1. **Use persistent storage**: UC volumes, not DBFS + 2. **Unique per stream**: Each stream needs its own checkpoint path + 3. **Never delete**: Checkpoints track stream state and offsets + 4. **Use volumes**: `/Volumes/catalog/schema/checkpoints/stream_name` + + ```python + df.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/schema/checkpoints/my_stream") \ + .start("catalog.schema.target_table") + ``` + execution_success: true + expectations: + expected_facts: + - Checkpoint location should be persistent (UC volumes) + - Each stream needs a unique checkpoint + expected_patterns: + - pattern: checkpointLocation + min_count: 1 + description: Must reference checkpointLocation + - pattern: /Volumes/ + min_count: 1 + description: Should use UC volumes path + guidelines: + - Must recommend persistent storage for checkpoints + - Must warn about unique checkpoints per stream + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_production_checklist_004 + inputs: + prompt: What is the production checklist for Spark Structured Streaming? + outputs: + response: | + Production checklist for Spark Structured Streaming: + + - Checkpoint location is persistent (UC volumes, not DBFS) + - Unique checkpoint per stream + - Fixed-size cluster (no autoscaling for streaming) + - Monitoring configured (input rate, lag, batch duration) + - Exactly-once verified (txnVersion/txnAppId) + - Watermark configured for stateful operations + - Left joins for stream-static (not inner) + execution_success: true + expectations: + expected_facts: + - Use persistent checkpoint location + - Fixed-size cluster, no autoscaling + - Configure monitoring for streaming + expected_patterns: + - pattern: checkpoint + min_count: 1 + description: Must mention checkpoint requirements + - pattern: autoscal + min_count: 1 + description: Must mention autoscaling restriction + guidelines: + - Must cover checkpoint, cluster, and monitoring requirements + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: spark-streaming_stream_joins_005 + inputs: + prompt: How do I join a streaming DataFrame with a static table? + outputs: + response: | + For stream-static joins, use a left join (not inner) to avoid dropping streaming rows when the static side has no match: + + ```python + # Read streaming source + stream_df = spark.readStream.format("delta").table("catalog.schema.events") + + # Read static dimension table + dim_df = spark.read.table("catalog.schema.dimensions") + + # Left join: stream on left, static on right + joined = stream_df.join(dim_df, "join_key", "left") + + joined.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/schema/checkpoints/joined_stream") \ + .start("catalog.schema.enriched_events") + ``` + execution_success: true + expectations: + expected_facts: + - Use left join for stream-static joins + - Static table is on the right side + expected_patterns: + - pattern: readStream + min_count: 1 + description: Must use streaming read + - pattern: join + min_count: 1 + description: Must demonstrate join operation + - pattern: left + min_count: 1 + description: Must use left join + guidelines: + - Must use left join for stream-static joins, not inner + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-spark-structured-streaming/manifest.yaml b/.test/skills/databricks-spark-structured-streaming/manifest.yaml new file mode 100644 index 00000000..7e6e9521 --- /dev/null +++ b/.test/skills/databricks-spark-structured-streaming/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-spark-structured-streaming +description: Comprehensive guide to Spark Structured Streaming for production workloads. Use when building streaming pipelines, + implementing real-time data processing, handling stateful operations, or optimizing streaming performance. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml b/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml new file mode 100644 index 00000000..b0e1b0f6 --- /dev/null +++ b/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml @@ -0,0 +1,392 @@ +metadata: + skill_name: databricks-synthetic-data-generation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: synth_data_script_structure_001 + inputs: + prompt: "What is the recommended structure for a synthetic data generation script on Databricks?" + outputs: + response: | + ```python + """Generate synthetic data for [use case].""" + import numpy as np + import pandas as pd + from datetime import datetime, timedelta + from faker import Faker + import holidays + from pyspark.sql import SparkSession + + # ============================================================================= + # CONFIGURATION - Edit these values + # ============================================================================= + CATALOG = "my_catalog" + SCHEMA = "my_schema" + VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/raw_data" + + N_CUSTOMERS = 2500 + N_ORDERS = 25000 + N_TICKETS = 8000 + + END_DATE = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + START_DATE = END_DATE - timedelta(days=180) + + SEED = 42 + + # ============================================================================= + # SETUP + # ============================================================================= + np.random.seed(SEED) + Faker.seed(SEED) + fake = Faker() + spark = SparkSession.builder.getOrCreate() + ``` + execution_success: true + expectations: + expected_facts: + - "Configuration variables go at the top of the script" + - "Seeds must be set for both numpy and Faker for reproducibility" + - "Volume path follows /Volumes/{catalog}/{schema}/{volume} pattern" + - "SparkSession is used for saving data to Delta/Parquet" + expected_patterns: + - pattern: "CATALOG\\s*=" + min_count: 1 + description: "Defines catalog configuration variable" + - pattern: "np\\.random\\.seed" + min_count: 1 + description: "Sets numpy random seed" + - pattern: "Faker\\.seed" + min_count: 1 + description: "Sets Faker random seed" + - pattern: "SparkSession" + min_count: 1 + description: "Uses SparkSession for data saving" + guidelines: + - "Configuration must be at the top of the script" + - "Must set seeds for both numpy and Faker for reproducibility" + - "Must use SparkSession for saving to Databricks storage" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_infrastructure_002 + inputs: + prompt: "Should I create the catalog, schema, and volume separately or inside the generation script?" + outputs: + response: | + Always create infrastructure inside the Python script using `spark.sql()`: + + ```python + # ============================================================================= + # CREATE INFRASTRUCTURE (inside the Python script) + # ============================================================================= + spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}") + spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}") + spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.{SCHEMA}.raw_data") + ``` + + Do NOT make separate MCP SQL calls - it's much slower. The `spark` variable is available by default on Databricks clusters. + execution_success: true + expectations: + expected_facts: + - "Infrastructure creation belongs inside the Python script, not as separate SQL calls" + - "spark.sql() is used for CREATE CATALOG, CREATE SCHEMA, and CREATE VOLUME" + - "IF NOT EXISTS makes the script idempotent" + expected_patterns: + - pattern: "CREATE CATALOG IF NOT EXISTS" + min_count: 1 + description: "Creates catalog idempotently" + - pattern: "CREATE SCHEMA IF NOT EXISTS" + min_count: 1 + description: "Creates schema idempotently" + - pattern: "CREATE VOLUME IF NOT EXISTS" + min_count: 1 + description: "Creates volume idempotently" + guidelines: + - "Must use spark.sql() inside the script, not separate SQL calls" + - "All CREATE statements must use IF NOT EXISTS" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_distributions_003 + inputs: + prompt: "How do I generate synthetic data with realistic statistical distributions?" + outputs: + response: | + Use numpy distributions with weighted probabilities: + + ```python + import numpy as np + from faker import Faker + + fake = Faker() + + # Categorical with weighted distribution + customers_pdf["tier"] = np.random.choice( + ['Free', 'Pro', 'Enterprise'], N_CUSTOMERS, p=[0.6, 0.3, 0.1] + ) + + # Log-normal for financial data (right-skewed, no negatives) + customers_pdf["arr"] = customers_pdf["tier"].apply( + lambda t: round(np.random.lognormal(11, 0.5), 2) if t == 'Enterprise' + else round(np.random.lognormal(8, 0.6), 2) if t == 'Pro' else 0 + ) + + # Correlated fields: order amount varies by tier + amount = np.random.lognormal( + 7 if tier == 'Enterprise' else 5 if tier == 'Pro' else 3.5, + 0.7 + ) + ``` + + Key distributions: use log-normal for financial amounts, weighted choice for categories, and exponential for time durations. + execution_success: true + expectations: + expected_facts: + - "Log-normal distribution is ideal for financial data (right-skewed, non-negative)" + - "np.random.choice with p parameter creates weighted categorical distributions" + - "Field values should correlate with related fields (e.g., order amount varies by tier)" + expected_patterns: + - pattern: "np\\.random\\.choice" + min_count: 1 + description: "Uses weighted random choice for categories" + - pattern: "np\\.random\\.lognormal" + min_count: 1 + description: "Uses log-normal for financial data" + - pattern: "p=\\[" + min_count: 1 + description: "Specifies probability weights" + guidelines: + - "Must show weighted distributions, not uniform random" + - "Financial values must use log-normal, not normal distribution" + - "Must demonstrate correlation between related fields" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: synth_data_time_patterns_004 + inputs: + prompt: "How do I add realistic time-based patterns like weekday/weekend effects and event spikes?" + outputs: + response: | + ```python + import holidays + + US_HOLIDAYS = holidays.US(years=[START_DATE.year, END_DATE.year]) + + def get_daily_multiplier(date): + """Calculate volume multiplier for a given date.""" + multiplier = 1.0 + + # Weekend drop + if date.weekday() >= 5: + multiplier *= 0.6 + + # Holiday drop (even lower than weekends) + if date in US_HOLIDAYS: + multiplier *= 0.3 + + # Q4 seasonality (higher in Oct-Dec) + multiplier *= 1 + 0.15 * (date.month - 6) / 6 + + # Incident spike + if INCIDENT_START <= date <= INCIDENT_END: + multiplier *= 3.0 + + # Random noise + multiplier *= np.random.normal(1, 0.1) + + return max(0.1, multiplier) + + # Apply to generate daily volumes + date_range = pd.date_range(START_DATE, END_DATE, freq='D') + daily_volumes = [int(BASE_DAILY * get_daily_multiplier(d)) for d in date_range] + ``` + execution_success: true + expectations: + expected_facts: + - "The holidays library provides country-specific holiday calendars" + - "Weekend volumes should drop to 60% of weekday volumes" + - "Holiday volumes should drop further to 30%" + - "Event spikes multiply baseline volume by 3x or more" + expected_patterns: + - pattern: "holidays\\.US" + min_count: 1 + description: "Uses US holiday calendar" + - pattern: "weekday\\(\\)" + min_count: 1 + description: "Checks for weekday/weekend" + - pattern: "INCIDENT_START" + min_count: 1 + description: "Implements event spike period" + - pattern: "np\\.random\\.normal" + min_count: 1 + description: "Adds random noise to volumes" + guidelines: + - "Must include weekday/weekend effects" + - "Must include holiday effects using the holidays library" + - "Must include at least one event spike period" + - "Must add random noise to avoid unrealistic regularity" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: synth_data_save_volume_005 + inputs: + prompt: "How should I save generated synthetic data to Databricks?" + outputs: + response: | + Always save data to a Unity Catalog Volume as parquet files: + + ```python + VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/raw_data" + + # Save as parquet files (raw data for downstream pipelines) + spark.createDataFrame(customers_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/customers") + spark.createDataFrame(orders_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/orders") + spark.createDataFrame(tickets_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/tickets") + ``` + + Save to Volumes as parquet, not directly to tables. This serves as raw input for a downstream Spark Declarative Pipeline that handles bronze/silver/gold layers. + execution_success: true + expectations: + expected_facts: + - "Data should be saved to Volumes as parquet, not directly to tables" + - "Volume path pattern is /Volumes/{catalog}/{schema}/{volume}" + - "spark.createDataFrame() converts pandas DataFrames to Spark for saving" + - "Parquet files serve as input for downstream pipeline processing" + expected_patterns: + - pattern: "createDataFrame" + min_count: 1 + description: "Converts pandas to Spark DataFrame" + - pattern: "\\.parquet\\(" + min_count: 1 + description: "Saves as parquet format" + - pattern: "mode\\(\"overwrite\"\\)" + min_count: 1 + description: "Uses overwrite mode" + - pattern: "/Volumes/" + min_count: 1 + description: "Saves to Unity Catalog Volume path" + guidelines: + - "Must save to Volumes, not directly to tables" + - "Must use parquet format" + - "Must explain that this is raw data for downstream pipeline" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_raw_only_006 + inputs: + prompt: "Should I include pre-aggregated fields like total_orders or avg_revenue in my synthetic data?" + outputs: + response: | + No. By default, generate only raw transactional data. One row equals one event. + + ```python + # CORRECT - Raw transactional data + customers_data.append({ + "customer_id": cid, + "name": fake.company(), + "tier": "Enterprise", + "region": "North", + }) + + orders_data.append({ + "order_id": f"ORD-{i:06d}", + "customer_id": cid, + "amount": 150.00, + "order_date": "2024-10-15", + }) + + # WRONG - Do not add pre-aggregated fields + # customers_data.append({ + # "total_orders": 47, # aggregation + # "total_revenue": 12500.00, # sum + # "avg_order_value": 265.95, # average + # }) + ``` + + Aggregations belong in downstream pipeline transformations, not in raw data. + execution_success: true + expectations: + expected_facts: + - "Raw data should have one row per event/transaction/record" + - "Pre-aggregated fields like total_orders, sum_revenue, avg_csat should not be included" + - "Aggregations belong in downstream pipeline transformations" + expected_patterns: + - pattern: "customer_id" + min_count: 1 + description: "Shows raw transactional customer data" + - pattern: "order_id" + min_count: 1 + description: "Shows raw transactional order data" + - pattern: "WRONG|BAD|Do not" + min_count: 1 + description: "Shows anti-pattern example" + guidelines: + - "Must clearly show the correct raw data pattern" + - "Must show the anti-pattern with pre-aggregated fields" + - "Must explain that aggregations belong downstream" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_common_libraries_007 + inputs: + prompt: "What Python libraries should I use for generating synthetic data?" + outputs: + response: | + Key libraries for synthetic data generation: + + - **faker**: Generates realistic names, addresses, emails, companies, and dates + - **numpy**: Statistical distributions (log-normal, exponential, weighted choice) + - **pandas**: DataFrame construction and manipulation + - **holidays**: Country-specific holiday calendars for realistic date patterns + - **pyspark**: SparkSession for saving data to Volumes as parquet + + ```python + import numpy as np + import pandas as pd + from datetime import datetime, timedelta + from faker import Faker + import holidays + from pyspark.sql import SparkSession + ``` + + All of these are pre-installed on Databricks clusters. + execution_success: true + expectations: + expected_facts: + - "Faker generates realistic personal and business data" + - "Numpy provides statistical distributions for realistic value generation" + - "The holidays library provides holiday calendars for date patterns" + - "All listed libraries are pre-installed on Databricks clusters" + expected_patterns: + - pattern: "faker" + min_count: 1 + description: "Mentions Faker library" + - pattern: "numpy" + min_count: 1 + description: "Mentions numpy library" + - pattern: "holidays" + min_count: 1 + description: "Mentions holidays library" + - pattern: "pyspark" + min_count: 1 + description: "Mentions PySpark library" + guidelines: + - "Must list at least four key libraries with their purposes" + - "Must mention that libraries are pre-installed on Databricks" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-synthetic-data-generation/manifest.yaml b/.test/skills/databricks-synthetic-data-generation/manifest.yaml new file mode 100644 index 00000000..9216d106 --- /dev/null +++ b/.test/skills/databricks-synthetic-data-generation/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-synthetic-data-generation +description: Generate realistic synthetic data using Faker and Spark, with non-linear distributions, integrity constraints, + and save to Databricks. Use when creating test data, demo datasets, or synthetic tables. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-unity-catalog/ground_truth.yaml b/.test/skills/databricks-unity-catalog/ground_truth.yaml new file mode 100644 index 00000000..487ae52e --- /dev/null +++ b/.test/skills/databricks-unity-catalog/ground_truth.yaml @@ -0,0 +1,255 @@ +metadata: + skill_name: databricks-unity-catalog + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: uc_volume_operations_001 + inputs: + prompt: "How do I upload, download, and list files in a Unity Catalog Volume?" + outputs: + response: | + Use MCP tools for Volume file operations: + + ```python + # List files in a volume + list_volume_files(volume_path="/Volumes/catalog/schema/volume/folder/") + + # Upload file to volume + upload_to_volume( + local_path="/tmp/data.csv", + volume_path="/Volumes/catalog/schema/volume/data.csv" + ) + + # Download file from volume + download_from_volume( + volume_path="/Volumes/catalog/schema/volume/data.csv", + local_path="/tmp/downloaded.csv" + ) + + # Create directory + create_volume_directory(volume_path="/Volumes/catalog/schema/volume/new_folder") + ``` + execution_success: true + expectations: + expected_facts: + - "Volume paths follow the pattern /Volumes/{catalog}/{schema}/{volume}" + - "list_volume_files, upload_to_volume, download_from_volume, and create_volume_directory are available MCP tools" + expected_patterns: + - pattern: "list_volume_files" + min_count: 1 + description: "Uses list_volume_files MCP tool" + - pattern: "upload_to_volume" + min_count: 1 + description: "Uses upload_to_volume MCP tool" + - pattern: "download_from_volume" + min_count: 1 + description: "Uses download_from_volume MCP tool" + - pattern: "/Volumes/" + min_count: 3 + description: "Uses Unity Catalog Volume paths" + guidelines: + - "Must show at least list, upload, and download operations" + - "Volume paths must follow the /Volumes/{catalog}/{schema}/{volume} pattern" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: uc_table_lineage_002 + inputs: + prompt: "How do I query table lineage to find which tables feed into my target table?" + outputs: + response: | + ```sql + -- Table lineage: What tables feed into this table? + SELECT source_table_full_name, source_column_name + FROM system.access.table_lineage + WHERE target_table_full_name = 'catalog.schema.table' + AND event_date >= current_date() - 7; + ``` + + System table lineage is captured automatically by Unity Catalog. Query `system.access.table_lineage` for table-level dependencies and `system.access.column_lineage` for column-level lineage. + execution_success: true + expectations: + expected_facts: + - "Table lineage is stored in system.access.table_lineage" + - "Column lineage is available in system.access.column_lineage" + - "Lineage is captured automatically by Unity Catalog" + expected_patterns: + - pattern: "system\\.access\\.table_lineage" + min_count: 1 + description: "Queries the table lineage system table" + - pattern: "target_table_full_name" + min_count: 1 + description: "Filters by target table name" + - pattern: "source_table_full_name" + min_count: 1 + description: "Returns source table names" + guidelines: + - "Must query system.access.table_lineage" + - "Must show filtering by target table name" + - "Should mention that lineage is captured automatically" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: uc_audit_logs_003 + inputs: + prompt: "How do I check recent permission changes using Unity Catalog audit logs?" + outputs: + response: | + ```sql + -- Audit: Recent permission changes + SELECT event_time, user_identity.email, action_name, request_params + FROM system.access.audit + WHERE action_name LIKE '%GRANT%' OR action_name LIKE '%REVOKE%' + ORDER BY event_time DESC + LIMIT 100; + ``` + + The `system.access.audit` table records all access control changes. Filter by action_name to find GRANT, REVOKE, or other permission-related events. + execution_success: true + expectations: + expected_facts: + - "Audit logs are stored in system.access.audit" + - "Permission changes include GRANT and REVOKE actions" + - "user_identity.email identifies who made the change" + expected_patterns: + - pattern: "system\\.access\\.audit" + min_count: 1 + description: "Queries the audit log system table" + - pattern: "action_name" + min_count: 1 + description: "Filters by action type" + - pattern: "GRANT" + min_count: 1 + description: "Checks for GRANT operations" + - pattern: "REVOKE" + min_count: 1 + description: "Checks for REVOKE operations" + guidelines: + - "Must query system.access.audit" + - "Must filter for permission-related actions" + - "Must include user identity in the output" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_billing_usage_004 + inputs: + prompt: "How do I query DBU usage and billing data across workspaces?" + outputs: + response: | + ```sql + -- Billing: DBU usage by workspace + SELECT workspace_id, sku_name, SUM(usage_quantity) AS total_dbus + FROM system.billing.usage + WHERE usage_date >= current_date() - 30 + GROUP BY workspace_id, sku_name; + ``` + + The `system.billing.usage` table provides DBU consumption data. Use it to track costs by workspace, SKU type, or time period. + execution_success: true + expectations: + expected_facts: + - "Billing data is stored in system.billing.usage" + - "usage_quantity contains the DBU consumption" + - "Data can be filtered by usage_date for time-based analysis" + expected_patterns: + - pattern: "system\\.billing\\.usage" + min_count: 1 + description: "Queries the billing system table" + - pattern: "SUM\\(usage_quantity\\)" + min_count: 1 + description: "Aggregates DBU usage" + - pattern: "sku_name" + min_count: 1 + description: "Groups by SKU type" + guidelines: + - "Must query system.billing.usage" + - "Must aggregate by meaningful dimensions like workspace or SKU" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_grant_system_access_005 + inputs: + prompt: "How do I grant a group access to Unity Catalog system tables?" + outputs: + response: | + ```sql + -- Grant access to system tables + GRANT USE CATALOG ON CATALOG system TO `data_engineers`; + GRANT USE SCHEMA ON SCHEMA system.access TO `data_engineers`; + GRANT SELECT ON SCHEMA system.access TO `data_engineers`; + ``` + + Three grants are needed: USE CATALOG on the `system` catalog, USE SCHEMA on the specific schema (e.g., `system.access`), and SELECT for read access. + execution_success: true + expectations: + expected_facts: + - "System table access requires three levels of grants: USE CATALOG, USE SCHEMA, and SELECT" + - "The system catalog contains built-in system tables" + - "system.access schema contains lineage and audit tables" + expected_patterns: + - pattern: "GRANT USE CATALOG" + min_count: 1 + description: "Grants USE CATALOG permission" + - pattern: "GRANT USE SCHEMA" + min_count: 1 + description: "Grants USE SCHEMA permission" + - pattern: "GRANT SELECT" + min_count: 1 + description: "Grants SELECT permission" + guidelines: + - "Must show all three required GRANT statements" + - "Must explain why three separate grants are needed" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_mcp_execute_sql_006 + inputs: + prompt: "How do I query system tables using the MCP SQL execution tool?" + outputs: + response: | + ```python + # Query lineage via MCP + mcp__databricks__execute_sql( + sql_query=""" + SELECT source_table_full_name, target_table_full_name + FROM system.access.table_lineage + WHERE event_date >= current_date() - 7 + """, + catalog="system" + ) + ``` + + Use `mcp__databricks__execute_sql` to execute system table queries. Set the `catalog` parameter to "system" for system table access. + execution_success: true + expectations: + expected_facts: + - "mcp__databricks__execute_sql is the MCP tool for running SQL queries" + - "The catalog parameter should be set to 'system' for system table queries" + expected_patterns: + - pattern: "mcp__databricks__execute_sql" + min_count: 1 + description: "Uses the MCP SQL execution tool" + - pattern: "catalog=\"system\"" + min_count: 1 + description: "Sets catalog to system" + - pattern: "system\\.access" + min_count: 1 + description: "Queries system access tables" + guidelines: + - "Must show the mcp__databricks__execute_sql tool call" + - "Must set catalog parameter to 'system'" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-unity-catalog/manifest.yaml b/.test/skills/databricks-unity-catalog/manifest.yaml new file mode 100644 index 00000000..f679684b --- /dev/null +++ b/.test/skills/databricks-unity-catalog/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-unity-catalog +description: Unity Catalog system tables and volumes. Use when querying system tables (audit, lineage, billing) or working + with volume file operations (upload, download, list files in /Volumes/). +scorers: + enabled: + - sql_syntax + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml b/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml new file mode 100644 index 00000000..6560e7c1 --- /dev/null +++ b/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml @@ -0,0 +1,271 @@ +metadata: + skill_name: databricks-unstructured-pdf-generation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: pdf_gen_quick_start_001 + inputs: + prompt: "How do I generate synthetic PDF documents for my RAG pipeline?" + outputs: + response: | + Use the `generate_pdf_documents` MCP tool: + + ``` + generate_pdf_documents( + catalog="my_catalog", + schema="my_schema", + volume="my_volume", + folder="hr_policies", + topic="HR policy documents for a mid-size technology company", + num_documents=15, + complexity="medium" + ) + ``` + + This generates PDF files and corresponding JSON metadata files with questions and evaluation guidelines for RAG testing. + execution_success: true + expectations: + expected_facts: + - "generate_pdf_documents is an MCP tool for creating synthetic PDFs" + - "Documents are saved to a Unity Catalog Volume" + - "Both PDF files and JSON metadata files are generated" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "catalog" + min_count: 1 + description: "Specifies target catalog" + - pattern: "schema" + min_count: 1 + description: "Specifies target schema" + - pattern: "num_documents" + min_count: 1 + description: "Specifies number of documents to generate" + guidelines: + - "Must show the generate_pdf_documents tool with key parameters" + - "Must mention that both PDFs and JSON metadata are generated" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_json_structure_002 + inputs: + prompt: "What is the structure of the JSON metadata files generated alongside the PDFs?" + outputs: + response: | + Each PDF gets a corresponding JSON file with evaluation metadata: + + ```json + { + "title": "API Authentication Guide", + "category": "Technical", + "pdf_path": "/Volumes/catalog/schema/volume/folder/doc_001.pdf", + "question": "What authentication methods are supported by the API?", + "guideline": "Answer should mention OAuth 2.0, API keys, and JWT tokens with their use cases." + } + ``` + + - `title`: Document title + - `category`: Document category + - `pdf_path`: Full Volume path to the generated PDF + - `question`: A test question whose answer is in the PDF + - `guideline`: Evaluation criteria for judging RAG responses + execution_success: true + expectations: + expected_facts: + - "Each PDF has a corresponding JSON metadata file" + - "JSON includes title, category, pdf_path, question, and guideline fields" + - "The question field contains a query answerable from the PDF content" + - "The guideline field provides evaluation criteria for RAG responses" + expected_patterns: + - pattern: "\"title\"" + min_count: 1 + description: "JSON includes title field" + - pattern: "\"question\"" + min_count: 1 + description: "JSON includes question field" + - pattern: "\"guideline\"" + min_count: 1 + description: "JSON includes guideline field" + - pattern: "\"pdf_path\"" + min_count: 1 + description: "JSON includes pdf_path field" + guidelines: + - "Must show the complete JSON structure with all fields" + - "Must explain the purpose of each field" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_rag_integration_003 + inputs: + prompt: "How do I use the generated PDFs and JSON files to test my RAG pipeline?" + outputs: + response: | + ```python + # Load questions from JSON files + questions = load_json_files(f"/Volumes/{catalog}/{schema}/{volume}/{folder}/*.json") + + for q in questions: + # 1. Ingest the PDF into your vector database + ingest_document(q["pdf_path"]) + + # 2. Query RAG system with the test question + response = rag_system.query(q["question"]) + + # 3. Evaluate using the guideline + is_correct = evaluate_response(response, q["guideline"]) + ``` + + The workflow is: ingest PDFs as source documents, use `question` fields to query your RAG system, and use `guideline` fields to evaluate response quality. + execution_success: true + expectations: + expected_facts: + - "PDFs serve as source documents for the vector database" + - "Questions from JSON test retrieval accuracy" + - "Guidelines from JSON evaluate response quality" + expected_patterns: + - pattern: "question" + min_count: 2 + description: "Uses question field for RAG queries" + - pattern: "guideline" + min_count: 2 + description: "Uses guideline field for evaluation" + - pattern: "pdf_path" + min_count: 1 + description: "References PDF path for ingestion" + guidelines: + - "Must show the three-step workflow: ingest, query, evaluate" + - "Must use both question and guideline fields from the JSON" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pdf_gen_hr_documents_004 + inputs: + prompt: "Generate HR policy documents for testing a company knowledge base RAG system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="hr_demo", + volume="documents", + folder="hr_policies", + topic="HR policy documents for a mid-size technology company including PTO, benefits, code of conduct, and remote work policies", + num_documents=20, + complexity="medium" + ) + ``` + + This generates 20 HR policy PDFs covering common topics like PTO, benefits enrollment, code of conduct, and remote work guidelines, along with JSON evaluation files. + execution_success: true + expectations: + expected_facts: + - "The topic parameter guides the content theme of generated documents" + - "num_documents controls how many PDFs are generated" + - "complexity affects the depth and length of generated content" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "topic=" + min_count: 1 + description: "Specifies document topic" + - pattern: "num_documents=" + min_count: 1 + description: "Specifies document count" + guidelines: + - "Must specify a descriptive topic parameter" + - "Must use realistic catalog/schema/volume names" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_tech_docs_005 + inputs: + prompt: "Generate technical documentation PDFs for testing a developer documentation search system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="tech_docs", + volume="documents", + folder="api_docs", + topic="Technical API documentation including REST endpoints, authentication, SDK usage, error handling, and rate limiting", + num_documents=25, + complexity="high" + ) + ``` + + Set complexity to "high" for technical documentation to generate detailed content with code examples, tables, and structured sections. + execution_success: true + expectations: + expected_facts: + - "complexity='high' generates more detailed content with code examples" + - "The topic parameter should describe the specific technical domain" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "complexity=\"high\"" + min_count: 1 + description: "Uses high complexity for technical docs" + - pattern: "topic=" + min_count: 1 + description: "Specifies technical topic" + guidelines: + - "Must use high complexity for technical documentation" + - "Topic must describe specific technical domains" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pdf_gen_financial_reports_006 + inputs: + prompt: "Generate synthetic financial report PDFs for testing a compliance document search system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="finance_demo", + volume="documents", + folder="financial_reports", + topic="Corporate financial reports including quarterly earnings, annual reports, budget summaries, and audit findings", + num_documents=15, + complexity="high" + ) + ``` + + Financial documents benefit from high complexity to include tables, figures, and structured financial data that tests RAG parsing capabilities. + execution_success: true + expectations: + expected_facts: + - "Financial documents should use high complexity for realistic structure" + - "Generated PDFs include tables and structured data" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "financial" + min_count: 1 + description: "Topic relates to financial documents" + - pattern: "complexity=\"high\"" + min_count: 1 + description: "Uses high complexity" + guidelines: + - "Must use high complexity for financial documents" + - "Must specify financial-specific topics" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml b/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml new file mode 100644 index 00000000..27c09770 --- /dev/null +++ b/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-unstructured-pdf-generation +description: Generate synthetic PDF documents for RAG and unstructured data use cases. Use when creating test PDFs, demo documents, + or evaluation datasets for retrieval systems. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-vector-search/ground_truth.yaml b/.test/skills/databricks-vector-search/ground_truth.yaml new file mode 100644 index 00000000..7e8acd2b --- /dev/null +++ b/.test/skills/databricks-vector-search/ground_truth.yaml @@ -0,0 +1,345 @@ +metadata: + skill_name: databricks-vector-search + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: vs_create_endpoint_001 + inputs: + prompt: "How do I create a vector search endpoint in Databricks?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Create a standard endpoint + endpoint = w.vector_search_endpoints.create_endpoint( + name="my-vs-endpoint", + endpoint_type="STANDARD" # or "STORAGE_OPTIMIZED" + ) + # Endpoint creation is asynchronous; check status with get_endpoint() + ``` + + Two endpoint types are available: + - **STANDARD**: Best for low-latency queries + - **STORAGE_OPTIMIZED**: Best for large-scale, cost-effective deployments + execution_success: true + expectations: + expected_facts: + - "Endpoints are created with w.vector_search_endpoints.create_endpoint()" + - "Two endpoint types: STANDARD and STORAGE_OPTIMIZED" + - "Endpoint creation is asynchronous" + expected_patterns: + - pattern: "create_endpoint" + min_count: 1 + description: "Uses create_endpoint() method" + - pattern: "STANDARD" + min_count: 1 + description: "Shows STANDARD endpoint type" + - pattern: "STORAGE_OPTIMIZED" + min_count: 1 + description: "Shows STORAGE_OPTIMIZED endpoint type" + guidelines: + - "Must show both STANDARD and STORAGE_OPTIMIZED endpoint types" + - "Must mention that endpoint creation is asynchronous" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_delta_sync_managed_002 + inputs: + prompt: "How do I create a Delta Sync index with managed embeddings?" + outputs: + response: | + ```python + # Source table must have: primary key column + text column + index = w.vector_search_indexes.create_index( + name="catalog.schema.my_index", + endpoint_name="my-vs-endpoint", + primary_key="id", + index_type="DELTA_SYNC", + delta_sync_index_spec={ + "source_table": "catalog.schema.documents", + "embedding_source_columns": [ + { + "name": "content", + "embedding_model_endpoint_name": "databricks-gte-large-en" + } + ], + "pipeline_type": "TRIGGERED" # or "CONTINUOUS" + } + ) + ``` + + Databricks automatically embeds the text column using the specified model. Pipeline types: TRIGGERED (manual sync) or CONTINUOUS (auto-sync on source changes). + execution_success: true + expectations: + expected_facts: + - "Delta Sync index automatically syncs from a source Delta table" + - "Managed embeddings use embedding_source_columns with a model endpoint" + - "TRIGGERED pipeline syncs on demand; CONTINUOUS syncs automatically" + - "databricks-gte-large-en is a built-in embedding model" + expected_patterns: + - pattern: "create_index" + min_count: 1 + description: "Uses create_index() method" + - pattern: "DELTA_SYNC" + min_count: 1 + description: "Specifies Delta Sync index type" + - pattern: "embedding_source_columns" + min_count: 1 + description: "Configures managed embedding columns" + - pattern: "databricks-gte-large-en" + min_count: 1 + description: "Uses built-in embedding model" + guidelines: + - "Must show the full create_index call with delta_sync_index_spec" + - "Must explain TRIGGERED vs CONTINUOUS pipeline types" + - "Must specify a Databricks embedding model endpoint" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_delta_sync_self_managed_003 + inputs: + prompt: "How do I create a Delta Sync index with pre-computed embeddings?" + outputs: + response: | + ```python + # Source table must have: primary key + embedding vector column + index = w.vector_search_indexes.create_index( + name="catalog.schema.my_index", + endpoint_name="my-vs-endpoint", + primary_key="id", + index_type="DELTA_SYNC", + delta_sync_index_spec={ + "source_table": "catalog.schema.documents", + "embedding_vector_columns": [ + { + "name": "embedding", + "embedding_dimension": 768 + } + ], + "pipeline_type": "TRIGGERED" + } + ) + ``` + + Use self-managed embeddings when you need a specific embedding model not available as a Databricks endpoint, or when you pre-compute embeddings in a pipeline. + execution_success: true + expectations: + expected_facts: + - "Self-managed embeddings use embedding_vector_columns instead of embedding_source_columns" + - "embedding_dimension must match the dimension of your pre-computed vectors" + - "The source table must already contain an embedding column" + expected_patterns: + - pattern: "embedding_vector_columns" + min_count: 1 + description: "Uses self-managed embedding columns" + - pattern: "embedding_dimension" + min_count: 1 + description: "Specifies embedding dimension" + - pattern: "DELTA_SYNC" + min_count: 1 + description: "Uses Delta Sync index type" + guidelines: + - "Must use embedding_vector_columns (not embedding_source_columns)" + - "Must include embedding_dimension parameter" + - "Must explain when to use self-managed vs managed embeddings" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_query_index_004 + inputs: + prompt: "How do I query a vector search index and get similarity results?" + outputs: + response: | + ```python + results = w.vector_search_indexes.query_index( + index_name="catalog.schema.my_index", + columns=["id", "content", "metadata"], + query_text="What is machine learning?", + num_results=5 + ) + + for doc in results.result.data_array: + score = doc[-1] # Similarity score is last column + print(f"Score: {score}, Content: {doc[1][:100]}...") + ``` + + For managed embedding indexes, use `query_text`. For self-managed, use `query_vector` with a pre-computed embedding. + execution_success: true + expectations: + expected_facts: + - "query_index() searches for similar documents by text or vector" + - "query_text is used for managed embedding indexes" + - "query_vector is used for self-managed embedding indexes" + - "Similarity score is the last column in the result array" + expected_patterns: + - pattern: "query_index" + min_count: 1 + description: "Uses query_index() method" + - pattern: "query_text" + min_count: 1 + description: "Uses text-based query" + - pattern: "num_results" + min_count: 1 + description: "Specifies number of results" + - pattern: "data_array" + min_count: 1 + description: "Accesses result data array" + guidelines: + - "Must show query_text for managed embeddings" + - "Must mention query_vector as alternative for self-managed" + - "Must show how to access similarity scores" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_query_with_filters_005 + inputs: + prompt: "How do I apply metadata filters when querying a vector search index?" + outputs: + response: | + ```python + results = w.vector_search_indexes.query_index( + index_name="catalog.schema.my_index", + columns=["id", "content", "category"], + query_text="machine learning best practices", + num_results=10, + filters_json='{"category": "technical", "status": "published"}' + ) + ``` + + The `filters_json` parameter accepts a JSON string of key-value pairs that filter results by metadata columns before similarity ranking. Only documents matching all filters are returned. + execution_success: true + expectations: + expected_facts: + - "filters_json accepts a JSON string for metadata filtering" + - "Filters are applied before similarity ranking" + - "All filter conditions must be satisfied (AND logic)" + expected_patterns: + - pattern: "filters_json" + min_count: 1 + description: "Uses filters_json parameter" + - pattern: "query_index" + min_count: 1 + description: "Uses query_index() method" + - pattern: "query_text" + min_count: 1 + description: "Combines text query with filters" + guidelines: + - "Must show filters_json parameter with a JSON string" + - "Must explain that filters are applied before similarity ranking" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_cli_reference_006 + inputs: + prompt: "What CLI commands are available for managing Databricks vector search?" + outputs: + response: | + ```bash + # List endpoints + databricks vector-search endpoints list + + # Create endpoint + databricks vector-search endpoints create \ + --name my-endpoint \ + --endpoint-type STANDARD + + # List indexes on endpoint + databricks vector-search indexes list-indexes \ + --endpoint-name my-endpoint + + # Get index status + databricks vector-search indexes get-index \ + --index-name catalog.schema.my_index + + # Sync index (for TRIGGERED pipeline) + databricks vector-search indexes sync-index \ + --index-name catalog.schema.my_index + + # Delete index + databricks vector-search indexes delete-index \ + --index-name catalog.schema.my_index + ``` + execution_success: true + expectations: + expected_facts: + - "Vector search CLI commands are under databricks vector-search" + - "sync-index triggers a manual sync for TRIGGERED pipeline indexes" + - "Endpoints and indexes are managed as separate subcommands" + expected_patterns: + - pattern: "databricks vector-search endpoints" + min_count: 1 + description: "Shows endpoint management commands" + - pattern: "databricks vector-search indexes" + min_count: 1 + description: "Shows index management commands" + - pattern: "sync-index" + min_count: 1 + description: "Shows manual sync command" + guidelines: + - "Must show both endpoint and index management commands" + - "Must include sync-index for triggered pipelines" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_embedding_models_007 + inputs: + prompt: "What embedding models are available for vector search indexes?" + outputs: + response: | + Databricks provides built-in embedding models for managed indexes: + + | Model | Dimensions | + |-------|-----------| + | `databricks-gte-large-en` | 1024 | + | `databricks-bge-large-en` | 1024 | + + ```python + # Use with managed embeddings in index creation + embedding_source_columns=[ + { + "name": "content", + "embedding_model_endpoint_name": "databricks-gte-large-en" + } + ] + ``` + + You can also use custom embedding models deployed as serving endpoints. + execution_success: true + expectations: + expected_facts: + - "databricks-gte-large-en produces 1024-dimensional embeddings" + - "databricks-bge-large-en produces 1024-dimensional embeddings" + - "Custom embedding models can also be used via serving endpoints" + expected_patterns: + - pattern: "databricks-gte-large-en" + min_count: 1 + description: "Lists GTE embedding model" + - pattern: "databricks-bge-large-en" + min_count: 1 + description: "Lists BGE embedding model" + - pattern: "1024" + min_count: 1 + description: "Specifies embedding dimensions" + guidelines: + - "Must list at least two built-in embedding models with dimensions" + - "Must mention that custom models can also be used" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-vector-search/manifest.yaml b/.test/skills/databricks-vector-search/manifest.yaml new file mode 100644 index 00000000..a053e80a --- /dev/null +++ b/.test/skills/databricks-vector-search/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-vector-search +description: 'Patterns for Databricks Vector Search: create endpoints and indexes, query with filters, manage embeddings. + Use when building RAG applications, semantic search, or similarity matching. Covers both storage-optimized and standard + endpoints.' +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/src/skill_test/cli/__init__.py b/.test/src/skill_test/cli/__init__.py index b8352981..4afe3a7d 100644 --- a/.test/src/skill_test/cli/__init__.py +++ b/.test/src/skill_test/cli/__init__.py @@ -17,6 +17,7 @@ review, trace_eval, list_traces, + optimize, ) @@ -36,6 +37,7 @@ def main(): review - Review pending candidates interactively trace-eval - Evaluate trace against skill expectations list-traces - List available trace runs from MLflow + optimize - Optimize a skill using GEPA """ args = sys.argv[1:] @@ -51,6 +53,7 @@ def main(): print(" review Review pending candidates interactively") print(" trace-eval Evaluate trace against skill expectations") print(" list-traces List available trace runs from MLflow") + print(" optimize Optimize a skill using GEPA") sys.exit(0) skill_name = args[0] @@ -142,6 +145,48 @@ def main(): } else: result = list_traces(experiment, ctx, limit) + elif subcommand == "optimize": + # Parse optimize-specific arguments + opt_preset = "standard" + opt_mode = "static" + opt_task_lm = None + opt_reflection_lm = None + opt_dry_run = False + opt_apply = False + + i = 2 + while i < len(args): + if args[i] in ("--preset", "-p") and i + 1 < len(args): + opt_preset = args[i + 1] + i += 2 + elif args[i] in ("--mode", "-m") and i + 1 < len(args): + opt_mode = args[i + 1] + i += 2 + elif args[i] == "--task-lm" and i + 1 < len(args): + opt_task_lm = args[i + 1] + i += 2 + elif args[i] == "--reflection-lm" and i + 1 < len(args): + opt_reflection_lm = args[i + 1] + i += 2 + elif args[i] == "--dry-run": + opt_dry_run = True + i += 1 + elif args[i] == "--apply": + opt_apply = True + i += 1 + else: + i += 1 + + result = optimize( + skill_name, + ctx, + preset=opt_preset, + mode=opt_mode, + task_lm=opt_task_lm, + reflection_lm=opt_reflection_lm, + dry_run=opt_dry_run, + apply=opt_apply, + ) else: print(f"Unknown subcommand: {subcommand}") sys.exit(1) @@ -171,5 +216,6 @@ def main(): "review", "trace_eval", "list_traces", + "optimize", "main", ] diff --git a/.test/src/skill_test/cli/commands.py b/.test/src/skill_test/cli/commands.py index e84637e2..be619d5a 100644 --- a/.test/src/skill_test/cli/commands.py +++ b/.test/src/skill_test/cli/commands.py @@ -1399,3 +1399,84 @@ def list_traces( "experiment_name": experiment_name, "hint": "Check experiment name and MLflow connection", } + + +def optimize( + skill_name: str, + ctx: CLIContext, + preset: str = "standard", + mode: str = "static", + task_lm: Optional[str] = None, + reflection_lm: Optional[str] = None, + dry_run: bool = False, + apply: bool = False, +) -> Dict[str, Any]: + """Optimize a skill using GEPA. + + Runs the full optimization pipeline: evaluate -> optimize -> review. + Optionally applies the optimized result to the SKILL.md. + + Args: + skill_name: Name of the skill to optimize + ctx: CLI context + preset: GEPA preset ("quick", "standard", "thorough") + mode: "static" (uses ground truth) or "generative" (generates fresh responses) + task_lm: LLM model for generative mode + reflection_lm: Override GEPA reflection model + dry_run: Show config and estimate cost without running + apply: Apply the optimized result to SKILL.md + + Returns: + Dictionary with optimization results + """ + try: + from ..optimize.runner import optimize_skill + from ..optimize.review import review_optimization, apply_optimization + except ImportError as e: + return { + "success": False, + "error": f"GEPA optimization requires the 'optimize' extra: {e}", + "hint": "Install with: pip install skill-test[optimize]", + } + + try: + result = optimize_skill( + skill_name=skill_name, + mode=mode, + preset=preset, + task_lm=task_lm, + reflection_lm=reflection_lm, + dry_run=dry_run, + ) + + review_optimization(result) + + if apply and not dry_run: + apply_optimization(result) + + return { + "success": True, + "skill_name": skill_name, + "original_score": result.original_score, + "optimized_score": result.optimized_score, + "improvement": result.improvement, + "original_tokens": result.original_token_count, + "optimized_tokens": result.optimized_token_count, + "token_reduction_pct": result.token_reduction_pct, + "applied": apply and not dry_run, + "dry_run": dry_run, + "mlflow_run_id": result.mlflow_run_id, + } + except FileNotFoundError as e: + return { + "success": False, + "error": str(e), + "skill_name": skill_name, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "skill_name": skill_name, + "hint": "Check GEPA installation and API keys", + } diff --git a/.test/src/skill_test/grp/executor.py b/.test/src/skill_test/grp/executor.py index 5cd393bb..0413ea3c 100644 --- a/.test/src/skill_test/grp/executor.py +++ b/.test/src/skill_test/grp/executor.py @@ -4,8 +4,8 @@ import re import time import yaml -from dataclasses import dataclass, field -from typing import List, Tuple, Optional, Dict, Any, Callable, Protocol +from dataclasses import dataclass +from typing import List, Tuple, Optional, Dict, Any, Protocol @dataclass diff --git a/.test/src/skill_test/optimize/__init__.py b/.test/src/skill_test/optimize/__init__.py new file mode 100644 index 00000000..f55bad01 --- /dev/null +++ b/.test/src/skill_test/optimize/__init__.py @@ -0,0 +1,22 @@ +"""GEPA-powered skill optimization using optimize_anything API. + +Public API: + optimize_skill() - End-to-end optimize a SKILL.md (and optionally tools) + create_skillbench_evaluator() - Create a judge-based evaluator for a skill + OptimizationResult - Dataclass with optimization results + PRESETS - GEPA config presets (quick, standard, thorough) +""" + +from .runner import optimize_skill, OptimizationResult +from .skillbench_evaluator import create_skillbench_evaluator +from .config import PRESETS +from .review import review_optimization, apply_optimization + +__all__ = [ + "optimize_skill", + "OptimizationResult", + "create_skillbench_evaluator", + "PRESETS", + "review_optimization", + "apply_optimization", +] diff --git a/.test/src/skill_test/optimize/alignment.py b/.test/src/skill_test/optimize/alignment.py new file mode 100644 index 00000000..6151f3f2 --- /dev/null +++ b/.test/src/skill_test/optimize/alignment.py @@ -0,0 +1,99 @@ +"""MemAlign integration for aligning judges with human feedback. + +MemAlign aligns judges with human feedback via dual-memory: + - Semantic memory: generalizable evaluation principles + - Episodic memory: specific edge cases and corrections + +Alignment traces are stored per-skill in: + .test/skills//alignment_traces.yaml + +Populated via ``scripts/review.py --align`` where a human corrects +judge verdicts. MemAlign learns principles from corrections, +improving judge accuracy over time. + +Only 2-10 examples are needed for visible improvement. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import yaml + +logger = logging.getLogger(__name__) + + +def load_alignment_traces(skill_name: str) -> list[dict[str, Any]]: + """Load human-corrected alignment traces for a skill. + + Traces are stored in .test/skills//alignment_traces.yaml + with format: + - inputs: {prompt: "..."} + outputs: {response: "..."} + expected_value: true/false or 0.0-1.0 + rationale: "Human explanation of correct verdict" + + Returns: + List of trace dicts, or empty list if no traces found. + """ + traces_path = Path(".test/skills") / skill_name / "alignment_traces.yaml" + if not traces_path.exists(): + return [] + + try: + with open(traces_path) as f: + data = yaml.safe_load(f) + return data if isinstance(data, list) else [] + except Exception as e: + logger.warning("Failed to load alignment traces for %s: %s", skill_name, e) + return [] + + +def align_judge( + skill_name: str, + judge: Any, + reflection_lm: str = "openai:/gpt-4o-mini", +) -> Any: + """Align a judge with human feedback using MemAlign. + + If fewer than 3 alignment traces exist, returns the judge unchanged. + Otherwise, uses MemAlignOptimizer to learn evaluation principles + from human corrections and returns an aligned judge. + + Args: + skill_name: Name of the skill to load traces for. + judge: An MLflow judge (from make_judge or similar). + reflection_lm: LLM for MemAlign's reflection step. + + Returns: + Aligned judge if enough traces exist, otherwise original judge. + """ + traces = load_alignment_traces(skill_name) + if len(traces) < 3: + if traces: + logger.info( + "Only %d alignment traces for %s (need >=3). Using base judge.", + len(traces), + skill_name, + ) + return judge + + try: + from mlflow.genai.judges.optimizers import MemAlignOptimizer + + optimizer = MemAlignOptimizer(reflection_lm=reflection_lm) + aligned = judge.align(traces=traces, optimizer=optimizer) + logger.info( + "Aligned judge with %d traces for %s", + len(traces), + skill_name, + ) + return aligned + except ImportError: + logger.warning("MemAlignOptimizer not available. Install mlflow-deepeval for alignment support.") + return judge + except Exception as e: + logger.warning("MemAlign alignment failed for %s: %s", skill_name, e) + return judge diff --git a/.test/src/skill_test/optimize/asi.py b/.test/src/skill_test/optimize/asi.py new file mode 100644 index 00000000..fe79a6cf --- /dev/null +++ b/.test/src/skill_test/optimize/asi.py @@ -0,0 +1,81 @@ +"""ASI diagnostics: convert MLflow Feedback to optimize_anything SideInfo. + +Thin adapter that passes judge rationale through to GEPA's reflection LM +WITHOUT truncation. The critical fix: GEPA's reflection LM gets full +diagnostic text from judges, not truncated snippets. + +Also provides ``feedback_to_score()`` for backward compatibility with tests. +""" + +from __future__ import annotations + +from typing import Any + +from mlflow.entities import Feedback + + +def feedback_to_score(feedback: Feedback) -> float | None: + """Convert a single MLflow Feedback to a numeric score. + + Mapping: + "yes" -> 1.0 + "no" -> 0.0 + "skip" -> None (excluded from scoring) + numeric -> float(value) + """ + value = feedback.value + if value == "yes": + return 1.0 + elif value == "no": + return 0.0 + elif value == "skip": + return None + else: + try: + return float(value) + except (TypeError, ValueError): + return None + + +def feedback_to_asi(feedbacks: list[Feedback]) -> tuple[float, dict[str, Any]]: + """Convert MLflow Feedback objects to optimize_anything (score, SideInfo). + + Computes the mean score across non-skipped feedbacks and builds a + SideInfo dict with full rationale (no truncation). + """ + scores = [] + side_info: dict[str, Any] = {} + + for fb in feedbacks: + score = feedback_to_score(fb) + name = fb.name or "unnamed" + + if score is None: + side_info[name] = { + "score": None, + "value": fb.value, + "rationale": fb.rationale or "", + "status": "skipped", + } + continue + + scores.append(score) + side_info[name] = { + "score": score, + "value": fb.value, + "rationale": fb.rationale or "", + "status": "pass" if score >= 0.5 else "fail", + } + + composite = sum(scores) / len(scores) if scores else 0.0 + + side_info["_summary"] = { + "composite_score": composite, + "total_scorers": len(feedbacks), + "scored": len(scores), + "skipped": len(feedbacks) - len(scores), + "passed": sum(1 for s in scores if s >= 0.5), + "failed": sum(1 for s in scores if s < 0.5), + } + + return composite, side_info diff --git a/.test/src/skill_test/optimize/assertions.py b/.test/src/skill_test/optimize/assertions.py new file mode 100644 index 00000000..5906e5fa --- /dev/null +++ b/.test/src/skill_test/optimize/assertions.py @@ -0,0 +1,194 @@ +"""Binary assertion layer for SkillBench-style evaluation. + +Wraps pattern and fact checks into binary pass/fail assertions, +mirroring SkillBench's pytest-style binary approach. No fuzzy keyword +scoring -- each assertion either passes or fails. +""" + +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class AssertionResult: + """Result of a single binary assertion.""" + + name: str + passed: bool + rationale: str + assertion_type: str # "pattern" | "fact" + + +def _run_pattern_assertions(response: str, expected_patterns: list) -> list[AssertionResult]: + """Run pattern assertions against a response. + + Each pattern spec can be a plain regex string or a dict with + ``pattern``, ``min_count``, ``max_count``, ``description`` keys. + """ + results = [] + for pattern_spec in expected_patterns: + if isinstance(pattern_spec, str): + pattern = pattern_spec + min_count = 1 + max_count = None + description = pattern[:40] + else: + pattern = pattern_spec["pattern"] + min_count = pattern_spec.get("min_count", 1) + max_count = pattern_spec.get("max_count", None) + description = pattern_spec.get("description", pattern[:40]) + + matches = len(re.findall(pattern, response, re.IGNORECASE)) + + if max_count is not None: + passed = min_count <= matches <= max_count + rationale = f"Found {matches} matches (need {min_count}-{max_count})" + else: + passed = matches >= min_count + rationale = f"Found {matches} matches (need >={min_count})" + + results.append( + AssertionResult( + name=f"pattern_{description}", + passed=passed, + rationale=rationale, + assertion_type="pattern", + ) + ) + return results + + +def _run_fact_assertions(response: str, expected_facts: list[str]) -> list[AssertionResult]: + """Run fact assertions against a response. + + Exact substring match (case-insensitive). No fuzzy keyword overlap. + """ + response_lower = response.lower() + results = [] + for fact in expected_facts: + found = fact.lower() in response_lower + results.append( + AssertionResult( + name=f"fact_{fact[:40]}", + passed=found, + rationale=f"{'Found' if found else 'Missing'}: {fact}", + assertion_type="fact", + ) + ) + return results + + +def run_all_assertions(response: str, expectations: dict[str, Any]) -> list[AssertionResult]: + """Run all pattern + fact assertions, return binary pass/fail per assertion. + + Args: + response: The text to check assertions against. + expectations: Dict with optional ``expected_patterns`` and ``expected_facts`` keys. + + Returns: + List of AssertionResult with binary pass/fail for each assertion. + """ + results: list[AssertionResult] = [] + + patterns = expectations.get("expected_patterns", []) + if patterns: + results.extend(_run_pattern_assertions(response, patterns)) + + facts = expectations.get("expected_facts", []) + if facts: + results.extend(_run_fact_assertions(response, facts)) + + return results + + +def _classify_assertion( + with_result: AssertionResult, + without_result: AssertionResult, +) -> str: + """Classify a single assertion by comparing with-skill vs without-skill. + + Returns one of: + POSITIVE — fails without skill, passes with (skill is helping) + REGRESSION — passes without skill, fails with (skill is confusing the agent) + NEEDS_SKILL — fails both with and without (skill must add this content) + NEUTRAL — same result either way (agent already knows this) + """ + if with_result.passed and not without_result.passed: + return "POSITIVE" + elif not with_result.passed and without_result.passed: + return "REGRESSION" + elif not with_result.passed and not without_result.passed: + return "NEEDS_SKILL" + else: + return "NEUTRAL" + + +def _extract_content(result: AssertionResult) -> str: + """Extract the actual expected content from an assertion result. + + For facts, strips the ``Missing: `` / ``Found: `` prefix to get the raw + fact text. For patterns, uses the description embedded in the assertion + name (strips the ``pattern_`` prefix). + """ + if result.assertion_type == "fact": + for prefix in ("Missing: ", "Found: "): + if result.rationale.startswith(prefix): + return result.rationale[len(prefix) :] + return result.rationale + else: + # Pattern: name is "pattern_{description}", rationale is match count + return result.name.removeprefix("pattern_") + + +def summarize_failures( + with_results: list[AssertionResult], + without_results: list[AssertionResult], +) -> dict[str, str]: + """Build GEPA-friendly diagnostic strings from assertion results. + + Collects only NEEDS_SKILL and REGRESSION assertions (skips NEUTRAL/POSITIVE) + and produces structured output that maps to GEPA's standard diagnostic keys. + + Only non-empty keys are included in the returned dict so that GEPA does not + render empty ``## Header`` sections that waste tokens and confuse the + reflection LM. + + Returns: + Dict with a subset of: ``Error``, ``Regressions``. + ``Error`` carries compact NEEDS_SKILL/REGRESSION tokens that downstream + consumers (``_review_skillbench``, ``build_skillbench_background``) parse. + ``Regressions`` is a concise NL summary only present when regressions exist. + """ + needs_skill: list[tuple[AssertionResult, AssertionResult]] = [] + regressions: list[tuple[AssertionResult, AssertionResult]] = [] + + for w, wo in zip(with_results, without_results, strict=True): + label = _classify_assertion(w, wo) + if label == "NEEDS_SKILL": + needs_skill.append((w, wo)) + elif label == "REGRESSION": + regressions.append((w, wo)) + + result: dict[str, str] = {} + + # Error: compact assertion labels (NEEDS_SKILL/REGRESSION tokens preserved) + error_lines: list[str] = [] + for w, _ in needs_skill: + content = _extract_content(w) + error_lines.append(f"NEEDS_SKILL: {w.assertion_type} — '{content}'") + for w, _ in regressions: + content = _extract_content(w) + error_lines.append(f"REGRESSION: {w.assertion_type} — '{content}'") + if error_lines: + result["Error"] = "\n".join(error_lines) + + # Regressions: concise NL (only when non-empty) + if regressions: + lines: list[str] = [] + for i, (w, _wo) in enumerate(regressions, 1): + content = _extract_content(w) + lines.append(f"{i}. '{content}' — passes without skill, fails with it") + result["Regressions"] = "\n".join(lines) + + return result diff --git a/.test/src/skill_test/optimize/config.py b/.test/src/skill_test/optimize/config.py new file mode 100644 index 00000000..e361b3d6 --- /dev/null +++ b/.test/src/skill_test/optimize/config.py @@ -0,0 +1,369 @@ +"""GEPA configuration presets for skill optimization. + +Uses the optimize_anything API with GEPAConfig/EngineConfig/ReflectionConfig. +""" + +import os +import warnings + +from gepa.optimize_anything import GEPAConfig, EngineConfig, ReflectionConfig, RefinerConfig + +DEFAULT_REFLECTION_LM = os.environ.get("GEPA_REFLECTION_LM", "databricks/databricks-claude-opus-4-6") + +DEFAULT_GEN_LM = os.environ.get("GEPA_GEN_LM", "databricks/databricks-claude-sonnet-4-6") + +DEFAULT_TOKEN_BUDGET: int | None = int(os.environ.get("GEPA_TOKEN_BUDGET", "0")) or None + + +# --------------------------------------------------------------------------- +# Register Databricks models with litellm so it knows their true context +# windows. Without this, litellm may fuzzy-match to a similar model with +# different limits, or worse, the Databricks serving endpoint may reject +# requests that exceed a vLLM-configured max_model_len. +# +# NOTE: This does NOT override the endpoint's own max_model_len setting. +# If the Databricks endpoint itself is configured with a low limit (e.g. +# 8192), you must either reconfigure the endpoint or use a different +# provider (openai/, anthropic/) whose endpoints support larger contexts. +# --------------------------------------------------------------------------- +def _configure_litellm_retries() -> None: + """Configure litellm to retry on transient errors (429, 529, 500). + + GEPA calls litellm.completion() without passing num_retries, so we + set it globally. This handles Anthropic 529 "Overloaded" errors, + rate limits, and other transient failures with exponential backoff. + + Rate-limit retries get extra attempts (10) since --include-tools sends + large contexts that easily hit token-per-minute ceilings on Opus. + """ + try: + import litellm + from litellm import RetryPolicy + + litellm.num_retries = 5 + litellm.request_timeout = 180 # seconds per attempt + litellm.retry_policy = RetryPolicy( + RateLimitErrorRetries=10, + InternalServerErrorRetries=5, + TimeoutErrorRetries=5, + ) + # Drop log noise from retries + litellm.suppress_debug_info = True + except ImportError: + pass + + +def _register_litellm_models() -> None: + """Register Databricks model context windows with litellm.""" + try: + import litellm + + _models = { + "databricks/databricks-claude-opus-4-6": { + "max_tokens": 32_000, + "max_input_tokens": 200_000, + "max_output_tokens": 32_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-claude-sonnet-4-6": { + "max_tokens": 16_000, + "max_input_tokens": 200_000, + "max_output_tokens": 16_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-gpt-5-2": { + "max_tokens": 128_000, + "max_input_tokens": 272_000, + "max_output_tokens": 128_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-gemini-3-1-pro": { + "max_tokens": 65_536, + "max_input_tokens": 1_048_576, + "max_output_tokens": 65_536, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-claude-opus-4-5": { + "max_tokens": 32_000, + "max_input_tokens": 200_000, + "max_output_tokens": 32_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-gpt-5": { + "max_tokens": 100_000, + "max_input_tokens": 1_048_576, + "max_output_tokens": 100_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-claude-sonnet-4-5": { + "max_tokens": 16_000, + "max_input_tokens": 200_000, + "max_output_tokens": 16_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + } + for model_name, model_info in _models.items(): + litellm.model_cost[model_name] = model_info + except ImportError: + pass + + +_register_litellm_models() +_configure_litellm_retries() + + +# Overhead multiplier: the reflection prompt is roughly this many times +# the raw candidate tokens (includes background, ASI, framing). +_REFLECTION_OVERHEAD_MULTIPLIER = 3 + +PRESETS: dict[str, GEPAConfig] = { + "quick": GEPAConfig( + engine=EngineConfig(max_metric_calls=15, parallel=True), + reflection=ReflectionConfig(reflection_lm=DEFAULT_REFLECTION_LM), + refiner=RefinerConfig(max_refinements=1), + ), + "standard": GEPAConfig( + engine=EngineConfig(max_metric_calls=50, parallel=True), + reflection=ReflectionConfig( + reflection_lm=DEFAULT_REFLECTION_LM, + reflection_minibatch_size=3, + ), + refiner=RefinerConfig(max_refinements=1), + ), + "thorough": GEPAConfig( + engine=EngineConfig(max_metric_calls=150, parallel=True), + reflection=ReflectionConfig( + reflection_lm=DEFAULT_REFLECTION_LM, + reflection_minibatch_size=3, + ), + refiner=RefinerConfig(max_refinements=1), + ), +} + +# Base max_metric_calls per preset (used to scale by component count) +PRESET_BASE_CALLS: dict[str, int] = { + "quick": 15, + "standard": 50, + "thorough": 150, +} + +# Per-preset caps: safety net so component scaling never exceeds a reasonable +# ceiling. Important for --tools-only mode which has many tool components. +PRESET_MAX_CALLS: dict[str, int] = { + "quick": 45, + "standard": 150, + "thorough": 300, +} + +# Maximum total metric calls per pass to avoid runaway runtimes. +# With many components, uncapped scaling (e.g., 50 * 17 = 850) can cause +# multi-hour hangs with slower reflection models like Sonnet. +MAX_METRIC_CALLS_PER_PASS = 300 + +# Models known to be fast enough for large multi-component optimization. +# Other models get the metric-call cap applied. +_FAST_REFLECTION_MODELS = { + "databricks/databricks-claude-opus-4-6", + "databricks/databricks-gpt-5-2", + "openai/gpt-4o", + "anthropic/claude-opus-4-6", +} + + +def validate_databricks_env() -> None: + """Check that DATABRICKS_API_BASE is set correctly for litellm. + + litellm's Databricks provider requires: + DATABRICKS_API_BASE=https://.cloud.databricks.com/serving-endpoints + + A common mistake is omitting /serving-endpoints, which causes 404 errors. + """ + api_base = os.environ.get("DATABRICKS_API_BASE", "") + if api_base and not api_base.rstrip("/").endswith("/serving-endpoints"): + fixed = api_base.rstrip("/") + "/serving-endpoints" + warnings.warn( + f"DATABRICKS_API_BASE={api_base!r} is missing '/serving-endpoints' suffix. " + f"litellm will get 404 errors. Automatically fixing to: {fixed}", + stacklevel=2, + ) + os.environ["DATABRICKS_API_BASE"] = fixed + + +def validate_reflection_context( + reflection_lm: str, + total_candidate_tokens: int, +) -> None: + """Warn if the candidate is likely too large for the reflection model. + + Queries litellm's model registry for the model's max_input_tokens and + compares against the estimated reflection prompt size. + + Note: this checks litellm's *client-side* knowledge of the model. The + Databricks serving endpoint may have a *different* (lower) limit set via + vLLM's ``max_model_len``. If you see ``BadRequestError`` with + ``max_model_len`` in the message, the endpoint itself is the bottleneck -- + switch to a provider whose endpoint supports your context needs (e.g. + ``openai/gpt-4o`` or ``anthropic/claude-sonnet-4-5-20250514``). + """ + try: + import litellm + + info = litellm.get_model_info(reflection_lm) + limit = info.get("max_input_tokens") or info.get("max_tokens") or 0 + except Exception: + return # can't determine limit -- skip check + + if limit <= 0: + return + + estimated_prompt = total_candidate_tokens * _REFLECTION_OVERHEAD_MULTIPLIER + if estimated_prompt > limit: + raise ValueError( + f"\nReflection model '{reflection_lm}' has a {limit:,}-token input limit " + f"(per litellm), but the estimated reflection prompt is ~{estimated_prompt:,} " + f"tokens ({total_candidate_tokens:,} candidate tokens x " + f"{_REFLECTION_OVERHEAD_MULTIPLIER} overhead).\n\n" + f"Fix: use a model with a larger context window:\n" + f" --reflection-lm 'databricks/databricks-claude-opus-4-6' (200K)\n" + f" --reflection-lm 'openai/gpt-4o' (128K)\n" + f" --reflection-lm 'anthropic/claude-sonnet-4-5-20250514' (200K)\n\n" + f"Or set the environment variable:\n" + f" export GEPA_REFLECTION_LM='databricks/databricks-claude-opus-4-6'\n\n" + f"If you already use a large-context model and still see 'max_model_len'\n" + f"errors, the Databricks serving endpoint itself has a low context limit.\n" + f"Switch to a non-Databricks provider (openai/ or anthropic/) instead.\n\n" + f" Current GEPA_REFLECTION_LM={os.environ.get('GEPA_REFLECTION_LM', '(not set)')}" + ) + + +def estimate_pass_duration( + num_metric_calls: int, + reflection_lm: str, + total_candidate_tokens: int, + num_dataset_examples: int = 7, +) -> float | None: + """Estimate wall-clock seconds for one optimization pass. + + Metric calls are mostly fast local evaluations. The slow part is + reflection LLM calls, which happen roughly once per iteration + (num_metric_calls / num_dataset_examples iterations). + + Returns None if estimation is not possible. + """ + # Rough per-reflection latency (seconds) based on model class + if reflection_lm in _FAST_REFLECTION_MODELS: + secs_per_reflection = 5.0 + elif "sonnet" in reflection_lm.lower(): + secs_per_reflection = 20.0 + elif "haiku" in reflection_lm.lower(): + secs_per_reflection = 8.0 + else: + secs_per_reflection = 15.0 + + # Scale by candidate size (larger candidates → slower) + size_factor = min(max(1.0, total_candidate_tokens / 10_000), 2.5) + adjusted = secs_per_reflection * size_factor + + # Approximate iterations (each iteration evaluates all dataset examples) + num_iterations = max(1, num_metric_calls // max(num_dataset_examples, 1)) + + return num_iterations * adjusted + + +def get_preset( + name: str, + reflection_lm: str | None = None, + num_components: int = 1, + max_metric_calls_override: int | None = None, +) -> GEPAConfig: + """Get a GEPA config preset by name, scaled by component count. + + When optimizing multiple components (skill + tool modules), GEPA's + round-robin selector divides the budget across all of them. We scale + ``max_metric_calls`` so that *each component* receives the preset's + base budget rather than splitting it. + + For slower reflection models (non-Opus/GPT-4o), the total metric calls + are capped at ``MAX_METRIC_CALLS_PER_PASS`` to avoid multi-hour hangs. + + Args: + name: One of "quick", "standard", "thorough" + reflection_lm: Override reflection LM model string + num_components: Number of GEPA components (used to scale budget) + max_metric_calls_override: Explicit cap on metric calls per pass + + Returns: + GEPAConfig instance + """ + if name not in PRESETS: + raise KeyError(f"Unknown preset '{name}'. Choose from: {list(PRESETS.keys())}") + + # Validate Databricks env if using databricks/ prefix + effective_lm = reflection_lm or DEFAULT_REFLECTION_LM + if isinstance(effective_lm, str) and effective_lm.startswith("databricks/"): + validate_databricks_env() + + base_calls = PRESET_BASE_CALLS[name] + scaled_calls = base_calls * max(num_components, 1) + + # Apply explicit override if provided + if max_metric_calls_override is not None: + scaled_calls = max_metric_calls_override + else: + # Apply per-preset cap first (safety net for multi-component modes) + preset_cap = PRESET_MAX_CALLS[name] + if scaled_calls > preset_cap: + scaled_calls = preset_cap + + # Cap for slower models to avoid multi-hour hangs + if ( + max_metric_calls_override is None + and effective_lm not in _FAST_REFLECTION_MODELS + and scaled_calls > MAX_METRIC_CALLS_PER_PASS + ): + warnings.warn( + f"Capping metric calls from {scaled_calls} to {MAX_METRIC_CALLS_PER_PASS} " + f"for reflection model '{effective_lm}'. " + f"Use --max-metric-calls to override, or use a faster model " + f"(e.g., databricks/databricks-claude-opus-4-6).", + stacklevel=2, + ) + scaled_calls = MAX_METRIC_CALLS_PER_PASS + + config = PRESETS[name] + config = GEPAConfig( + engine=EngineConfig( + max_metric_calls=scaled_calls, + parallel=config.engine.parallel, + ), + reflection=ReflectionConfig( + reflection_lm=reflection_lm or config.reflection.reflection_lm, + reflection_minibatch_size=config.reflection.reflection_minibatch_size, + skip_perfect_score=config.reflection.skip_perfect_score, + ), + merge=config.merge, + refiner=config.refiner, + tracking=config.tracking, + ) + return config diff --git a/.test/src/skill_test/optimize/judges.py b/.test/src/skill_test/optimize/judges.py new file mode 100644 index 00000000..dfc748a3 --- /dev/null +++ b/.test/src/skill_test/optimize/judges.py @@ -0,0 +1,503 @@ +"""MLflow judge factories for skill evaluation. + +Replaces the 6 separate judge calls and binary assertion layer with three +focused judges that provide both scores AND rich rationale for GEPA's +reflection LM. + +Judges: + quality_judge — Scores a single response (0.0-1.0) against expectations. + effectiveness_judge — Compares WITH vs WITHOUT responses, returns verdict. + regression_judge — Identifies specific ways a skill harms responses. + +Judge model resolution (highest priority first): + 1. Explicit ``judge_model`` argument to factory functions + 2. ``GEPA_JUDGE_LM`` environment variable + 3. ``databricks:/databricks-claude-sonnet-4-6`` (default) + +Model fallback: + On rate limit errors (REQUEST_LIMIT_EXCEEDED), automatically retries with + fallback models. Configure via ``GEPA_FALLBACK_MODELS`` env var (comma-separated) + or use the built-in Databricks fallback chain. + +AI Gateway support: + Set ``DATABRICKS_AI_GATEWAY_URL`` to route calls through Databricks AI Gateway. + Example: https://1444828305810485.ai-gateway.cloud.databricks.com/mlflow/v1 + Works alongside the standard serving endpoint approach. +""" + +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass +from typing import Any + +from mlflow.genai.judges import make_judge + +logger = logging.getLogger(__name__) + +DEFAULT_JUDGE_LM = os.environ.get("GEPA_JUDGE_LM", "databricks:/databricks-claude-sonnet-4-6") + +# --------------------------------------------------------------------------- +# Fallback model chain for rate limit errors +# --------------------------------------------------------------------------- + +_DEFAULT_FALLBACK_MODELS = [ + "databricks/databricks-gpt-5-2", + "databricks/databricks-gemini-3-1-pro", + "databricks/databricks-claude-opus-4-5", + "databricks/databricks-gpt-5", + "databricks/databricks-claude-sonnet-4-6", + "databricks/databricks-claude-sonnet-4-5", +] + + +def _get_fallback_models() -> list[str]: + """Get fallback model chain from env or defaults.""" + custom = os.environ.get("GEPA_FALLBACK_MODELS", "") + if custom.strip(): + return [m.strip() for m in custom.split(",") if m.strip()] + return list(_DEFAULT_FALLBACK_MODELS) + + +def _is_rate_limit_error(exc: Exception) -> bool: + """Check if an exception is a rate limit / request limit exceeded error.""" + msg = str(exc).lower() + return any( + phrase in msg + for phrase in [ + "rate_limit", + "rate limit", + "request_limit_exceeded", + "request limit exceeded", + "too many requests", + "429", + "token.*per.*minute", + ] + ) + + +# --------------------------------------------------------------------------- +# AI Gateway support +# --------------------------------------------------------------------------- + +DATABRICKS_AI_GATEWAY_URL = os.environ.get("DATABRICKS_AI_GATEWAY_URL", "") + + +def _get_gateway_base_url() -> str | None: + """Return the AI Gateway base URL if configured, else None.""" + url = DATABRICKS_AI_GATEWAY_URL.strip() + if not url: + return None + return url.rstrip("/") + + +def _to_litellm_model(model: str) -> tuple[str, str | None]: + """Convert a model string to (litellm_model, base_url) for completion calls. + + If AI Gateway is configured and model is a databricks/ model, routes + through the gateway as an OpenAI-compatible endpoint. Otherwise returns + the model unchanged with no base_url override. + + Returns: + (model_string, base_url_or_None) + """ + gateway = _get_gateway_base_url() + if gateway and model.startswith("databricks/"): + # Route through AI Gateway as OpenAI-compatible endpoint + endpoint_name = model.split("/", 1)[1] + return f"openai/{endpoint_name}", gateway + return model, None + + +# --------------------------------------------------------------------------- +# URI conversion +# --------------------------------------------------------------------------- + + +def _to_judge_uri(model: str) -> str: + """Convert litellm-style model strings to MLflow judge URI format. + + litellm uses ``provider/model`` (e.g. ``databricks/databricks-claude-sonnet-4-6``). + MLflow judges use ``provider:/model`` (e.g. ``databricks:/databricks-claude-sonnet-4-6``). + """ + if ":/" in model: + return model + if "/" in model: + provider, name = model.split("/", 1) + return f"{provider}:/{name}" + return model + + +def _judge_inference_params() -> dict[str, Any] | None: + """Build inference_params for make_judge if AI Gateway is configured.""" + gateway = _get_gateway_base_url() + if gateway: + return {"base_url": gateway} + return None + + +def _to_judge_model_and_params(model: str) -> tuple[str, dict[str, Any] | None]: + """Convert a model string to (judge_uri, inference_params) for make_judge. + + If AI Gateway is configured, uses ``openai:/endpoint-name`` with + ``inference_params.base_url`` pointing to the gateway. Otherwise + uses standard ``provider:/model`` format. + """ + gateway = _get_gateway_base_url() + if gateway and model.startswith(("databricks/", "databricks:/")): + # Extract the endpoint name + if ":/" in model: + endpoint_name = model.split(":/", 1)[1] + else: + endpoint_name = model.split("/", 1)[1] + return f"openai:/{endpoint_name}", {"base_url": gateway} + return _to_judge_uri(model), _judge_inference_params() + + +# --------------------------------------------------------------------------- +# Completion with fallback +# --------------------------------------------------------------------------- + + +def completion_with_fallback(*, model: str, max_retries: int = 3, **kwargs) -> Any: + """Call litellm.completion with model fallback on rate limit errors. + + Tries the primary model first. On rate limit errors, cycles through + the fallback chain. Each model gets ``max_retries`` attempts with + exponential backoff before moving to the next. + + Also supports AI Gateway: if DATABRICKS_AI_GATEWAY_URL is set, + databricks/ models are routed through the gateway. + """ + import litellm + + models_to_try = [model] + [m for m in _get_fallback_models() if m != model] + + last_err: Exception | None = None + for model_str in models_to_try: + litellm_model, base_url = _to_litellm_model(model_str) + + call_kwargs = dict(kwargs) + call_kwargs["model"] = litellm_model + if base_url: + call_kwargs["base_url"] = base_url + + for attempt in range(max_retries): + if attempt > 0: + delay = min(2**attempt, 30) + time.sleep(delay) + try: + return litellm.completion(**call_kwargs) + except Exception as e: + last_err = e + if _is_rate_limit_error(e): + if attempt == max_retries - 1: + logger.warning( + "Model '%s' rate limited after %d attempts, trying next fallback", + model_str, + max_retries, + ) + continue + # Non-rate-limit error: don't retry, try next model + logger.warning("Model '%s' failed (non-rate-limit): %s", model_str, e) + break + + raise last_err # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + + +@dataclass +class JudgeFeedback: + """Structured feedback from a judge call.""" + + value: float | str + rationale: str + name: str + + +def _safe_parse_score(raw_value: Any) -> float: + """Convert judge output to a float score in [0.0, 1.0]. + + Handles: bool, "yes"/"no", numeric, float-as-string. + """ + if isinstance(raw_value, (int, float)): + return max(0.0, min(1.0, float(raw_value))) + if isinstance(raw_value, bool): + return 1.0 if raw_value else 0.0 + if isinstance(raw_value, str): + low = raw_value.strip().lower() + if low == "yes": + return 1.0 + if low == "no": + return 0.0 + try: + return max(0.0, min(1.0, float(low))) + except ValueError: + pass + return 0.0 + + +# --------------------------------------------------------------------------- +# Quality judge — primary scorer for a single response +# --------------------------------------------------------------------------- + +_QUALITY_INSTRUCTIONS = """\ +You are an expert evaluator for Databricks skill documentation quality. +Rate the response on a scale from 0.0 to 1.0 based on how well it addresses +the user's question using correct, complete, and relevant information. + +## Evaluation Criteria + +1. **Relevance** (does the response address the question?) +2. **Completeness** (are all parts of the question answered?) +3. **Correctness** (are the facts and API references accurate?) +4. **Pattern adherence** (does the response follow expected code patterns?) +5. **API accuracy** (are function names, parameters, and syntax correct?) + +## Expected Facts, Patterns, and Guidelines + +{{ expectations }} + +## Input + +Question: {{ inputs }} +Response: {{ outputs }} + +## Instructions + +Return a score between 0.0 and 1.0 where: +- 1.0 = perfect response, all facts present, all patterns correct +- 0.7 = good response, most facts present, minor gaps +- 0.4 = partial response, significant gaps or inaccuracies +- 0.1 = poor response, mostly wrong or off-topic +- 0.0 = completely wrong or empty + +Provide detailed rationale explaining: +- Which expected facts are present vs missing +- Which patterns are correctly followed vs violated +- Specific API or syntax errors found +- What would need to change to improve the score +""" + + +def create_skill_quality_judge( + skill_guidelines: list[str] | None = None, + judge_model: str | None = None, +) -> Any: + """Create a universal quality judge for scoring responses. + + Args: + skill_guidelines: Optional per-skill evaluation principles from + ground_truth.yaml guidelines across all test cases. + judge_model: LLM model for the judge. Defaults to GEPA_JUDGE_LM env + or databricks/databricks-claude-sonnet-4-6. + """ + instructions = _QUALITY_INSTRUCTIONS + if skill_guidelines: + principles = "\n".join(f"- {g}" for g in skill_guidelines) + instructions += f"\n\n## Domain-Specific Principles\n{principles}\n" + + model_uri, inference_params = _to_judge_model_and_params(judge_model or DEFAULT_JUDGE_LM) + return make_judge( + name="skill_quality", + model=model_uri, + instructions=instructions, + feedback_value_type=float, + inference_params=inference_params, + ) + + +# --------------------------------------------------------------------------- +# Effectiveness judge — WITH vs WITHOUT comparison +# --------------------------------------------------------------------------- + +_EFFECTIVENESS_INSTRUCTIONS = """\ +You are comparing two responses to the same question to determine whether +a skill document helped or hurt the agent's response quality. + +The inputs contain three fields separated by markers: +- QUESTION: the user's question +- WITH-SKILL RESPONSE: generated with the skill document in context +- WITHOUT-SKILL RESPONSE: generated without any skill document + +The expectations contain the expected facts and patterns. + +## Inputs + +{{ inputs }} + +## Expected Information + +{{ expectations }} + +## Instructions + +Determine whether the skill IMPROVED, maintained (SAME), or REGRESSED the +response quality. Return one of exactly: "improved", "same", "regressed". + +An "improved" verdict means the WITH-skill response is meaningfully better: +more accurate facts, better code patterns, correct API usage that the +WITHOUT response got wrong. + +A "regressed" verdict means the skill actively HURT the response: introduced +incorrect information, deprecated APIs, or confused the agent. + +"same" means no meaningful difference. + +Provide detailed rationale explaining: +- What the skill added or removed from the response +- Specific facts/patterns that differ between WITH and WITHOUT +- Whether the skill taught something the model didn't already know +- If regressed: what specifically the skill got wrong +""" + + +def create_effectiveness_judge(judge_model: str | None = None) -> Any: + """Create a WITH vs WITHOUT comparison judge. + + Args: + judge_model: LLM model for the judge. Defaults to GEPA_JUDGE_LM env + or databricks/databricks-claude-sonnet-4-6. + """ + model_uri, inference_params = _to_judge_model_and_params(judge_model or DEFAULT_JUDGE_LM) + return make_judge( + name="skill_effectiveness", + model=model_uri, + instructions=_EFFECTIVENESS_INSTRUCTIONS, + feedback_value_type=str, + inference_params=inference_params, + ) + + +# --------------------------------------------------------------------------- +# Regression judge — identifies how a skill harms responses +# --------------------------------------------------------------------------- + +_REGRESSION_INSTRUCTIONS = """\ +You are a regression detector for Databricks skill documents. Your job is +to identify specific ways that a skill document HARMS agent responses. + +The inputs contain three fields separated by markers: +- QUESTION: the user's question +- WITH-SKILL RESPONSE: generated with the skill document in context +- WITHOUT-SKILL RESPONSE: generated without any skill document + +## Input + +{{ inputs }} + +## Instructions + +Identify specific regressions introduced by the skill. Return "yes" if +regressions are found, "no" if the skill is harmless. + +Common regression patterns: +1. **Deprecated APIs** — skill teaches old APIs the model already uses correctly +2. **Verbosity** — skill adds noise that confuses the model +3. **Contradicting correct knowledge** — model was right, skill made it wrong +4. **Wrong examples** — skill's code examples have errors the model copies +5. **Over-specification** — skill's rigid patterns prevent correct alternatives + +For each regression found, explain: +- WHAT specific content in the skill caused the regression +- WHY it made the response worse +- WHAT to remove or change in the skill to fix it +""" + + +def create_regression_judge(judge_model: str | None = None) -> Any: + """Create a regression detection judge. + + Args: + judge_model: LLM model for the judge. Defaults to GEPA_JUDGE_LM env + or databricks/databricks-claude-sonnet-4-6. + """ + model_uri, inference_params = _to_judge_model_and_params(judge_model or DEFAULT_JUDGE_LM) + return make_judge( + name="skill_regression", + model=model_uri, + instructions=_REGRESSION_INSTRUCTIONS, + feedback_value_type=bool, + inference_params=inference_params, + ) + + +# --------------------------------------------------------------------------- +# Helper: run a judge safely with fallback on rate limit +# --------------------------------------------------------------------------- + + +def run_judge_safe( + judge: Any, + *, + inputs: Any, + outputs: Any | None = None, + expectations: Any | None = None, + name: str = "judge", +) -> JudgeFeedback: + """Run a judge with error handling and model fallback. + + On rate limit errors, recreates the judge with fallback models and + retries. On other errors, returns zero-score feedback so evaluation + never crashes from a judge failure. + """ + kwargs: dict[str, Any] = {"inputs": inputs} + if outputs is not None: + kwargs["outputs"] = outputs + if expectations is not None: + kwargs["expectations"] = expectations + + # Try the primary judge first + try: + fb = judge(**kwargs) + return JudgeFeedback( + value=fb.value, + rationale=fb.rationale or "", + name=name, + ) + except Exception as e: + if not _is_rate_limit_error(e): + logger.warning("Judge '%s' failed: %s", name, e) + return JudgeFeedback(value=0.0, rationale=f"Judge error: {e}", name=name) + + # Rate limit hit — try fallback models + logger.warning("Judge '%s' rate limited, trying fallback models", name) + fallbacks = _get_fallback_models() + + for fallback_model in fallbacks: + model_uri, inference_params = _to_judge_model_and_params(fallback_model) + try: + fallback_judge = make_judge( + name=judge.name, + model=model_uri, + instructions=judge._instructions, + feedback_value_type=judge._feedback_value_type, + inference_params=inference_params, + ) + fb = fallback_judge(**kwargs) + logger.info("Judge '%s' succeeded with fallback model '%s'", name, fallback_model) + return JudgeFeedback( + value=fb.value, + rationale=fb.rationale or "", + name=name, + ) + except Exception as fallback_err: + if _is_rate_limit_error(fallback_err): + logger.warning("Fallback '%s' also rate limited, trying next", fallback_model) + continue + logger.warning("Fallback '%s' failed: %s", fallback_model, fallback_err) + continue + + # All fallbacks exhausted + logger.error("Judge '%s': all models rate limited", name) + return JudgeFeedback( + value=0.0, + rationale="All models rate limited — no judge score available", + name=name, + ) diff --git a/.test/src/skill_test/optimize/review.py b/.test/src/skill_test/optimize/review.py new file mode 100644 index 00000000..2374d429 --- /dev/null +++ b/.test/src/skill_test/optimize/review.py @@ -0,0 +1,390 @@ +"""Review and apply workflow for optimization results. + +Provides human-readable output of optimization results and the ability +to apply the optimized SKILL.md to the repository. + +After each optimization run, results are saved to: + .test/skills//optimized_SKILL.md — the optimized content + .test/skills//last_optimization.md — summary with scores and diff + +Use ``--apply-last`` to apply a saved result without re-running optimization. +""" + +import difflib +import json +from datetime import datetime, timezone +from pathlib import Path + +from .runner import OptimizationResult +from .utils import find_skill_md as _find_skill_md + + +def _get_results_dir(skill_name: str) -> Path: + """Get the results directory for a skill.""" + # Try standard skills dir first + candidates = [ + Path(".test/skills") / skill_name, + Path(__file__).resolve().parent.parent.parent / "skills" / skill_name, + ] + for d in candidates: + if d.exists(): + return d + # Fallback: create under .test/skills + d = Path(".test/skills") / skill_name + d.mkdir(parents=True, exist_ok=True) + return d + + +def save_result(result: OptimizationResult) -> tuple[Path | None, Path | None]: + """Save optimization results to disk for later application. + + Writes two files: + - ``optimized_SKILL.md`` — the raw optimized content (can be diffed/reviewed) + - ``last_optimization.json`` — metadata for ``--apply-last`` + + Returns: + Tuple of (optimized_skill_path, metadata_path), either may be None on error. + """ + if result.improvement <= 0 and result.original_content == result.optimized_content: + return None, None + + results_dir = _get_results_dir(result.skill_name) + + optimized_path = None + metadata_path = None + + # Write the optimized SKILL.md + if result.optimized_content and result.optimized_content != result.original_content: + optimized_path = results_dir / "optimized_SKILL.md" + optimized_path.write_text(result.optimized_content) + + # Write metadata for --apply-last + metadata = { + "skill_name": result.skill_name, + "timestamp": datetime.now(timezone.utc).isoformat(), + "original_score": result.original_score, + "optimized_score": result.optimized_score, + "improvement": result.improvement, + "original_token_count": result.original_token_count, + "optimized_token_count": result.optimized_token_count, + "token_reduction_pct": result.token_reduction_pct, + "diff_summary": result.diff_summary, + "mlflow_run_id": result.mlflow_run_id, + "evaluator_type": getattr(result, "evaluator_type", "legacy"), + } + + # Save tool components if present + if result.components: + tool_components = {k: v for k, v in result.components.items() if k.startswith("tools_")} + if tool_components: + metadata["has_tool_components"] = True + # Save each tool component + for comp_name, comp_text in tool_components.items(): + comp_path = results_dir / f"optimized_{comp_name}.txt" + comp_path.write_text(comp_text) + + metadata_path = results_dir / "last_optimization.json" + metadata_path.write_text(json.dumps(metadata, indent=2)) + + return optimized_path, metadata_path + + +def load_last_result(skill_name: str) -> OptimizationResult | None: + """Load the last saved optimization result for a skill. + + Returns: + OptimizationResult reconstructed from saved files, or None if not found. + """ + results_dir = _get_results_dir(skill_name) + metadata_path = results_dir / "last_optimization.json" + optimized_path = results_dir / "optimized_SKILL.md" + + if not metadata_path.exists(): + return None + + metadata = json.loads(metadata_path.read_text()) + + # Load optimized content + optimized_content = "" + if optimized_path.exists(): + optimized_content = optimized_path.read_text() + + # Load original content + original_content = "" + skill_path = _find_skill_md(skill_name) + if skill_path: + original_content = skill_path.read_text() + + # Reconstruct tool components + components = None + if metadata.get("has_tool_components"): + components = {} + if optimized_content: + components["skill_md"] = optimized_content + for f in results_dir.glob("optimized_tools_*.txt"): + comp_name = f.stem.replace("optimized_", "") + components[comp_name] = f.read_text() + + return OptimizationResult( + skill_name=skill_name, + original_score=metadata.get("original_score", 0.0), + optimized_score=metadata.get("optimized_score", 0.0), + improvement=metadata.get("improvement", 0.0), + original_content=original_content, + optimized_content=optimized_content, + original_token_count=metadata.get("original_token_count", 0), + optimized_token_count=metadata.get("optimized_token_count", 0), + token_reduction_pct=metadata.get("token_reduction_pct", 0.0), + diff_summary=metadata.get("diff_summary", ""), + val_scores={}, + mlflow_run_id=metadata.get("mlflow_run_id"), + gepa_result=None, + components=components, + ) + + +def review_optimization(result: OptimizationResult) -> None: + """Print optimization summary for human review. + + Shows: score improvement, token reduction, judge-based effectiveness, + per-test-case score breakdown, and diff of changes. + """ + print(f"\n{'=' * 60}") + print(f" Optimization Results: {result.skill_name}") + print(f"{'=' * 60}") + + si = result.skillbench_side_info or {} + + # Aggregate judge-based scores from per-task side_info + task_count = 0 + sum_with = 0.0 + sum_without = 0.0 + sum_eff = 0.0 + per_task_lines: list[str] = [] + + for task_id in sorted(si.keys()): + info = si[task_id] + scores = info.get("scores", {}) + pw = scores.get("quality_with", 0.0) + pwo = scores.get("quality_without", 0.0) + eff = scores.get("skill_effectiveness", 0.0) + sum_with += pw + sum_without += pwo + sum_eff += eff + task_count += 1 + + # Build per-task notes + error = info.get("Error", "") + notes = [] + if "NEEDS_SKILL" in error: + notes.append("NEEDS_SKILL") + if "REGRESSION" in error: + notes.append("REGRESSION") + if not notes: + notes.append("OK") + note_str = f" [{'; '.join(notes)}]" + per_task_lines.append(f" {task_id:<30s} WITH {pw:.2f} WITHOUT {pwo:.2f} delta {eff:+.2f}{note_str}") + + if task_count > 0: + agg_with = sum_with / task_count + agg_without = sum_without / task_count + agg_eff = sum_eff / task_count + else: + agg_with = agg_without = agg_eff = 0.0 + + # Score summary + improvement_sign = "+" if result.improvement >= 0 else "" + print( + f" Score: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({improvement_sign}{result.improvement:.3f})" + ) + print(f" Skill Effectiveness: {agg_eff:.2f}") + print(f" Quality (with): {agg_with:.2f}") + print(f" Quality (without): {agg_without:.2f} (baseline)") + + # Token counts + reduction_sign = "+" if result.token_reduction_pct >= 0 else "" + print( + f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,} " + f"({reduction_sign}{result.token_reduction_pct:.1f}%)" + ) + + if result.gepa_result and hasattr(result.gepa_result, "iterations"): + print(f" Iterations: {result.gepa_result.iterations}") + if result.mlflow_run_id: + print(f" MLflow run: {result.mlflow_run_id}") + + print() + + # Per-task breakdown + if per_task_lines: + print(" Per-task:") + for line in per_task_lines: + print(line) + print() + + # Diff summary + if result.diff_summary and result.diff_summary != "No changes": + print(" Changes:") + for line in result.diff_summary.split("\n"): + print(f" {line}") + print() + + # Detailed diff (first 50 lines) + if result.original_content != result.optimized_content: + diff_lines = list( + difflib.unified_diff( + result.original_content.splitlines(keepends=True), + result.optimized_content.splitlines(keepends=True), + fromfile="original SKILL.md", + tofile="optimized SKILL.md", + n=2, + ) + ) + if len(diff_lines) > 50: + print(f" Diff (first 50 of {len(diff_lines)} lines):") + for line in diff_lines[:50]: + print(f" {line}", end="") + print(f"\n ... ({len(diff_lines) - 50} more lines)") + else: + print(" Diff:") + for line in diff_lines: + print(f" {line}", end="") + print() + else: + print(" No changes to SKILL.md content.") + + # Validation breakdown + if result.val_scores: + print(" Validation scores by test case:") + for task_id, score in sorted(result.val_scores.items()): + status = "PASS" if score >= 0.5 else "FAIL" + print(f" {status} {task_id}: {score:.3f}") + print() + + # Auto-save result to disk + saved_skill, saved_meta = save_result(result) + if saved_skill: + print(f" Saved: {saved_skill}") + print(f" Apply: uv run python .test/scripts/optimize.py {result.skill_name} --apply-last") + elif result.original_content == result.optimized_content: + print(" No improvement found -- nothing saved.") + print(f"{'=' * 60}\n") + + +def apply_optimization(result: OptimizationResult) -> Path | None: + """Apply optimized SKILL.md and/or tool descriptions. + + Writes back: + - SKILL.md (if changed) + - MCP tool docstrings (if tools were included in optimization) + + Args: + result: OptimizationResult from optimize_skill() + + Returns: + Path to the updated SKILL.md (or None if tools_only) + + Raises: + ValueError: If optimization did not improve the skill + """ + if result.improvement < 0: + raise ValueError( + f"Optimization regressed quality ({result.improvement:+.3f}). Refusing to apply. Use --force to override." + ) + + skill_path = None + + # Apply SKILL.md changes + if result.optimized_content and result.optimized_content != result.original_content: + skill_path = _find_skill_md(result.skill_name) + if skill_path: + skill_path.write_text(result.optimized_content) + print(f"Applied optimized SKILL.md to {skill_path}") + + # Apply tool description changes + if result.tool_map and result.components: + from .tools import parse_gepa_component, write_tool_descriptions + + all_optimized_tools = {} + for comp_name, comp_text in result.components.items(): + if comp_name.startswith("tools_"): + parsed = parse_gepa_component(comp_text) + all_optimized_tools.update(parsed) + + if all_optimized_tools: + modified = write_tool_descriptions(all_optimized_tools, result.tool_map) + if modified: + print(f"Applied optimized tool descriptions to {len(modified)} files:") + for f in modified: + print(f" {f}") + + print(f" Quality: {result.original_score:.3f} -> {result.optimized_score:.3f} ({result.improvement:+.3f})") + print( + f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,} " + f"({result.token_reduction_pct:+.1f}%)" + ) + + # Try to update baseline + try: + from ..runners.compare import save_baseline + + if result.mlflow_run_id: + save_baseline( + skill_name=result.skill_name, + run_id=result.mlflow_run_id, + metrics={"optimized_score": result.optimized_score}, + test_count=len(result.val_scores) if result.val_scores else 0, + ) + print(" Baseline updated.") + except Exception: + pass + + return skill_path + + +def format_cost_estimate( + train_count: int, + val_count: int | None, + preset: str, + mode: str, +) -> str: + """Estimate the cost of running optimization. + + Args: + train_count: Number of training tasks + val_count: Number of validation tasks (or None) + preset: Preset name + mode: "static" or "generative" + + Returns: + Human-readable cost estimate string + """ + # Rough estimates based on preset + max_calls = {"quick": 15, "standard": 50, "thorough": 150}.get(preset, 50) + + # Each metric call runs all scorers on all train tasks + calls_per_iteration = train_count + if val_count: + calls_per_iteration += val_count + + total_scorer_calls = max_calls * calls_per_iteration + + if mode == "static": + # Static mode: ~$0.001 per scorer call (just deterministic checks) + est_cost = total_scorer_calls * 0.001 + else: + # Generative mode: ~$0.01 per call (LLM generation + scoring) + est_cost = total_scorer_calls * 0.01 + + # GEPA reflection calls + reflection_cost = max_calls * 0.02 # ~$0.02 per reflection + + total = est_cost + reflection_cost + + return ( + f"Estimated cost: ~${total:.2f}\n" + f" Scorer calls: {total_scorer_calls:,} x {'$0.001' if mode == 'static' else '$0.01'}\n" + f" Reflection calls: {max_calls} x $0.02\n" + f" Max iterations: {max_calls}" + ) diff --git a/.test/src/skill_test/optimize/runner.py b/.test/src/skill_test/optimize/runner.py new file mode 100644 index 00000000..d1d0ef16 --- /dev/null +++ b/.test/src/skill_test/optimize/runner.py @@ -0,0 +1,547 @@ +"""End-to-end orchestrator for GEPA skill optimization. + +Uses optimize_anything API: evaluator function + GEPAConfig. +Single evaluator path using SkillBench judge-based evaluation. +""" + +import copy +import difflib +from dataclasses import dataclass +from typing import Any + +from gepa.optimize_anything import optimize_anything + +from ..config import SkillTestConfig +from ..runners.evaluate import setup_mlflow +from .config import ( + get_preset, + validate_reflection_context, + estimate_pass_duration, + DEFAULT_GEN_LM, + DEFAULT_TOKEN_BUDGET, +) +from .utils import SKILL_KEY, count_tokens, find_skill_md +from .skillbench_evaluator import ( + create_skillbench_evaluator, + build_skillbench_background, +) +from .splitter import create_gepa_datasets, generate_bootstrap_tasks, to_gepa_instances, create_cross_skill_dataset +from .tools import ( + extract_tool_descriptions, + tools_to_gepa_components, + get_tool_stats, +) + + +@dataclass +class OptimizationResult: + """Result of a GEPA optimization run.""" + + skill_name: str + original_score: float + optimized_score: float + improvement: float + original_content: str + optimized_content: str + original_token_count: int + optimized_token_count: int + token_reduction_pct: float + diff_summary: str + val_scores: dict[str, float] + mlflow_run_id: str | None + gepa_result: Any + components: dict[str, str] | None = None + original_components: dict[str, str] | None = None + tool_map: Any = None + evaluator_type: str = "skillbench" + skillbench_side_info: dict[str, dict] | None = None + + +def _compute_diff_summary(original: str, optimized: str) -> str: + """Generate a human-readable summary of changes.""" + original_lines = original.splitlines(keepends=True) + optimized_lines = optimized.splitlines(keepends=True) + diff = list(difflib.unified_diff(original_lines, optimized_lines, fromfile="original", tofile="optimized", n=1)) + + if not diff: + return "No changes" + + added = sum(1 for line in diff if line.startswith("+") and not line.startswith("+++")) + removed = sum(1 for line in diff if line.startswith("-") and not line.startswith("---")) + + parts = [] + if added: + parts.append(f"+{added} lines added") + if removed: + parts.append(f"-{removed} lines removed") + + changed_sections = set() + for line in diff: + content = line[1:].strip() if line.startswith(("+", "-")) and not line.startswith(("+++", "---")) else "" + if content.startswith("#"): + changed_sections.add(content) + + summary = ", ".join(parts) + if changed_sections: + sections = "\n".join(f" ~ {s}" for s in sorted(changed_sections)[:10]) + summary += f"\n\nChanged sections:\n{sections}" + + return summary + + +def _evaluate_on_tasks(evaluator, candidate, tasks, label: str = "Evaluating"): + """Run evaluator on tasks and return mean score, per-task scores, and per-task side_info. + + Returns: + (mean_score, per_task_scores, side_info_by_id, side_info_by_input) + """ + + gepa_instances = to_gepa_instances(tasks) + total = len(gepa_instances) + per_task = {} + side_info_by_id = {} + side_info_by_input = {} + for i, inst in enumerate(gepa_instances): + task_id = tasks[i].get("id", f"task_{i}") + print(f"\r {label}: {i + 1}/{total} ({task_id})...", end="", flush=True) + score, side_info = evaluator(candidate, inst) + per_task[task_id] = score + side_info_by_id[task_id] = side_info + side_info_by_input[inst.get("input", f"task_{i}")] = side_info + mean = sum(per_task.values()) / len(per_task) if per_task else 0.0 + print(f"\r {label}: {total}/{total} done. Mean: {mean:.3f} ") + return mean, per_task, side_info_by_id, side_info_by_input + + +def optimize_skill( + skill_name: str, + preset: str = "standard", + gen_model: str | None = None, + reflection_lm: str | None = None, + include_tools: bool = False, + tool_modules: list[str] | None = None, + tools_only: bool = False, + dry_run: bool = False, + max_passes: int = 5, + max_metric_calls: int | None = None, + token_budget: int | None = None, + judge_model: str | None = None, + align: bool = False, + run_dir: str | None = None, + # Deprecated params kept for backward compat + mode: str = "static", + task_lm: str | None = None, + evaluator_type: str = "skillbench", + use_judges: bool = True, +) -> OptimizationResult: + """Run end-to-end GEPA optimization on a skill and/or tools. + + Uses optimize_anything API with judge-based evaluation. + Runs up to ``max_passes`` optimization passes per component, feeding + each pass's best candidate as the seed for the next. + + Args: + skill_name: Name of the skill to optimize + preset: "quick" (15), "standard" (50), "thorough" (150) + gen_model: LLM for generative evaluation + reflection_lm: Override reflection LM + include_tools: Include MCP tool descriptions as additional components + tool_modules: Specific tool modules (None = all) + tools_only: Optimize ONLY tool descriptions + dry_run: Show config without running + max_passes: Maximum optimization passes (default 5) + max_metric_calls: Override max metric calls per pass + token_budget: Hard token ceiling + judge_model: Override judge model (future use) + align: Use MemAlign alignment (future use) + run_dir: Directory for GEPA checkpoints. Resumes from last state if dir exists. + """ + # 1. Load SKILL.md + skill_path = find_skill_md(skill_name) + if not tools_only and skill_path is None: + raise FileNotFoundError(f"Could not find SKILL.md for '{skill_name}'") + + original_content = skill_path.read_text() if skill_path else "" + + # 1b. Load MCP tool descriptions + tool_map = None + tool_components: dict[str, str] = {} + tool_context_str: str | None = None + + # Always load tool descriptions for context + try: + tool_map = extract_tool_descriptions(modules=tool_modules) + tool_components = tools_to_gepa_components(tool_map, per_module=True) + stats = get_tool_stats() + print( + f"Tool modules: {stats['modules']}, tools: {stats['total_tools']}, " + f"description chars: {stats['total_description_chars']:,}" + ) + except FileNotFoundError: + pass # No MCP tools directory — skip + + # Build read-only tool context string (for skill optimization) + if tool_components: + tool_context_str = "\n\n".join(tool_components[k] for k in sorted(tool_components)) + + # 2. Build seed_candidate (multi-component dict) + seed_candidate: dict[str, str] = {} + original_token_counts: dict[str, int] = {} + + if tools_only: + # Tools-only mode: tool descriptions ARE the GEPA components + for comp_name, comp_text in tool_components.items(): + seed_candidate[comp_name] = comp_text + original_token_counts[comp_name] = count_tokens(comp_text) + tool_context_str = None # tools are in candidate, not read-only context + elif include_tools: + # Explicit --include-tools: both skill and tools are GEPA components + seed_candidate[SKILL_KEY] = original_content + original_token_counts[SKILL_KEY] = count_tokens(original_content) + for comp_name, comp_text in tool_components.items(): + seed_candidate[comp_name] = comp_text + original_token_counts[comp_name] = count_tokens(comp_text) + tool_context_str = None # tools are in candidate, not read-only context + else: + # Default: skill is the only GEPA component; tools are read-only context + seed_candidate[SKILL_KEY] = original_content + original_token_counts[SKILL_KEY] = count_tokens(original_content) + + total_original_tokens = sum(original_token_counts.values()) + + # Resolve token budget + token_budget = token_budget or DEFAULT_TOKEN_BUDGET + + # 3. Load datasets + if tools_only: + # Cross-skill dataset for tool optimization + train = create_cross_skill_dataset(max_per_skill=5) + val = None + if train: + source_skills = {t.get("metadata", {}).get("source_skill", "?") for t in train} + print(f"Cross-skill dataset: {len(train)} tasks from {len(source_skills)} skill(s)") + else: + # Fall back to single-skill dataset + try: + train, val = create_gepa_datasets(skill_name) + except FileNotFoundError: + train, val = [], None + else: + try: + train, val = create_gepa_datasets(skill_name) + except FileNotFoundError: + train, val = [], None + + if not train: + train = generate_bootstrap_tasks(skill_name) + val = None + print(f"No test cases found. Using {len(train)} auto-generated tasks.") + + # 4. Build evaluator + effective_gen_model = gen_model or task_lm or DEFAULT_GEN_LM + if effective_gen_model: + print(f"Generation model: {effective_gen_model}") + + from .judges import DEFAULT_JUDGE_LM + + effective_judge_model = judge_model or DEFAULT_JUDGE_LM + print(f"Judge model: {effective_judge_model}") + print("Evaluator: skillbench (judge-driven)") + + if not effective_gen_model: + raise ValueError("SkillBench evaluator requires a gen_model. Pass --gen-model or set GEPA_GEN_LM env var.") + evaluator = create_skillbench_evaluator( + skill_name, + gen_model=effective_gen_model, + original_token_counts=original_token_counts, + token_budget=token_budget, + judge_model=judge_model, + tool_context=tool_context_str, + ) + + # 5. Get config (scaled by component count) + num_components = len(seed_candidate) + config = get_preset( + preset, + reflection_lm=reflection_lm, + num_components=num_components, + max_metric_calls_override=max_metric_calls, + ) + print(f"Reflection model: {config.reflection.reflection_lm}") + + # 5b. Validate reflection model context window + validate_reflection_context( + config.reflection.reflection_lm, + total_original_tokens, + ) + + # 5c. Replace GEPA's reflection_lm string with a fallback-aware callable. + # GEPA internally calls make_litellm_lm() which wraps litellm.completion + # with no fallback. We pre-convert it so GEPA uses our version with + # model fallback on rate limit errors. + from .judges import completion_with_fallback + + _reflection_model_name = config.reflection.reflection_lm or "" + if isinstance(config.reflection.reflection_lm, str): + + def _reflection_lm_with_fallback(prompt): + if isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + else: + messages = prompt + result = completion_with_fallback( + model=_reflection_model_name, + messages=messages, + ) + return result.choices[0].message.content + + config.reflection.reflection_lm = _reflection_lm_with_fallback + + # Same for refiner_lm if present + if config.refiner is not None and isinstance(config.refiner.refiner_lm, str): + _refiner_model_name = config.refiner.refiner_lm + + def _refiner_lm_with_fallback(prompt): + if isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + else: + messages = prompt + result = completion_with_fallback( + model=_refiner_model_name, + messages=messages, + ) + return result.choices[0].message.content + + config.refiner.refiner_lm = _refiner_lm_with_fallback + + # Dry run + if dry_run: + print(f"\n=== Dry Run: {skill_name} (skillbench) ===") + if not tools_only: + print(f"SKILL.md path: {skill_path}") + print(f"Components: {list(seed_candidate.keys())}") + print(f"Total original tokens: {total_original_tokens:,}") + for comp, tokens in original_token_counts.items(): + print(f" {comp}: {tokens:,} tokens") + if tool_context_str: + print(f"Tool context (read-only): {count_tokens(tool_context_str):,} tokens") + print(f"Train tasks: {len(train)}") + print(f"Val tasks: {len(val) if val else 'None (single-task mode)'}") + print(f"Generation model: {effective_gen_model}") + print( + f"Preset: {preset} (max_metric_calls={config.engine.max_metric_calls}, " + f"scaled for {num_components} component(s))" + ) + print(f"Max passes: {max_passes}") + if run_dir: + print(f"Run dir: {run_dir}") + print(f"Reflection LM: {config.reflection.reflection_lm}") + + print(f"\nScoring baseline ({len(train)} tasks, ~5 LLM calls each)...") + original_score, original_per_task, si_by_id, _ = _evaluate_on_tasks( + evaluator, seed_candidate, train, label="Baseline" + ) + print(f"Current score: {original_score:.3f}") + for task_id, score in original_per_task.items(): + print(f" {task_id}: {score:.3f}") + + background = build_skillbench_background( + skill_name, + total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + token_budget=token_budget, + ) + print(f"\nBackground preview:\n{background[:500]}...") + + return OptimizationResult( + skill_name=skill_name, + original_score=original_score, + optimized_score=original_score, + improvement=0.0, + original_content=original_content, + optimized_content=original_content, + original_token_count=total_original_tokens, + optimized_token_count=total_original_tokens, + token_reduction_pct=0.0, + diff_summary="Dry run - no optimization performed", + val_scores={}, + mlflow_run_id=None, + gepa_result=None, + components=dict(seed_candidate), + original_components=dict(seed_candidate), + tool_map=tool_map, + evaluator_type="skillbench", + skillbench_side_info=si_by_id, + ) + + # Evaluate original and capture per-task detail for baseline context + print(f"\nScoring baseline ({len(train)} tasks, ~5 LLM calls each)...") + original_score, original_per_task, si_by_id, si_by_input = _evaluate_on_tasks( + evaluator, seed_candidate, train, label="Baseline" + ) + + # 6. Build background and objective + background = build_skillbench_background( + skill_name, + total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + token_budget=token_budget, + ) + objective = ( + f"Refine and improve the existing '{skill_name}' skill. " + "Score is based on SKILL EFFECTIVENESS (35%) and TOKEN EFFICIENCY (25%). " + "Judge rationale in side_info explains exactly what failed. " + "Focus on what the agent would otherwise get wrong. " + "Be concise — remove redundant examples and verbose explanations." + ) + + # 7. Convert datasets to GEPA format + trainset = to_gepa_instances(train) + valset = to_gepa_instances(val) if val else None + + # 8. Multi-pass optimization loop + current_seed = dict(seed_candidate) + best = dict(seed_candidate) + best_score = original_score + last_result = None + total_metric_calls = 0 + improvement_threshold = 0.0005 + + print( + f"\n Starting multi-pass optimization (up to {max_passes} passes, " + f"{num_components} component(s), {config.engine.max_metric_calls} metric calls/pass)" + ) + + # estimate_pass_duration expects the model name string, not the callable + _est_reflection_lm = _reflection_model_name if _reflection_model_name else str(reflection_lm or DEFAULT_GEN_LM) + est_secs = estimate_pass_duration( + config.engine.max_metric_calls, + _est_reflection_lm, + total_original_tokens, + num_dataset_examples=len(train), + ) + if est_secs is not None: + est_mins = est_secs / 60 + if est_mins > 5: + print( + f" Estimated ~{est_mins:.0f} min/pass ({est_mins * max_passes:.0f} min total for {max_passes} passes)" + ) + + for pass_num in range(1, max_passes + 1): + print(f"\n --- Pass {pass_num}/{max_passes} (best score so far: {best_score:.4f}) ---") + + pass_config = copy.deepcopy(config) + + # Set per-pass checkpoint directory + if run_dir: + pass_config.engine.run_dir = f"{run_dir}/pass_{pass_num}" + + result = optimize_anything( + seed_candidate=current_seed, + evaluator=evaluator, + dataset=trainset, + valset=valset, + objective=objective, + background=background, + config=pass_config, + ) + total_metric_calls += result.total_metric_calls or 0 + + candidate = result.best_candidate + pass_score, _, _, _ = _evaluate_on_tasks(evaluator, candidate, train, label=f"Pass {pass_num}") + improvement = pass_score - best_score + + print(f" Pass {pass_num} score: {pass_score:.4f} (delta: {'+' if improvement >= 0 else ''}{improvement:.4f})") + + if pass_score > best_score + improvement_threshold: + best = dict(candidate) + best_score = pass_score + last_result = result + current_seed = dict(candidate) + else: + print(f" No significant improvement in pass {pass_num} -- stopping early.") + if last_result is None: + last_result = result + break + else: + print(f" Completed all {max_passes} passes.") + + if last_result is None: + last_result = result + + # 9. Extract results + optimized_content = best.get(SKILL_KEY, original_content) + optimized_token_count = sum(count_tokens(v) for v in best.values()) + + optimized_score = best_score + + val_scores: dict[str, float] = {} + if val: + _, val_scores, _, _ = _evaluate_on_tasks(evaluator, best, val, label="Validation") + + token_reduction_pct = ( + (total_original_tokens - optimized_token_count) / total_original_tokens * 100 + if total_original_tokens > 0 + else 0.0 + ) + + diff_summary = _compute_diff_summary(original_content, optimized_content) + + # 10. MLflow logging (best-effort) + mlflow_run_id = None + try: + import mlflow + + stc = SkillTestConfig() + setup_mlflow(stc) + with mlflow.start_run(run_name=f"{skill_name}_optimize_{preset}"): + mlflow.set_tags( + { + "optimizer": "gepa", + "skill_name": skill_name, + "preset": preset, + "evaluator_type": "skillbench", + } + ) + mlflow.log_metrics( + { + "original_score": original_score, + "optimized_score": optimized_score, + "improvement": optimized_score - original_score, + "original_tokens": float(total_original_tokens), + "optimized_tokens": float(optimized_token_count), + "token_reduction_pct": token_reduction_pct, + "total_metric_calls": float(total_metric_calls), + } + ) + mlflow_run_id = mlflow.active_run().info.run_id + except Exception: + pass + + # Capture final side_info for review output + _, _, final_si_by_id, _ = _evaluate_on_tasks(evaluator, best, train, label="Final eval") + + return OptimizationResult( + skill_name=skill_name, + original_score=original_score, + optimized_score=optimized_score, + improvement=optimized_score - original_score, + original_content=original_content, + optimized_content=optimized_content, + original_token_count=total_original_tokens, + optimized_token_count=optimized_token_count, + token_reduction_pct=token_reduction_pct, + diff_summary=diff_summary, + val_scores=val_scores, + mlflow_run_id=mlflow_run_id, + gepa_result=last_result, + components=dict(best), + original_components=dict(seed_candidate), + tool_map=tool_map, + evaluator_type="skillbench", + skillbench_side_info=final_si_by_id, + ) diff --git a/.test/src/skill_test/optimize/skillbench_evaluator.py b/.test/src/skill_test/optimize/skillbench_evaluator.py new file mode 100644 index 00000000..d13601e8 --- /dev/null +++ b/.test/src/skill_test/optimize/skillbench_evaluator.py @@ -0,0 +1,524 @@ +"""SkillBench evaluator: measure skill effectiveness via WITH vs WITHOUT comparison. + +Evaluates skills by measuring agent performance WITH the skill vs WITHOUT it +on real tasks. Uses MLflow judges as the primary scoring mechanism — judges +provide both scores AND rich rationale for GEPA's reflection LM. + + Phase 1: WITH-SKILL -- LLM generates response with SKILL.md in context + Phase 2: WITHOUT-SKILL -- LLM generates response with NO skill (cached once) + Phase 3: JUDGE -- quality_judge scores both, effectiveness derived from delta + +Scoring weights: + 40% Skill Effectiveness (quality_with - quality_without delta) + 30% Absolute Quality (quality_with score from judge) + 5% Structure (syntax validity) + 25% Token Efficiency (smaller candidates score higher) +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import threading +import time +from typing import Any, Callable + +from mlflow.entities import Feedback + +from ..scorers.universal import python_syntax, sql_syntax, no_hallucinated_apis +from .judges import ( + JudgeFeedback, + create_skill_quality_judge, + create_regression_judge, + run_judge_safe, + _safe_parse_score, + completion_with_fallback, +) +from .utils import count_tokens + +logger = logging.getLogger(__name__) + + +def _prompt_hash(prompt: str) -> str: + """Stable hash for caching baseline results by prompt.""" + return hashlib.sha256(prompt.encode()).hexdigest()[:16] + + +class _RateLimiter: + """Thread-safe token-bucket rate limiter for LLM API calls.""" + + def __init__(self, max_concurrent: int = 2, min_interval: float = 1.0): + self._semaphore = threading.Semaphore(max_concurrent) + self._min_interval = min_interval + self._lock = threading.Lock() + self._last_call: float = 0.0 + + def acquire(self) -> None: + self._semaphore.acquire() + with self._lock: + now = time.monotonic() + wait = self._last_call + self._min_interval - now + if wait > 0: + time.sleep(wait) + self._last_call = time.monotonic() + + def release(self) -> None: + self._semaphore.release() + + +# Module-level rate limiter shared across evaluator instances. +_rate_limiter = _RateLimiter(max_concurrent=4, min_interval=0.2) + + +def _completion_with_backoff(*, max_retries: int = 3, **kwargs) -> Any: + """Call litellm.completion with rate limiting and model fallback. + + Uses the centralized completion_with_fallback which handles: + - Rate limit errors with exponential backoff + - Model fallback chain on persistent rate limits + - AI Gateway routing when configured + """ + _rate_limiter.acquire() + try: + return completion_with_fallback(max_retries=max_retries, **kwargs) + finally: + _rate_limiter.release() + + +def _run_structure_scorers(text: str) -> float: + """Run structure validation scorers on text, return 0.0-1.0 composite.""" + outputs = {"response": text} + scores: list[float] = [] + for scorer_fn in [python_syntax, sql_syntax, no_hallucinated_apis]: + try: + result = scorer_fn(outputs=outputs) + if isinstance(result, list): + for fb in result: + if fb.value == "yes": + scores.append(1.0) + elif fb.value == "no": + scores.append(0.0) + elif isinstance(result, Feedback): + if result.value == "yes": + scores.append(1.0) + elif result.value == "no": + scores.append(0.0) + except Exception: + pass + return sum(scores) / len(scores) if scores else 1.0 + + +def _effectiveness_score(verdict: str | float) -> float: + """Convert effectiveness verdict to numeric score for weighting.""" + if isinstance(verdict, (int, float)): + return max(0.0, min(1.0, float(verdict))) + v = str(verdict).strip().lower() + if v == "improved": + return 1.0 + elif v == "same": + return 0.5 + elif v == "regressed": + return 0.0 + # Fallback: try bool-like + if v in ("yes", "true"): + return 1.0 + if v in ("no", "false"): + return 0.0 + return 0.5 + + +class SkillBenchEvaluator: + """GEPA-compatible evaluator using judges for scoring + diagnostics. + + Args: + gen_model: LLM model for generating responses. Required. + original_token_counts: Token counts of original artifacts for efficiency scoring. + token_budget: Hard token ceiling; candidates exceeding this are penalized. + skill_guidelines: Deduplicated guidelines from ground_truth.yaml for the quality judge. + judge_model: LLM model for judges. Defaults to GEPA_JUDGE_LM env + or databricks/databricks-claude-sonnet-4-6. + """ + + def __init__( + self, + gen_model: str, + original_token_counts: dict[str, int] | None = None, + token_budget: int | None = None, + skill_guidelines: list[str] | None = None, + judge_model: str | None = None, + tool_context: str | None = None, + ): + if not gen_model: + raise ValueError("SkillBench evaluator requires a gen_model. Pass --gen-model or set GEPA_GEN_LM env var.") + self.gen_model = gen_model + self._baseline_response_cache: dict[str, str] = {} + self._baseline_judge_cache: dict[str, JudgeFeedback] = {} + self._original_token_counts = original_token_counts or {} + self._total_original_tokens = sum(self._original_token_counts.values()) + self._token_budget = token_budget + self._tool_context = tool_context or "" + + # Create judge instances with configurable model + self._quality_judge = create_skill_quality_judge(skill_guidelines, judge_model=judge_model) + self._regression_judge = create_regression_judge(judge_model=judge_model) + + def _generate_response(self, prompt: str, skill_context: str | None = None) -> str: + """Generate a response with or without skill context.""" + messages = [] + if skill_context: + messages.append( + { + "role": "system", + "content": ( + "Use ONLY the following skill documentation to answer " + "the user's question. Do not use any other knowledge.\n\n" + f"{skill_context}" + ), + } + ) + messages.append({"role": "user", "content": prompt}) + + resp = _completion_with_backoff( + model=self.gen_model, + messages=messages, + temperature=0, + ) + return resp.choices[0].message.content or "" + + def _get_baseline_response(self, prompt: str) -> str: + """Get WITHOUT-skill baseline response, computing once then caching.""" + key = _prompt_hash(prompt) + if key not in self._baseline_response_cache: + response = self._generate_response(prompt, skill_context=None) + self._baseline_response_cache[key] = response + return self._baseline_response_cache[key] + + def __call__( + self, + candidate: dict[str, str], + example: dict, + ) -> tuple[float, dict]: + """Evaluate a candidate skill against a single task example. + + GEPA-compatible signature: (candidate, example) -> (score, side_info) + """ + skill_md = candidate.get("skill_md", "") + + # Build combined context: skill + read-only tool descriptions + # During skill optimization, tools come from self._tool_context (read-only). + # During tool optimization, tools come from candidate keys (optimizable). + tool_parts = [] + for key in sorted(candidate): + if key.startswith("tools_"): + tool_parts.append(candidate[key]) + + full_context = skill_md + if tool_parts: + full_context += "\n\n## Available MCP Tools\n\n" + "\n\n".join(tool_parts) + elif self._tool_context: + full_context += "\n\n## Available MCP Tools\n\n" + self._tool_context + + prompt = example.get("input", "") + + # Decode expectations + expectations: dict[str, Any] = {} + expectations_json = example.get("additional_context", {}).get("expectations", "") + if expectations_json: + try: + expectations = json.loads(expectations_json) + except (json.JSONDecodeError, TypeError): + pass + + if not prompt or not expectations: + return 0.0, {"_error": "No prompt or expectations for this task"} + + # Phase 1: Generate WITH-skill response + with_response = self._generate_response(prompt, skill_context=full_context) + + # Phase 2: Generate WITHOUT-skill response (cached) + without_response = self._get_baseline_response(prompt) + + # Phase 3: Judge-driven scoring + facts = expectations.get("expected_facts", []) + patterns = expectations.get("expected_patterns", []) + guidelines = expectations.get("guidelines", []) + + # Build flat strings for judge templates — make_judge only supports + # top-level {{ inputs }}, {{ outputs }}, {{ expectations }} variables. + facts_str = "\n".join(f"- {f}" for f in facts) if facts else "None specified" + patterns_str = ( + "\n".join( + f"- {p}" if isinstance(p, str) else f"- {p.get('description', p.get('pattern', ''))}" for p in patterns + ) + if patterns + else "None specified" + ) + guidelines_str = "\n".join(f"- {g}" for g in guidelines) if guidelines else "None specified" + + expectations_text = ( + f"Expected facts:\n{facts_str}\n\nExpected patterns:\n{patterns_str}\n\nGuidelines:\n{guidelines_str}" + ) + + # make_judge requires expectations as dict, inputs/outputs as Any. + # The template renders {{ expectations }} as the dict's string repr, + # so we pack our formatted text into a single-key dict. + expectations_dict = {"criteria": expectations_text} + + # Quality judge: score WITH response + quality_with_fb = run_judge_safe( + self._quality_judge, + inputs=prompt, + outputs=with_response, + expectations=expectations_dict, + name="quality_with", + ) + + # Quality judge: score WITHOUT response (cached — baseline never changes) + baseline_key = _prompt_hash(prompt) + if baseline_key not in self._baseline_judge_cache: + self._baseline_judge_cache[baseline_key] = run_judge_safe( + self._quality_judge, + inputs=prompt, + outputs=without_response, + expectations=expectations_dict, + name="quality_without", + ) + quality_without_fb = self._baseline_judge_cache[baseline_key] + + # Parse scores + score_with = _safe_parse_score(quality_with_fb.value) + score_without = _safe_parse_score(quality_without_fb.value) + effectiveness_delta = score_with - score_without + + # Derive effectiveness verdict from quality delta (no LLM call needed) + if effectiveness_delta > 0.05: + effectiveness_verdict = 1.0 # improved + elif effectiveness_delta < -0.05: + effectiveness_verdict = 0.0 # regressed + else: + effectiveness_verdict = 0.5 # same + + # Structure validation on the skill itself + structure = _run_structure_scorers(skill_md) if skill_md else 1.0 + + # Token efficiency scoring + total_candidate_tokens = sum(count_tokens(v) for v in candidate.values()) + + if self._total_original_tokens > 0: + ratio = total_candidate_tokens / self._total_original_tokens + if ratio <= 1.0: + efficiency = 1.0 + 0.15 * (1.0 - ratio) + else: + efficiency = max(0.0, 2.0 - ratio) + + if self._token_budget and total_candidate_tokens > self._token_budget: + over_ratio = total_candidate_tokens / self._token_budget + efficiency = min(efficiency, max(0.0, 2.0 - over_ratio)) + else: + efficiency = 1.0 + + # Weighted final score + final_score = 0.40 * max(0.0, effectiveness_delta) + 0.30 * score_with + 0.05 * structure + 0.25 * efficiency + + # Build side info with FULL judge rationale (not truncated!) + reference_answer = example.get("answer", "") + + side_info: dict[str, Any] = {} + + # Task context + if prompt: + side_info["Task"] = prompt[:200] + + # Full judge feedback — the critical fix for GEPA optimization + side_info["Judge_quality_with"] = { + "score": score_with, + "rationale": quality_with_fb.rationale, + } + side_info["Judge_quality_without"] = { + "score": score_without, + "rationale": quality_without_fb.rationale, + } + side_info["Judge_effectiveness"] = { + "verdict": ( + "improved" if effectiveness_verdict == 1.0 else "regressed" if effectiveness_verdict == 0.0 else "same" + ), + "delta": effectiveness_delta, + } + + # Expected vs Actual for GEPA reflection + if reference_answer: + side_info["Expected"] = reference_answer[:500] + if with_response: + side_info["Actual"] = with_response[:500] + + # Score breakdown + side_info["scores"] = { + "quality_with": score_with, + "quality_without": score_without, + "skill_effectiveness": effectiveness_delta, + "effectiveness_verdict": effectiveness_verdict, + "structure": structure, + "token_efficiency": efficiency, + "final": final_score, + } + + # Token counts for GEPA Pareto tracking + side_info["token_counts"] = { + "candidate_total": total_candidate_tokens, + "original_total": self._total_original_tokens, + } + if self._token_budget: + side_info["token_counts"]["budget"] = self._token_budget + + # Derive diagnostic labels from judge verdicts for backward compat + if effectiveness_delta < -0.05: + side_info["Error"] = ( + f"REGRESSION: skill_effectiveness delta={effectiveness_delta:.2f} " + f"(with={score_with:.2f}, without={score_without:.2f})" + ) + side_info["skill_md_specific_info"] = { + "Regressions": quality_with_fb.rationale, + } + elif score_with < 0.5: + side_info["Error"] = ( + f"NEEDS_SKILL: quality_with={score_with:.2f}, missing content. Judge: {quality_with_fb.rationale[:200]}" + ) + + return final_score, side_info + + +def _collect_skill_guidelines(skill_name: str) -> list[str]: + """Collect and deduplicate all guidelines from a skill's ground_truth.yaml.""" + from pathlib import Path + import yaml + + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + if not gt_path.exists(): + return [] + + try: + with open(gt_path) as f: + data = yaml.safe_load(f) or {} + except Exception: + return [] + + seen: set[str] = set() + guidelines: list[str] = [] + for tc in data.get("test_cases", []): + for g in tc.get("expectations", {}).get("guidelines", []): + g_norm = g.strip() + if g_norm and g_norm not in seen: + seen.add(g_norm) + guidelines.append(g_norm) + + return guidelines + + +def create_skillbench_evaluator( + skill_name: str, + gen_model: str, + original_token_counts: dict[str, int] | None = None, + token_budget: int | None = None, + judge_model: str | None = None, + tool_context: str | None = None, +) -> Callable: + """Factory for SkillBench-style evaluator. + + Returns a GEPA-compatible callable: (candidate, example) -> (score, side_info) + + Judges are always enabled — they are the primary scoring mechanism. + Guidelines from ground_truth.yaml are incorporated into the quality judge. + + Args: + skill_name: Name of the skill being evaluated. + gen_model: LLM model for generating responses. Required. + original_token_counts: Token counts of original artifacts for efficiency scoring. + token_budget: Hard token ceiling; candidates exceeding this are penalized. + judge_model: LLM model for judges. Defaults to GEPA_JUDGE_LM env + or databricks/databricks-claude-sonnet-4-6. + tool_context: Read-only tool descriptions included in generation context + but not optimized. Used during skill optimization so tools provide + context without being GEPA components. + """ + skill_guidelines = _collect_skill_guidelines(skill_name) + if skill_guidelines: + logger.info( + "Loaded %d domain guidelines for quality judge", + len(skill_guidelines), + ) + + from .judges import DEFAULT_JUDGE_LM + + effective_judge_model = judge_model or DEFAULT_JUDGE_LM + logger.info("Judge model: %s", effective_judge_model) + + return SkillBenchEvaluator( + gen_model=gen_model, + original_token_counts=original_token_counts, + token_budget=token_budget, + skill_guidelines=skill_guidelines, + judge_model=judge_model, + tool_context=tool_context, + ) + + +def build_skillbench_background( + skill_name: str, + original_token_count: int, + component_names: list[str] | None = None, + baseline_scores: dict[str, float] | None = None, + baseline_side_info: dict[str, dict] | None = None, + token_budget: int | None = None, +) -> str: + """Build concise GEPA reflection context for SkillBench optimization. + + Kept short so GEPA's reflection LM spends its context on the per-example + diagnostics (judge rationale) rather than methodology. + """ + baseline_desc = "" + if baseline_scores: + mean_score = sum(baseline_scores.values()) / len(baseline_scores) + baseline_desc = f"\nBASELINE: mean {mean_score:.3f} across {len(baseline_scores)} tasks." + + if baseline_side_info: + needs_skill_ids = [] + regression_ids = [] + for tid, info in baseline_side_info.items(): + error = info.get("Error", "") + if "NEEDS_SKILL" in error: + needs_skill_ids.append(tid) + if "REGRESSION" in error: + regression_ids.append(tid) + if needs_skill_ids: + baseline_desc += f"\n NEEDS_SKILL ({len(needs_skill_ids)} tasks): {', '.join(needs_skill_ids[:5])}" + if regression_ids: + baseline_desc += f"\n REGRESSION ({len(regression_ids)} tasks): {', '.join(regression_ids[:5])}" + + components_desc = "" + if component_names and any(c.startswith("tools_") for c in component_names): + tool_modules = [c.replace("tools_", "") for c in component_names if c.startswith("tools_")] + components_desc = ( + f"\nAlso optimizing MCP tool descriptions for: {', '.join(tool_modules)}. " + "Keep docstrings accurate and concise — every token counts toward the budget." + ) + + token_desc = ( + f"\nTOKEN EFFICIENCY (25% of score): Current artifacts total {original_token_count:,} tokens. " + "Smaller candidates score HIGHER. Be ruthlessly concise." + ) + if token_budget: + token_desc += f"\nTOKEN BUDGET: {token_budget:,} tokens. Candidates exceeding this are heavily penalized." + + return ( + f"You are refining SKILL.md for '{skill_name}'.\n" + "The skill is scored by MLflow judges that evaluate how much it HELPS an agent.\n" + "Judge rationale in side_info explains exactly WHAT failed and WHY.\n" + "Use Judge_quality_with to see missing facts/patterns.\n" + "Use Judge_effectiveness to see if the skill helped or hurt.\n" + "Focus on: specific API syntax, version requirements, non-obvious patterns.\n" + "Do NOT add generic knowledge the agent already has." + f"{baseline_desc}" + f"{components_desc}" + f"{token_desc}" + ) diff --git a/.test/src/skill_test/optimize/splitter.py b/.test/src/skill_test/optimize/splitter.py new file mode 100644 index 00000000..0fc6f170 --- /dev/null +++ b/.test/src/skill_test/optimize/splitter.py @@ -0,0 +1,296 @@ +"""Train/val dataset splitting for GEPA optimization. + +Loads ground_truth.yaml test cases and splits them into train/val sets, +stratified by metadata.category when possible. + +GEPA's DefaultDataInst format: {"input": str, "additional_context": dict[str, str], "answer": str} + +We store our internal task representation alongside, and convert to GEPA format +when needed via to_gepa_instances(). +""" + +import json +import random +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, TypedDict + +from ..dataset import EvalRecord, get_dataset_source + + +class SkillTask(TypedDict, total=False): + """Internal task representation (superset of GEPA DefaultDataInst).""" + + id: str + input: str # The prompt (maps to DefaultDataInst.input) + answer: str # Expected response (maps to DefaultDataInst.answer) + additional_context: dict[str, str] # Extra context (maps to DefaultDataInst.additional_context) + expectations: dict[str, Any] # Scorer expectations (not sent to GEPA directly) + metadata: dict[str, Any] # Category, difficulty, etc. + + +def _summarize_expectations(expectations: dict[str, Any]) -> str: + """Produce a human-readable summary of what a task tests. + + Included in additional_context so GEPA's reflection LM understands + what each test case is checking without parsing JSON. + """ + parts = [] + + patterns = expectations.get("expected_patterns", []) + if patterns: + descs = [] + for p in patterns: + if isinstance(p, str): + descs.append(p[:40]) + elif isinstance(p, dict): + descs.append(p.get("description", p.get("pattern", "")[:40])) + parts.append(f"Patterns: {', '.join(descs)}") + + facts = expectations.get("expected_facts", []) + if facts: + parts.append(f"Facts: {', '.join(str(f) for f in facts)}") + + guidelines = expectations.get("guidelines", []) + if guidelines: + parts.append(f"Guidelines: {'; '.join(str(g) for g in guidelines[:3])}") + + return " | ".join(parts) if parts else "No specific expectations" + + +def _record_to_task(record: EvalRecord) -> SkillTask: + """Convert an EvalRecord to our internal task format.""" + task: SkillTask = { + "id": record.id, + "input": record.inputs.get("prompt", ""), + "additional_context": {}, + "answer": "", + "metadata": record.metadata or {}, + } + if record.outputs: + task["answer"] = record.outputs.get("response", "") + if record.expectations: + task["expectations"] = record.expectations + # Also encode expectations into additional_context for GEPA reflection + task["additional_context"]["expectations"] = json.dumps(record.expectations) + # Human-readable summary for GEPA's reflection LM + task["additional_context"]["evaluation_criteria"] = _summarize_expectations(record.expectations) + return task + + +def to_gepa_instances(tasks: list[SkillTask]) -> list[dict[str, Any]]: + """Convert internal tasks to GEPA DefaultDataInst format. + + Returns list of {"input": str, "additional_context": dict[str,str], "answer": str} + """ + return [ + { + "input": t["input"], + "additional_context": t.get("additional_context", {}), + "answer": t.get("answer", ""), + } + for t in tasks + ] + + +def create_gepa_datasets( + skill_name: str, + val_ratio: float = 0.2, + base_path: Path | None = None, + seed: int = 42, +) -> tuple[list[SkillTask], list[SkillTask] | None]: + """Load ground_truth.yaml, stratify by metadata.category, split into train/val. + + For skills with <5 test cases: uses all as train, val=None (single-task mode). + For skills with >=5 test cases: stratified train/val split (generalization mode). + + Args: + skill_name: Name of the skill to load test cases for + val_ratio: Fraction of test cases to hold out for validation + base_path: Override base path for skills directory + seed: Random seed for reproducible splits + + Returns: + Tuple of (train_tasks, val_tasks). val_tasks is None if <5 test cases. + """ + source = get_dataset_source(skill_name, base_path) + records = source.load() + + if not records: + return [], None + + tasks = [_record_to_task(r) for r in records] + + # Too few for a meaningful val split + if len(tasks) < 5: + return tasks, None + + # Stratify by category + by_category: dict[str, list[SkillTask]] = defaultdict(list) + for task in tasks: + cat = task.get("metadata", {}).get("category", "_uncategorized") + by_category[cat].append(task) + + rng = random.Random(seed) + train: list[SkillTask] = [] + val: list[SkillTask] = [] + + for _cat, cat_tasks in by_category.items(): + rng.shuffle(cat_tasks) + n_val = max(1, int(len(cat_tasks) * val_ratio)) + + # Ensure at least 1 train sample per category + if len(cat_tasks) - n_val < 1: + n_val = len(cat_tasks) - 1 + + if n_val <= 0: + train.extend(cat_tasks) + else: + val.extend(cat_tasks[:n_val]) + train.extend(cat_tasks[n_val:]) + + # If val ended up empty, fall back + if not val: + return tasks, None + + return train, val + + +def create_cross_skill_dataset( + skill_names: list[str] | None = None, + max_per_skill: int = 5, + base_path: Path | None = None, + seed: int = 42, +) -> list[SkillTask]: + """Create a merged dataset from multiple skills for cross-skill tool optimization. + + If ``skill_names`` is None, discovers all skills that have a ``ground_truth.yaml``. + Loads tasks from each, caps at ``max_per_skill``, and tags each task with + ``metadata["source_skill"]``. + + Args: + skill_names: Specific skills to include. None = auto-discover all. + max_per_skill: Maximum tasks per skill to keep the dataset balanced. + base_path: Override base path for skills directory. + seed: Random seed for reproducible sampling. + + Returns: + Merged list of SkillTask dicts, each tagged with source_skill. + """ + if base_path is None: + base_path = Path(".test/skills") + + # Auto-discover skills with ground_truth.yaml + if skill_names is None: + if not base_path.exists(): + return [] + skill_names = sorted( + d.name + for d in base_path.iterdir() + if d.is_dir() and (d / "ground_truth.yaml").exists() and not d.name.startswith("_") + ) + + if not skill_names: + return [] + + rng = random.Random(seed) + merged: list[SkillTask] = [] + + for skill_name in skill_names: + try: + source = get_dataset_source(skill_name, base_path) + records = source.load() + except Exception: + continue + + tasks = [_record_to_task(r) for r in records] + + # Tag with source skill + for t in tasks: + meta = t.get("metadata", {}) + meta["source_skill"] = skill_name + t["metadata"] = meta + + # Cap per skill + if len(tasks) > max_per_skill: + rng.shuffle(tasks) + tasks = tasks[:max_per_skill] + + merged.extend(tasks) + + return merged + + +def generate_bootstrap_tasks(skill_name: str, base_path: Path | None = None) -> list[SkillTask]: + """Generate synthetic tasks from a SKILL.md when no ground_truth.yaml exists. + + Parses the SKILL.md for documented patterns and generates basic test prompts + that exercise each pattern. + + Args: + skill_name: Name of the skill + base_path: Override base path for skills directory + + Returns: + List of synthetic SkillTask dicts + """ + if base_path is None: + # Find repo root for path resolution + from .utils import find_repo_root + + repo_root = find_repo_root() + skill_md_candidates = [ + repo_root / ".claude" / "skills" / skill_name / "SKILL.md", + repo_root / "databricks-skills" / skill_name / "SKILL.md", + ] + else: + skill_md_candidates = [base_path.parent / skill_name / "SKILL.md"] + + skill_content = None + for path in skill_md_candidates: + if path.exists(): + skill_content = path.read_text() + break + + if not skill_content: + return [] + + tasks: list[SkillTask] = [] + + # Extract h2/h3 headers as topic areas + headers = re.findall(r"^#{2,3}\s+(.+)$", skill_content, re.MULTILINE) + + for i, header in enumerate(headers): + tasks.append( + { + "id": f"bootstrap_{i:03d}", + "input": f"Using the {skill_name} skill, help me with: {header}", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ) + + # Extract code block language hints for targeted prompts + code_langs = set(re.findall(r"```(\w+)\n", skill_content)) + for lang in code_langs: + tasks.append( + { + "id": f"bootstrap_lang_{lang}", + "input": f"Show me a {lang} example using {skill_name} patterns", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ) + + return tasks or [ + { + "id": "bootstrap_general", + "input": f"Explain the key patterns in {skill_name}", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ] diff --git a/.test/src/skill_test/optimize/tools.py b/.test/src/skill_test/optimize/tools.py new file mode 100644 index 00000000..f519d099 --- /dev/null +++ b/.test/src/skill_test/optimize/tools.py @@ -0,0 +1,265 @@ +"""Tool description extraction and writing for MCP server tools. + +Extracts @mcp.tool docstrings from Python source files, formats them for GEPA +optimization, and writes optimized descriptions back to source files. + +Each tool module (sql.py, compute.py, etc.) becomes one GEPA component so +GEPA's round-robin selector cycles through modules efficiently. +""" + +import ast +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +# Separator used between tools within a module's GEPA component text +TOOL_SEPARATOR = "\n\n### TOOL: " + +MCP_TOOLS_DIR = Path(__file__).resolve().parents[5] / "databricks-mcp-server" / "databricks_mcp_server" / "tools" + + +@dataclass +class ToolDescription: + """A single tool's description extracted from source.""" + + name: str # Function name + docstring: str # The full docstring text + module: str # Module name (e.g., "sql", "compute") + lineno: int # Line number of the function def + source_path: Path # Path to the source file + + +def _find_tools_dir() -> Path: + """Find the MCP tools directory.""" + if MCP_TOOLS_DIR.exists(): + return MCP_TOOLS_DIR + # Fallback: search from repo root + from .utils import find_repo_root + + repo_root = find_repo_root() + candidate = repo_root / "databricks-mcp-server" / "databricks_mcp_server" / "tools" + if candidate.exists(): + return candidate + raise FileNotFoundError( + "Could not find MCP tools directory. Expected at databricks-mcp-server/databricks_mcp_server/tools/" + ) + + +def extract_tool_descriptions( + modules: list[str] | None = None, + tools_dir: Path | None = None, +) -> dict[str, list[ToolDescription]]: + """Extract all @mcp.tool docstrings from MCP server tool files. + + Args: + modules: Optional list of module names to extract (e.g., ["sql", "compute"]). + If None, extracts all modules. + tools_dir: Override path to tools directory. + + Returns: + Dict mapping module_name -> list of ToolDescription. + """ + if tools_dir is None: + tools_dir = _find_tools_dir() + + results: dict[str, list[ToolDescription]] = {} + + for py_file in sorted(tools_dir.glob("*.py")): + module_name = py_file.stem + if module_name == "__init__": + continue + if modules and module_name not in modules: + continue + + source = py_file.read_text() + tree = ast.parse(source) + + tool_descs = [] + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue + # Check if decorated with @mcp.tool + for dec in node.decorator_list: + is_mcp_tool = False + if isinstance(dec, ast.Attribute) and isinstance(dec.value, ast.Name): + if dec.value.id == "mcp" and dec.attr == "tool": + is_mcp_tool = True + elif isinstance(dec, ast.Name) and dec.id == "mcp": + is_mcp_tool = True + if is_mcp_tool: + docstring = ast.get_docstring(node) or "" + tool_descs.append( + ToolDescription( + name=node.name, + docstring=docstring, + module=module_name, + lineno=node.lineno, + source_path=py_file, + ) + ) + break + + if tool_descs: + results[module_name] = tool_descs + + return results + + +def tools_to_gepa_components( + tool_map: dict[str, list[ToolDescription]], + per_module: bool = True, +) -> dict[str, str]: + """Convert extracted tool descriptions into GEPA component text blocks. + + Args: + tool_map: Output of extract_tool_descriptions() + per_module: If True, one GEPA component per module (e.g., "tools_sql"). + If False, all tools in a single "tool_descriptions" component. + + Returns: + Dict mapping component_name -> text block. + """ + if per_module: + components = {} + for module_name, tools in tool_map.items(): + text_parts = [] + for td in tools: + text_parts.append(f"### TOOL: {td.name}\n{td.docstring}") + components[f"tools_{module_name}"] = "\n\n".join(text_parts) + return components + else: + all_parts = [] + for module_name, tools in sorted(tool_map.items()): + for td in tools: + all_parts.append(f"### TOOL: {td.name} (module: {module_name})\n{td.docstring}") + return {"tool_descriptions": "\n\n".join(all_parts)} + + +def parse_gepa_component(component_text: str) -> dict[str, str]: + """Parse a GEPA component text block back into individual tool descriptions. + + Args: + component_text: Text block with ### TOOL: markers + + Returns: + Dict mapping tool_name -> optimized docstring + """ + tools = {} + parts = re.split(r"### TOOL:\s*", component_text) + for part in parts: + part = part.strip() + if not part: + continue + # First line has the tool name (possibly with module annotation) + lines = part.split("\n", 1) + name_line = lines[0].strip() + # Remove module annotation if present: "execute_sql (module: sql)" + name = re.match(r"(\w+)", name_line).group(1) if re.match(r"(\w+)", name_line) else name_line + docstring = lines[1].strip() if len(lines) > 1 else "" + tools[name] = docstring + return tools + + +def write_tool_descriptions( + optimized: dict[str, str], + tool_map: dict[str, list[ToolDescription]], +) -> list[Path]: + """Write optimized docstrings back to MCP server source files. + + Uses AST to locate the exact docstring positions and replaces them + in the source text while preserving all other code. + + Args: + optimized: Dict mapping tool_name -> optimized docstring text + tool_map: Original extraction map (for source file locations) + + Returns: + List of modified file paths + """ + # Group updates by file + updates_by_file: dict[Path, list[tuple[ToolDescription, str]]] = {} + for _module_name, tools in tool_map.items(): + for td in tools: + if td.name in optimized: + updates_by_file.setdefault(td.source_path, []).append((td, optimized[td.name])) + + modified_files = [] + for file_path, updates in updates_by_file.items(): + source = file_path.read_text() + tree = ast.parse(source) + source_lines = source.splitlines(keepends=True) + + # Process updates in reverse line order to preserve positions + updates_sorted = sorted(updates, key=lambda x: x[0].lineno, reverse=True) + + for td, new_docstring in updates_sorted: + # Find the function node + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == td.name and node.lineno == td.lineno: + # Find the docstring node (first Expr with a Constant string) + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + and isinstance(node.body[0].value.value, str) + ): + doc_node = node.body[0] + # Get the docstring's line range + start_line = doc_node.lineno - 1 # 0-indexed + end_line = doc_node.end_lineno # exclusive + + # Detect indentation from the original docstring line + original_line = source_lines[start_line] + indent = re.match(r"(\s*)", original_line).group(1) + + # Build new docstring with proper indentation + new_doc_lines = [f'{indent}"""\n'] + for line in new_docstring.split("\n"): + if line.strip(): + new_doc_lines.append(f"{indent}{line}\n") + else: + new_doc_lines.append("\n") + new_doc_lines.append(f'{indent}"""\n') + + # Replace lines + source_lines[start_line:end_line] = new_doc_lines + break + + new_source = "".join(source_lines) + + # Validate the new source parses + try: + ast.parse(new_source) + except SyntaxError as e: + print(f"WARNING: Optimized source for {file_path.name} has syntax error: {e}") + print("Skipping this file.") + continue + + file_path.write_text(new_source) + modified_files.append(file_path) + + return modified_files + + +def list_tool_modules(tools_dir: Path | None = None) -> list[str]: + """List available tool module names.""" + if tools_dir is None: + tools_dir = _find_tools_dir() + return sorted(f.stem for f in tools_dir.glob("*.py") if f.stem != "__init__") + + +def get_tool_stats(tools_dir: Path | None = None) -> dict[str, Any]: + """Get statistics about available MCP tools.""" + tool_map = extract_tool_descriptions(tools_dir=tools_dir) + total_tools = sum(len(tools) for tools in tool_map.values()) + total_chars = sum(len(td.docstring) for tools in tool_map.values() for td in tools) + return { + "modules": len(tool_map), + "total_tools": total_tools, + "total_description_chars": total_chars, + "per_module": { + name: {"tools": len(tools), "chars": sum(len(td.docstring) for td in tools)} + for name, tools in tool_map.items() + }, + } diff --git a/.test/src/skill_test/optimize/utils.py b/.test/src/skill_test/optimize/utils.py new file mode 100644 index 00000000..464ee02c --- /dev/null +++ b/.test/src/skill_test/optimize/utils.py @@ -0,0 +1,69 @@ +"""Shared utilities for skill optimization. + +Extracted from evaluator.py — provides path resolution, token counting, +and the SKILL_KEY constant used across the optimization package. +""" + +from pathlib import Path + +import tiktoken + +SKILL_KEY = "skill_md" + + +# --------------------------------------------------------------------------- +# Path utilities +# --------------------------------------------------------------------------- + + +def find_repo_root() -> Path: + """Find the repo root by searching upward for .test/src/.""" + current = Path(__file__).resolve().parent + while current != current.parent: + if (current / ".test" / "src").exists(): + return current + if (current / "src" / "skill_test").exists() and current.name == ".test": + return current.parent + current = current.parent + return Path.cwd() + + +def find_skill_md(skill_name: str) -> Path | None: + """Locate the SKILL.md file for a given skill name.""" + repo_root = find_repo_root() + candidates = [ + repo_root / ".claude" / "skills" / skill_name / "SKILL.md", + repo_root / "databricks-skills" / skill_name / "SKILL.md", + ] + for p in candidates: + if p.exists(): + return p + return None + + +# --------------------------------------------------------------------------- +# Token utilities +# --------------------------------------------------------------------------- + + +def count_tokens(text: str) -> int: + """Count tokens using cl100k_base encoding.""" + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(text)) + + +def token_efficiency_score(candidate_text: str, original_token_count: int) -> float: + """Score based on how concise the candidate is vs. the original. + + Smaller than original = bonus up to 1.15, same size = 1.0, + larger = linear penalty to 0.0 at 2x. + """ + if original_token_count <= 0: + return 1.0 + enc = tiktoken.get_encoding("cl100k_base") + candidate_tokens = len(enc.encode(candidate_text)) + ratio = candidate_tokens / original_token_count + if ratio <= 1.0: + return 1.0 + 0.15 * (1.0 - ratio) + else: + return max(0.0, 2.0 - ratio) diff --git a/.test/src/skill_test/runners/compare.py b/.test/src/skill_test/runners/compare.py index 460d03db..d2d95b56 100644 --- a/.test/src/skill_test/runners/compare.py +++ b/.test/src/skill_test/runners/compare.py @@ -3,7 +3,7 @@ import json from datetime import datetime from pathlib import Path -from typing import Dict, Any, Optional, List +from typing import Dict, Optional, List from dataclasses import dataclass, asdict diff --git a/.test/src/skill_test/runners/evaluate.py b/.test/src/skill_test/runners/evaluate.py index 1dff1009..dc2530f4 100644 --- a/.test/src/skill_test/runners/evaluate.py +++ b/.test/src/skill_test/runners/evaluate.py @@ -4,7 +4,7 @@ from typing import Optional, Dict, Any, List import yaml import mlflow -from mlflow.genai.scorers import Guidelines, Safety +from mlflow.genai.scorers import Safety from ..config import SkillTestConfig from ..dataset import get_dataset_source @@ -25,7 +25,6 @@ file_existence, tool_sequence, category_limits, - get_trace_scorers, ) diff --git a/.test/src/skill_test/scorers/dynamic.py b/.test/src/skill_test/scorers/dynamic.py index 0c2e994e..adc8c164 100644 --- a/.test/src/skill_test/scorers/dynamic.py +++ b/.test/src/skill_test/scorers/dynamic.py @@ -1,4 +1,8 @@ -"""Dynamic scorers that use test case expectations from YAML config.""" +"""Dynamic scorers that use test case expectations from YAML config. + +DEPRECATED: For optimization, the quality judge in judges.py replaces this. +This module is kept for backward compatibility with runners/evaluate.py. +""" from mlflow.genai.scorers import scorer, Guidelines from mlflow.entities import Feedback diff --git a/.test/src/skill_test/scorers/trace.py b/.test/src/skill_test/scorers/trace.py index 2532a01d..8cd6bf13 100644 --- a/.test/src/skill_test/scorers/trace.py +++ b/.test/src/skill_test/scorers/trace.py @@ -12,8 +12,6 @@ from mlflow.entities import Feedback from mlflow.genai.scorers import scorer -from ..trace.models import TraceMetrics - @scorer def tool_count( diff --git a/.test/src/skill_test/scorers/universal.py b/.test/src/skill_test/scorers/universal.py index 3b839bf4..d1e5ac47 100644 --- a/.test/src/skill_test/scorers/universal.py +++ b/.test/src/skill_test/scorers/universal.py @@ -125,20 +125,22 @@ def no_hallucinated_apis(outputs: Dict[str, Any]) -> Feedback: @scorer -def expected_facts_present(outputs: Dict[str, Any], expectations: Dict[str, Any]) -> Feedback: - """Check if expected facts are mentioned in response.""" +def expected_facts_present(outputs: Dict[str, Any], expectations: Dict[str, Any]) -> List[Feedback]: + """Check if expected facts are mentioned in response (per-fact granularity).""" response = outputs.get("response", "").lower() expected_facts = expectations.get("expected_facts", []) if not expected_facts: - return Feedback(name="expected_facts", value="skip", rationale="No expected_facts defined") + return [Feedback(name="expected_facts", value="skip", rationale="No expected_facts defined")] - missing = [] + feedbacks = [] for fact in expected_facts: - if fact.lower() not in response: - missing.append(fact) - - if missing: - return Feedback(name="expected_facts", value="no", rationale=f"Missing facts: {missing}") - - return Feedback(name="expected_facts", value="yes", rationale=f"All {len(expected_facts)} expected facts present") + found = fact.lower() in response + feedbacks.append( + Feedback( + name=f"fact_{fact[:40]}", + value="yes" if found else "no", + rationale=f"{'Found' if found else 'Missing'}: {fact}", + ) + ) + return feedbacks diff --git a/.test/src/skill_test/trace/mlflow_integration.py b/.test/src/skill_test/trace/mlflow_integration.py index 62f8ffb1..179391f7 100644 --- a/.test/src/skill_test/trace/mlflow_integration.py +++ b/.test/src/skill_test/trace/mlflow_integration.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from .models import TraceMetrics, ToolCall, FileOperation +from .models import TraceMetrics, ToolCall from .parser import parse_and_compute_metrics diff --git a/.test/src/skill_test/trace/source.py b/.test/src/skill_test/trace/source.py index 74918875..8b7abef9 100644 --- a/.test/src/skill_test/trace/source.py +++ b/.test/src/skill_test/trace/source.py @@ -5,7 +5,6 @@ 2. Local fallback (~/.claude/projects/{hash}/*.jsonl) """ -import os import subprocess from dataclasses import dataclass from pathlib import Path diff --git a/.test/tests/test_optimize_e2e.py b/.test/tests/test_optimize_e2e.py new file mode 100644 index 00000000..422a8828 --- /dev/null +++ b/.test/tests/test_optimize_e2e.py @@ -0,0 +1,232 @@ +"""End-to-end tests for GEPA skill optimization using optimize_anything API. + +Unit tests run without API keys. E2E tests require GEPA reflection LM access. + +Run unit tests: + cd .test && uv run pytest tests/test_optimize_e2e.py -v -k "not TestOptimizeE2E" + +Run everything (slow, requires API key): + cd .test && uv run pytest tests/test_optimize_e2e.py -v -s +""" + +import pytest + +from skill_test.optimize.utils import token_efficiency_score, count_tokens, SKILL_KEY +from skill_test.optimize.splitter import create_gepa_datasets, generate_bootstrap_tasks, to_gepa_instances +from skill_test.optimize.asi import feedback_to_score, feedback_to_asi + +try: + from mlflow.entities import Feedback + HAS_MLFLOW = True +except ImportError: + Feedback = None + HAS_MLFLOW = False + +try: + from gepa.optimize_anything import GEPAConfig, EngineConfig, ReflectionConfig + HAS_GEPA = True +except ImportError: + HAS_GEPA = False + + +# -------------------------------------------------------------------------- +# Unit tests (no GEPA/LLM required) +# -------------------------------------------------------------------------- + +class TestTokenEfficiency: + def test_same_size_scores_one(self): + text = "Hello world, this is a test." + tokens = count_tokens(text) + assert token_efficiency_score(text, tokens) == 1.0 + + def test_smaller_scores_bonus(self): + # Smaller than original gets a bonus (up to 1.15) + score = token_efficiency_score("short", 100) + assert score > 1.0 + assert score <= 1.15 + + def test_double_size_scores_zero(self): + text = "word " * 200 + tokens = count_tokens(text) + assert token_efficiency_score(text + text, tokens) == pytest.approx(0.0, abs=0.05) + + def test_zero_original_returns_one(self): + assert token_efficiency_score("anything", 0) == 1.0 + + +class TestSplitter: + def test_model_serving_has_split(self): + try: + train, val = create_gepa_datasets("databricks-model-serving") + assert len(train) > 0 + if len(train) + (len(val) if val else 0) >= 5: + assert val is not None + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_reproducible_splits(self): + try: + t1, v1 = create_gepa_datasets("databricks-model-serving", seed=42) + t2, v2 = create_gepa_datasets("databricks-model-serving", seed=42) + assert [t["id"] for t in t1] == [t["id"] for t in t2] + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_tasks_have_correct_keys(self): + try: + train, _ = create_gepa_datasets("databricks-model-serving") + for task in train: + assert "id" in task + assert "input" in task + assert "answer" in task + assert "additional_context" in task + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_to_gepa_instances(self): + try: + train, _ = create_gepa_datasets("databricks-model-serving") + instances = to_gepa_instances(train) + assert len(instances) == len(train) + for inst in instances: + assert "input" in inst + assert "additional_context" in inst + assert "answer" in inst + assert "id" not in inst + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_bootstrap_tasks_generated(self): + tasks = generate_bootstrap_tasks("databricks-model-serving") + assert len(tasks) > 0 + for task in tasks: + assert "id" in task + assert "input" in task + + +@pytest.mark.skipif(not HAS_MLFLOW, reason="mlflow not installed") +class TestASI: + def test_yes_scores_one(self): + assert feedback_to_score(Feedback(name="test", value="yes")) == 1.0 + + def test_no_scores_zero(self): + assert feedback_to_score(Feedback(name="test", value="no")) == 0.0 + + def test_skip_returns_none(self): + assert feedback_to_score(Feedback(name="test", value="skip")) is None + + def test_feedback_to_asi_composite(self): + feedbacks = [ + Feedback(name="syntax", value="yes", rationale="Valid"), + Feedback(name="pattern", value="no", rationale="Missing X"), + Feedback(name="optional", value="skip", rationale="N/A"), + ] + score, si = feedback_to_asi(feedbacks) + assert score == pytest.approx(0.5) + assert si["syntax"]["score"] == 1.0 + assert si["pattern"]["score"] == 0.0 + assert si["optional"]["status"] == "skipped" + assert si["_summary"]["scored"] == 2 + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestConfig: + def test_presets_exist(self): + from skill_test.optimize.config import PRESETS + assert "quick" in PRESETS + assert "standard" in PRESETS + assert "thorough" in PRESETS + + def test_quick_has_fewer_calls(self): + from skill_test.optimize.config import PRESETS + assert PRESETS["quick"].engine.max_metric_calls < PRESETS["standard"].engine.max_metric_calls + + def test_presets_are_gepa_configs(self): + from skill_test.optimize.config import PRESETS + for name, cfg in PRESETS.items(): + assert isinstance(cfg, GEPAConfig), f"{name} is not GEPAConfig" + assert isinstance(cfg.engine, EngineConfig) + assert isinstance(cfg.reflection, ReflectionConfig) + + +class TestBootstrapMode: + def test_nonexistent_skill_returns_empty(self): + tasks = generate_bootstrap_tasks("nonexistent-skill-xyz") + assert tasks == [] + + def test_bootstrap_has_gepa_format(self): + tasks = generate_bootstrap_tasks("databricks-model-serving") + if not tasks: + pytest.skip("No SKILL.md found") + instances = to_gepa_instances(tasks) + for inst in instances: + assert isinstance(inst["input"], str) + assert isinstance(inst["additional_context"], dict) + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestToolExtraction: + def test_extract_tools(self): + from skill_test.optimize.tools import extract_tool_descriptions, get_tool_stats + stats = get_tool_stats() + assert stats["modules"] > 0 + assert stats["total_tools"] > 0 + + def test_tools_to_gepa_components(self): + from skill_test.optimize.tools import extract_tool_descriptions, tools_to_gepa_components + tool_map = extract_tool_descriptions(modules=["sql"]) + components = tools_to_gepa_components(tool_map) + assert "tools_sql" in components + assert "### TOOL:" in components["tools_sql"] + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestDryRun: + def test_dry_run_skill_only(self): + from skill_test.optimize.runner import optimize_skill + try: + result = optimize_skill("databricks-model-serving", preset="quick", dry_run=True) + assert result.improvement == 0.0 + assert result.original_content == result.optimized_content + assert result.gepa_result is None + assert result.original_token_count > 0 + except FileNotFoundError: + pytest.skip("SKILL.md not found") + + def test_dry_run_with_tools(self): + from skill_test.optimize.runner import optimize_skill + try: + result = optimize_skill( + "databricks-model-serving", preset="quick", dry_run=True, + include_tools=True, tool_modules=["serving"], + ) + assert SKILL_KEY in result.components + assert "tools_serving" in result.components + assert result.original_token_count > 0 + except FileNotFoundError: + pytest.skip("SKILL.md not found") + + +# -------------------------------------------------------------------------- +# E2E integration (requires GEPA + LLM API key) +# -------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +@pytest.mark.slow +class TestOptimizeE2E: + def test_optimize_improves_quality_and_reduces_tokens(self): + from skill_test.optimize.runner import optimize_skill + result = optimize_skill( + skill_name="databricks-spark-declarative-pipelines", + mode="static", + preset="quick", + ) + assert result.optimized_score >= result.original_score + assert result.optimized_token_count <= result.original_token_count * 1.05 + + if result.val_scores: + avg_val = sum(result.val_scores.values()) / len(result.val_scores) + assert avg_val >= result.optimized_score - 0.05 + + print(f"\nQuality: {result.original_score:.3f} -> {result.optimized_score:.3f}") + print(f"Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,}") diff --git a/.test/uv.lock b/.test/uv.lock index 0a1e71f8..67d8cfe0 100644 --- a/.test/uv.lock +++ b/.test/uv.lock @@ -838,7 +838,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1149,6 +1149,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/c9/97cc5aae1648dcb851958a3ddf73ccd7dbe5650d95203ecb4d7720b4cdbf/fsspec-2026.1.0-py3-none-any.whl", hash = "sha256:cb76aa913c2285a3b49bdd5fc55b1d7c708d7208126b60f2eb8194fe1b4cbdcc", size = 201838, upload-time = "2026-01-09T15:21:34.041Z" }, ] +[[package]] +name = "gepa" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/30/511e52916956508f56eca721260fcd524cfffd580e57782dd471be925f7e/gepa-0.1.0.tar.gz", hash = "sha256:f8b3d7918d4cdcf8593f39ef1cc757c4ba1a4e6793e3ffb622e6c0bc60a1efd9", size = 226064, upload-time = "2026-02-19T19:43:08.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/32/fe8afb3d2a6605a6bcbc8f119f0a2adae96e9e5d57ebed055490219956a8/gepa-0.1.0-py3-none-any.whl", hash = "sha256:4e3f8fe8ca20169e60518b2e9d416e8c4a579459848adffdcad12223fbf9643e", size = 191392, upload-time = "2026-02-19T19:43:07.065Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -3919,8 +3928,10 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "databricks-sdk" }, + { name = "gepa" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "tiktoken" }, ] databricks = [ { name = "databricks-sdk" }, @@ -3929,18 +3940,24 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, ] +optimize = [ + { name = "gepa" }, + { name = "tiktoken" }, +] [package.metadata] requires-dist = [ { name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.20.0" }, + { name = "gepa", marker = "extra == 'optimize'", specifier = ">=0.1.0" }, { name = "mlflow", extras = ["databricks"], specifier = ">=3.6.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "skill-test", extras = ["databricks", "dev"], marker = "extra == 'all'" }, + { name = "skill-test", extras = ["databricks", "dev", "optimize"], marker = "extra == 'all'" }, + { name = "tiktoken", marker = "extra == 'optimize'", specifier = ">=0.7.0" }, ] -provides-extras = ["databricks", "dev", "all"] +provides-extras = ["databricks", "dev", "optimize", "all"] [[package]] name = "skops" diff --git a/databricks-builder-app/alembic/env.py b/databricks-builder-app/alembic/env.py index 07021cd5..0b8110db 100644 --- a/databricks-builder-app/alembic/env.py +++ b/databricks-builder-app/alembic/env.py @@ -150,6 +150,12 @@ def run_migrations_online(): """Run migrations in 'online' mode using sync engine.""" url, connect_args = get_url_and_connect_args() + # Get schema name from Alembic config or environment + schema_name = config.get_main_option('lakebase_schema_name') or os.environ.get('LAKEBASE_SCHEMA_NAME', 'builder_app') + + # Add search_path to connect_args so tables are created in the custom schema + connect_args.setdefault('options', f'-c search_path={schema_name},public') + connectable = create_engine( url, poolclass=pool.NullPool, @@ -157,6 +163,11 @@ def run_migrations_online(): ) with connectable.connect() as connection: + # Create the schema if it doesn't exist (SP has CREATE on the database) + from sqlalchemy import text + connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS {schema_name}')) + connection.commit() + context.configure( connection=connection, target_metadata=target_metadata, diff --git a/databricks-builder-app/app.yaml b/databricks-builder-app/app.yaml index 5efe3f00..a57d8526 100644 --- a/databricks-builder-app/app.yaml +++ b/databricks-builder-app/app.yaml @@ -55,6 +55,8 @@ env: value: "fe-shared-demo" - name: LAKEBASE_DATABASE_NAME value: "databricks_postgres" + - name: LAKEBASE_SCHEMA_NAME + value: "builder_app" # ============================================================================= # LLM Provider Configuration @@ -92,9 +94,9 @@ env: # instead of hitting Anthropic directly. The app dynamically sets ANTHROPIC_BASE_URL # and ANTHROPIC_AUTH_TOKEN from the user's Databricks credentials at runtime. - name: ANTHROPIC_MODEL - value: "databricks-claude-opus-4-5" + value: "databricks-claude-opus-4-6" - name: ANTHROPIC_MODEL_MINI - value: "databricks-claude-sonnet-4-5" + value: "databricks-claude-sonnet-4-6" - name: CLAUDE_CODE_STREAM_CLOSE_TIMEOUT value: "3600000" diff --git a/databricks-builder-app/client/src/lib/types.ts b/databricks-builder-app/client/src/lib/types.ts index 05f45511..efaf8cff 100644 --- a/databricks-builder-app/client/src/lib/types.ts +++ b/databricks-builder-app/client/src/lib/types.ts @@ -61,6 +61,7 @@ export interface Warehouse { state: string; cluster_size?: string | null; creator_name?: string | null; + is_serverless?: boolean; } /** Todo item from agent TodoWrite tool */ diff --git a/databricks-builder-app/client/src/pages/ProjectPage.tsx b/databricks-builder-app/client/src/pages/ProjectPage.tsx index 7b7ea193..0a59d91c 100644 --- a/databricks-builder-app/client/src/pages/ProjectPage.tsx +++ b/databricks-builder-app/client/src/pages/ProjectPage.tsx @@ -1161,7 +1161,8 @@ export default function ProjectPage() { } const cluster = clusters.find(c => c.cluster_id === selectedClusterId); if (cluster) { - chips.push({ label: cluster.cluster_name || 'Cluster', color: cluster.state === 'RUNNING' ? 'text-[var(--color-success)]' : 'text-[var(--color-text-muted)]' }); + const isServerless = cluster.cluster_id === '__serverless__'; + chips.push({ label: isServerless ? 'Serverless Compute' : (cluster.cluster_name || 'Cluster'), color: cluster.state === 'RUNNING' ? 'text-[var(--color-success)]' : 'text-[var(--color-text-muted)]' }); } const warehouse = warehouses.find(w => w.warehouse_id === selectedWarehouseId); if (warehouse) { diff --git a/databricks-builder-app/requirements.txt b/databricks-builder-app/requirements.txt index e9287669..48b5cb7c 100644 --- a/databricks-builder-app/requirements.txt +++ b/databricks-builder-app/requirements.txt @@ -16,7 +16,7 @@ greenlet>=3.0.0 psycopg2-binary>=2.9.11 # Claude Agent SDK -claude-agent-sdk>=0.1.19 +claude-agent-sdk>=0.1.39 anthropic>=0.42.0 # MLflow for tracing Claude Code conversations diff --git a/databricks-builder-app/scripts/deploy.sh b/databricks-builder-app/scripts/deploy.sh index 42651f9a..dc8cccc0 100755 --- a/databricks-builder-app/scripts/deploy.sh +++ b/databricks-builder-app/scripts/deploy.sh @@ -213,6 +213,11 @@ cp -r server "$STAGING_DIR/" cp app.yaml "$STAGING_DIR/" cp requirements.txt "$STAGING_DIR/" +# Copy Alembic migrations +echo " Copying Alembic migrations..." +cp alembic.ini "$STAGING_DIR/" +cp -r alembic "$STAGING_DIR/" + # Copy frontend build (server expects it at client/out/) echo " Copying frontend build..." mkdir -p "$STAGING_DIR/client" diff --git a/databricks-builder-app/server/db/database.py b/databricks-builder-app/server/db/database.py index 156e82dd..37bdba4c 100644 --- a/databricks-builder-app/server/db/database.py +++ b/databricks-builder-app/server/db/database.py @@ -375,6 +375,7 @@ def init_database(database_url: Optional[str] = None) -> AsyncEngine: # Connect args for psycopg3 with DNS workaround connect_args = { "sslmode": "require", + "options": f"-c search_path={os.environ.get('LAKEBASE_SCHEMA_NAME', 'builder_app')},public", } # Add hostaddr if DNS resolution was needed (bypasses Python's getaddrinfo) if _resolved_hostaddr: @@ -562,6 +563,10 @@ def run_migrations() -> None: if alembic_dir.exists(): alembic_cfg.set_main_option("script_location", str(alembic_dir)) + # Pass the schema name to Alembic env.py via config + schema_name = os.environ.get("LAKEBASE_SCHEMA_NAME", "builder_app") + alembic_cfg.set_main_option("lakebase_schema_name", schema_name) + command.upgrade(alembic_cfg, "head") logger.info("Database migrations completed") except Exception as e: diff --git a/databricks-builder-app/server/services/agent.py b/databricks-builder-app/server/services/agent.py index b2766523..30d2cf47 100644 --- a/databricks-builder-app/server/services/agent.py +++ b/databricks-builder-app/server/services/agent.py @@ -423,11 +423,15 @@ async def stream_agent_response( claude_env['ANTHROPIC_AUTH_TOKEN'] = effective_fmapi_token # Set the model to use (required for Databricks FMAPI) - anthropic_model = os.environ.get('ANTHROPIC_MODEL', 'databricks-claude-opus-4-5') + anthropic_model = os.environ.get('ANTHROPIC_MODEL', 'databricks-claude-opus-4-6') claude_env['ANTHROPIC_MODEL'] = anthropic_model - # Disable beta headers for Databricks FMAPI compatibility - claude_env['ANTHROPIC_CUSTOM_HEADERS'] = 'x-databricks-disable-beta-headers: true' + # Disable beta headers and experimental betas for Databricks FMAPI compatibility + # ANTHROPIC_CUSTOM_HEADERS enables coding agent mode on FMAPI + # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS prevents context_management and other + # experimental body parameters that FMAPI doesn't support (400: Extra inputs not permitted) + claude_env['ANTHROPIC_CUSTOM_HEADERS'] = 'x-databricks-use-coding-agent-mode: true' + claude_env['CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS'] = '1' logger.info(f'Configured Databricks model serving: {anthropic_base_url} with model {anthropic_model}') logger.info(f'Claude env vars: BASE_URL={claude_env.get("ANTHROPIC_BASE_URL")}, MODEL={claude_env.get("ANTHROPIC_MODEL")}') diff --git a/databricks-builder-app/server/services/clusters.py b/databricks-builder-app/server/services/clusters.py index 74dad5cd..7e3f24b4 100644 --- a/databricks-builder-app/server/services/clusters.py +++ b/databricks-builder-app/server/services/clusters.py @@ -23,11 +23,14 @@ _cache_lock = Lock() +SERVERLESS_CLUSTER_ID = '__serverless__' + + def _fetch_clusters_sync(limit: int = 50, timeout: int = 15) -> list[dict]: """Synchronously fetch clusters from Databricks using SDK. - Returns clusters sorted by: running first, "shared" in name second, then alphabetically. - Filters out serverless clusters. + Returns a "Serverless Compute" entry first (always), followed by real clusters + sorted by: running first, "shared" in name second, then alphabetically. Args: limit: Maximum number of clusters to return @@ -57,7 +60,17 @@ def sort_key(c): filtered_clusters.sort(key=sort_key) - return [ + # Build result with Serverless Compute as the first (default) entry + result = [ + { + 'cluster_id': SERVERLESS_CLUSTER_ID, + 'cluster_name': 'Serverless Compute', + 'state': 'RUNNING', + 'creator_user_name': None, + }, + ] + + result.extend( { 'cluster_id': c.cluster_id, 'cluster_name': c.cluster_name, @@ -65,7 +78,9 @@ def sort_key(c): 'creator_user_name': c.creator_user_name, } for c in filtered_clusters[:limit] - ] + ) + + return result async def _refresh_cache(timeout_seconds: int = 30) -> None: @@ -142,4 +157,16 @@ async def list_clusters_async() -> list[dict]: # No cache - we must wait for the first fetch await _refresh_cache() - return _get_cached_clusters() or [] + result = _get_cached_clusters() + if result: + return result + + # Even if the API call failed, always return the serverless option + return [ + { + 'cluster_id': SERVERLESS_CLUSTER_ID, + 'cluster_name': 'Serverless Compute', + 'state': 'RUNNING', + 'creator_user_name': None, + }, + ] diff --git a/databricks-builder-app/server/services/system_prompt.py b/databricks-builder-app/server/services/system_prompt.py index 5b7b4fef..91f73f6b 100644 --- a/databricks-builder-app/server/services/system_prompt.py +++ b/databricks-builder-app/server/services/system_prompt.py @@ -107,7 +107,17 @@ def get_system_prompt( """ cluster_section = '' - if cluster_id: + if cluster_id == 'serverless' or cluster_id == '__serverless__': + cluster_section = """ +## Compute: Serverless + +You are configured to use **Databricks Serverless Compute** for code execution. + +When using `execute_databricks_command` or `run_python_file_on_databricks`: +- **Do NOT pass a cluster_id parameter** — serverless compute is used automatically when no cluster is specified. +- Serverless compute starts instantly with no cluster startup wait time. +""" + elif cluster_id: cluster_section = f""" ## Selected Cluster diff --git a/databricks-builder-app/server/services/test_clusters.py b/databricks-builder-app/server/services/test_clusters.py index a7cfc66a..80d949c9 100644 --- a/databricks-builder-app/server/services/test_clusters.py +++ b/databricks-builder-app/server/services/test_clusters.py @@ -24,17 +24,24 @@ def test_fetch_clusters_sync_performance(): def test_clusters_sorted_correctly(): - """Test that clusters are sorted: running first, shared second, then alphabetically.""" - from server.services.clusters import _fetch_clusters_sync + """Test that clusters are sorted: serverless first, then running, shared, then rest.""" + from server.services.clusters import _fetch_clusters_sync, SERVERLESS_CLUSTER_ID clusters = _fetch_clusters_sync(limit=50) - if len(clusters) < 2: - pytest.skip('Not enough clusters to test sorting') + # First entry should always be Serverless Compute + assert clusters[0]['cluster_id'] == SERVERLESS_CLUSTER_ID + assert clusters[0]['cluster_name'] == 'Serverless Compute' + + # Skip the synthetic serverless entry for sort order checking + real_clusters = [c for c in clusters if c['cluster_id'] != SERVERLESS_CLUSTER_ID] + + if len(real_clusters) < 2: + pytest.skip('Not enough real clusters to test sorting') - # Check running clusters come first + # Check running clusters come first among real clusters found_non_running = False - for c in clusters: + for c in real_clusters: if c['state'] != 'RUNNING': found_non_running = True elif found_non_running: diff --git a/databricks-builder-app/server/services/warehouses.py b/databricks-builder-app/server/services/warehouses.py index 655b76a1..7922a7cf 100644 --- a/databricks-builder-app/server/services/warehouses.py +++ b/databricks-builder-app/server/services/warehouses.py @@ -25,10 +25,12 @@ def _fetch_warehouses_sync(limit: int = 50, timeout: int = 15) -> list[dict]: """Synchronously fetch SQL warehouses from Databricks using SDK. Returns warehouses sorted by priority: - 1. Running + "shared" in name (highest priority) - 2. Running (without "shared") - 3. Not running + "shared" in name - 4. Everything else + 1. Serverless + running (highest priority) + 2. Serverless + not running + 3. Running + "shared" in name + 4. Running (without "shared") + 5. Not running + "shared" in name + 6. Everything else Args: limit: Maximum number of warehouses to return @@ -42,19 +44,24 @@ def _fetch_warehouses_sync(limit: int = 50, timeout: int = 15) -> list[dict]: # Fetch warehouses warehouses = list(islice(client.warehouses.list(), limit * 2)) - # Sort by priority: running + shared > running > shared > rest + # Sort by priority: serverless first, then running + shared > running > shared > rest def sort_key(w): is_running = w.state == State.RUNNING if w.state else False is_shared = 'shared' in (w.name or '').lower() - # Priority: (running + shared) = 0, running = 1, shared = 2, other = 3 - if is_running and is_shared: + is_serverless = getattr(w, 'enable_serverless_compute', False) or False + # Serverless warehouses always come first + if is_serverless and is_running: priority = 0 - elif is_running: + elif is_serverless: priority = 1 - elif is_shared: + elif is_running and is_shared: priority = 2 - else: + elif is_running: priority = 3 + elif is_shared: + priority = 4 + else: + priority = 5 return priority warehouses.sort(key=sort_key) @@ -66,6 +73,7 @@ def sort_key(w): 'state': w.state.value if w.state else 'UNKNOWN', 'cluster_size': w.cluster_size, 'creator_name': w.creator_name, + 'is_serverless': getattr(w, 'enable_serverless_compute', False) or False, } for w in warehouses[:limit] ] diff --git a/databricks-mcp-server/databricks_mcp_server/tools/compute.py b/databricks-mcp-server/databricks_mcp_server/tools/compute.py index e2462d44..06e52eff 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/compute.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/compute.py @@ -1,6 +1,6 @@ """Compute tools - Execute code on Databricks clusters.""" -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List from databricks_tools_core.compute import ( list_clusters as _list_clusters, diff --git a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py b/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py index 08324bd6..2727c999 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py @@ -31,14 +31,12 @@ update_project as _update_project, delete_project as _delete_project, create_branch as _create_branch, - get_branch as _get_branch, list_branches as _list_branches, update_branch as _update_branch, delete_branch as _delete_branch, create_endpoint as _create_endpoint, list_endpoints as _list_endpoints, update_endpoint as _update_endpoint, - delete_endpoint as _delete_endpoint, generate_credential as _generate_autoscale_credential, ) diff --git a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py b/databricks-mcp-server/databricks_mcp_server/tools/manifest.py index 1adc30c5..500a77a7 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/manifest.py @@ -30,7 +30,8 @@ def _delete_from_databricks(resource_type: str, resource_id: str) -> Optional[st @mcp.tool def list_tracked_resources(type: Optional[str] = None) -> Dict[str, Any]: - """List resources tracked in the project manifest. + """ + List resources tracked in the project manifest. The manifest records every resource created through the MCP server (dashboards, jobs, pipelines, Genie spaces, KAs, MAS, schemas, volumes, etc.). @@ -60,7 +61,8 @@ def delete_tracked_resource( resource_id: str, delete_from_databricks: bool = False, ) -> Dict[str, Any]: - """Delete a resource from the project manifest, and optionally from Databricks. + """ + Delete a resource from the project manifest, and optionally from Databricks. Use this to clean up resources that were created during development/testing. diff --git a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py b/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py index 26b91d8c..0a77c322 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py @@ -1,33 +1,24 @@ -"""Pipeline tools - Manage Spark Declarative Pipelines (SDP). +"""Pipeline tools - Manage Spark Declarative Pipelines (SDP).""" -Provides 4 workflow-oriented tools following the Lakebase pattern: -- create_or_update_pipeline: idempotent resource management -- get_pipeline: get details by name/ID, or list all -- delete_pipeline: delete by ID -- run_pipeline: start, stop, and wait for pipeline runs -""" - -import logging -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any from databricks_tools_core.identity import get_default_tags from databricks_tools_core.spark_declarative_pipelines.pipelines import ( - create_or_update_pipeline as _create_or_update_pipeline, + create_pipeline as _create_pipeline, get_pipeline as _get_pipeline, + update_pipeline as _update_pipeline, delete_pipeline as _delete_pipeline, start_update as _start_update, get_update as _get_update, stop_pipeline as _stop_pipeline, get_pipeline_events as _get_pipeline_events, + create_or_update_pipeline as _create_or_update_pipeline, find_pipeline_by_name as _find_pipeline_by_name, - wait_for_pipeline_update as _wait_for_pipeline_update, ) from ..manifest import register_deleter from ..server import mcp -logger = logging.getLogger(__name__) - def _delete_pipeline_resource(resource_id: str) -> None: _delete_pipeline(pipeline_id=resource_id) @@ -36,192 +27,113 @@ def _delete_pipeline_resource(resource_id: str) -> None: register_deleter("pipeline", _delete_pipeline_resource) -# ============================================================================ -# Tool 1: create_or_update_pipeline -# ============================================================================ - - @mcp.tool -def create_or_update_pipeline( +def create_pipeline( name: str, root_path: str, catalog: str, schema: str, workspace_file_paths: List[str], - start_run: bool = False, - wait_for_completion: bool = False, - full_refresh: bool = True, - timeout: int = 1800, extra_settings: Dict[str, Any] = None, ) -> Dict[str, Any]: """ - Create a new pipeline or update an existing one with the same name. - - This is the main tool for pipeline resource management. It: - 1. Searches for an existing pipeline with the same name (or uses 'id' from extra_settings) - 2. Creates a new pipeline or updates the existing one - 3. Optionally starts a run and waits for completion - - For more granular run control (stop, selective refresh, validate-only), - use run_pipeline() separately. - - Uses Unity Catalog and serverless compute by default. + Create a new Spark Declarative Pipeline (Unity Catalog, serverless by default). Args: - name: Pipeline name (used for lookup and creation) + name: Pipeline name root_path: Root folder for source code (added to Python sys.path for imports) - catalog: Unity Catalog name for output tables + catalog: Unity Catalog name schema: Schema name for output tables workspace_file_paths: List of workspace file paths (raw .sql or .py files) - start_run: If True, start a pipeline run after create/update (default: False) - wait_for_completion: If True and start_run=True, wait for the run to finish (default: False) - full_refresh: If True, performs full refresh when starting (default: True) - timeout: Maximum wait time in seconds when wait_for_completion=True (default: 1800) - extra_settings: Optional dict with additional pipeline settings. Supports all SDK - options: clusters, continuous, development, photon, edition, channel, event_log, - configuration, notifications, tags, serverless, etc. - If 'id' is provided, the pipeline will be updated instead of created. - Explicit parameters (name, root_path, catalog, schema) take precedence. + extra_settings: Optional dict with additional pipeline settings (clusters, + continuous, development, photon, edition, channel, event_log, configuration, + notifications, tags, serverless, etc.). Explicit parameters take precedence. Returns: - Dictionary with: - - pipeline_id: The pipeline ID - - pipeline_name: The pipeline name - - created: True if newly created, False if updated - - success: True if operation succeeded - - message: Human-readable status message - - Example: - >>> create_or_update_pipeline( - ... name="my_pipeline", - ... root_path="/Workspace/project", - ... catalog="my_catalog", - ... schema="my_schema", - ... workspace_file_paths=["/Workspace/project/pipeline.py"] - ... ) - >>> create_or_update_pipeline( - ... name="my_pipeline", ..., - ... start_run=True, wait_for_completion=True - ... ) + Dictionary with pipeline_id of the created pipeline. """ # Auto-inject default tags into extra_settings; user tags take precedence extra_settings = extra_settings or {} extra_settings.setdefault("tags", {}) extra_settings["tags"] = {**get_default_tags(), **extra_settings["tags"]} - result = _create_or_update_pipeline( + result = _create_pipeline( name=name, root_path=root_path, catalog=catalog, schema=schema, workspace_file_paths=workspace_file_paths, - start_run=start_run, - wait_for_completion=wait_for_completion, - full_refresh=full_refresh, - timeout=timeout, extra_settings=extra_settings, ) - # Track resource on successful create/update + # Track resource on successful create try: - result_dict = result.to_dict() - pipeline_id = result_dict.get("pipeline_id") - if pipeline_id: + if result.pipeline_id: from ..manifest import track_resource track_resource( resource_type="pipeline", name=name, - resource_id=pipeline_id, + resource_id=result.pipeline_id, ) except Exception: pass # best-effort tracking - return result.to_dict() - - -# ============================================================================ -# Tool 2: get_pipeline -# ============================================================================ + return {"pipeline_id": result.pipeline_id} @mcp.tool -def get_pipeline( - pipeline_id: Optional[str] = None, - name: Optional[str] = None, -) -> Dict[str, Any]: +def get_pipeline(pipeline_id: str) -> Dict[str, Any]: """ - Get pipeline details, or list all pipelines. - - Pass pipeline_id or name to get one pipeline's details enriched with - latest update status and recent events. Omit both to list all pipelines. + Get pipeline details and configuration. Args: - pipeline_id: Pipeline ID. Takes precedence over name. - name: Pipeline name. Used to look up pipeline ID if pipeline_id not provided. + pipeline_id: Pipeline ID Returns: - Single pipeline dict with enriched details (if ID/name provided), - or {"pipelines": [...]} when listing all. - - Example: - >>> get_pipeline(pipeline_id="abc-123") - {"pipeline_id": "abc-123", "name": "my_pipeline", "state": "IDLE", ...} - >>> get_pipeline(name="my_pipeline") - {"pipeline_id": "abc-123", "name": "my_pipeline", ...} - >>> get_pipeline() - {"pipelines": [{"pipeline_id": "abc-123", "name": "my_pipeline", ...}]} + Dictionary with pipeline configuration and state. """ - # Resolve name to pipeline_id if needed - if not pipeline_id and name: - pipeline_id = _find_pipeline_by_name(name=name) - if not pipeline_id: - return {"error": f"Pipeline '{name}' not found."} - - if pipeline_id: - result = _get_pipeline(pipeline_id=pipeline_id) - pipeline_dict = result.as_dict() if hasattr(result, "as_dict") else vars(result) - - # Enrich with latest update status - try: - latest_updates = pipeline_dict.get("latest_updates", []) - if latest_updates: - latest = latest_updates[0] - update_id = latest.get("update_id") - if update_id: - update_result = _get_update(pipeline_id=pipeline_id, update_id=update_id) - update_dict = update_result.as_dict() if hasattr(update_result, "as_dict") else vars(update_result) - pipeline_dict["latest_update_status"] = update_dict - except Exception: - pass - - # Enrich with recent events - try: - events = _get_pipeline_events(pipeline_id=pipeline_id, max_results=10) - pipeline_dict["recent_events"] = [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events] - except Exception: - pass - - return pipeline_dict - - # List all pipelines - from databricks_tools_core.auth import get_workspace_client - - w = get_workspace_client() - pipelines = [] - try: - for p in w.pipelines.list_pipelines(): - entry = p.as_dict() if hasattr(p, "as_dict") else vars(p) - pipelines.append(entry) - except Exception as e: - return {"error": f"Failed to list pipelines: {e}"} + result = _get_pipeline(pipeline_id=pipeline_id) + return result.as_dict() if hasattr(result, "as_dict") else vars(result) - return {"pipelines": pipelines} +@mcp.tool +def update_pipeline( + pipeline_id: str, + name: str = None, + root_path: str = None, + catalog: str = None, + schema: str = None, + workspace_file_paths: List[str] = None, + extra_settings: Dict[str, Any] = None, +) -> Dict[str, str]: + """ + Update pipeline configuration. + + Args: + pipeline_id: Pipeline ID + name: New pipeline name + root_path: New root folder for source code + catalog: New catalog name + schema: New schema name + workspace_file_paths: New list of file paths (raw .sql or .py files) + extra_settings: Optional dict with additional pipeline settings (clusters, + continuous, development, photon, edition, channel, event_log, configuration, + notifications, tags, serverless, etc.). Explicit parameters take precedence. -# ============================================================================ -# Tool 3: delete_pipeline -# ============================================================================ + Returns: + Dictionary with status message. + """ + _update_pipeline( + pipeline_id=pipeline_id, + name=name, + root_path=root_path, + catalog=catalog, + schema=schema, + workspace_file_paths=workspace_file_paths, + extra_settings=extra_settings, + ) + return {"status": "updated"} @mcp.tool @@ -245,59 +157,27 @@ def delete_pipeline(pipeline_id: str) -> Dict[str, str]: return {"status": "deleted"} -# ============================================================================ -# Tool 4: run_pipeline -# ============================================================================ - - @mcp.tool -def run_pipeline( +def start_update( pipeline_id: str, + refresh_selection: List[str] = None, full_refresh: bool = False, - refresh_selection: Optional[List[str]] = None, - full_refresh_selection: Optional[List[str]] = None, + full_refresh_selection: List[str] = None, validate_only: bool = False, - wait_for_completion: bool = False, - timeout: int = 1800, - stop: bool = False, -) -> Dict[str, Any]: +) -> Dict[str, str]: """ - Start, stop, or monitor a pipeline run. - - Set stop=True to stop a running pipeline. Otherwise starts an update. - Optionally waits for the run to complete with timeout. + Start a pipeline update or dry-run validation. Args: pipeline_id: Pipeline ID - full_refresh: If True, performs full refresh of all tables refresh_selection: List of table names to refresh + full_refresh: If True, performs full refresh of all tables full_refresh_selection: List of table names for full refresh validate_only: If True, validates without updating data (dry run) - wait_for_completion: If True, wait for run to complete (default: False) - timeout: Maximum wait time in seconds (default: 1800 = 30 minutes) - stop: If True, stop the currently running pipeline instead of starting Returns: - Dictionary with: - - update_id: Update ID (if started) - - state: Final state (if waited) - - success: True if completed successfully - - duration_seconds: Time taken (if waited) - - errors: List of error details (if failed) - - status: "stopped" (if stop=True) - - Example: - >>> run_pipeline(pipeline_id="abc-123", full_refresh=True) - {"update_id": "xyz-456", "status": "started"} - >>> run_pipeline(pipeline_id="abc-123", stop=True) - {"status": "stopped"} - >>> run_pipeline(pipeline_id="abc-123", wait_for_completion=True, timeout=600) - {"update_id": "xyz-456", "state": "COMPLETED", "success": True, ...} + Dictionary with update_id for polling status. """ - if stop: - _stop_pipeline(pipeline_id=pipeline_id) - return {"pipeline_id": pipeline_id, "status": "stopped"} - update_id = _start_update( pipeline_id=pipeline_id, refresh_selection=refresh_selection, @@ -305,55 +185,188 @@ def run_pipeline( full_refresh_selection=full_refresh_selection, validate_only=validate_only, ) + return {"update_id": update_id} + + +@mcp.tool +def get_update(pipeline_id: str, update_id: str) -> Dict[str, Any]: + """ + Get pipeline update status and results. + + Args: + pipeline_id: Pipeline ID + update_id: Update ID from start_update + + Returns: + Dictionary with update status (QUEUED, RUNNING, COMPLETED, FAILED, etc.) + """ + result = _get_update(pipeline_id=pipeline_id, update_id=update_id) + return result.as_dict() if hasattr(result, "as_dict") else vars(result) - result: Dict[str, Any] = { - "pipeline_id": pipeline_id, - "update_id": update_id, - "status": "started", - } - if not wait_for_completion: - result["message"] = ( - f"Pipeline update started. Use get_pipeline(pipeline_id='{pipeline_id}') " - f"to check status, or run_pipeline with wait_for_completion=True to wait." +@mcp.tool +def stop_pipeline(pipeline_id: str) -> Dict[str, str]: + """ + Stop a running pipeline. + + Args: + pipeline_id: Pipeline ID + + Returns: + Dictionary with status message. + """ + _stop_pipeline(pipeline_id=pipeline_id) + return {"status": "stopped"} + + +@mcp.tool +def get_pipeline_events( + pipeline_id: str, + max_results: int = 100, +) -> List[Dict[str, Any]]: + """ + Get pipeline events, issues, and error messages. + + Use this to debug pipeline failures. + + Args: + pipeline_id: Pipeline ID + max_results: Maximum number of events to return (default: 100) + + Returns: + List of event dictionaries with error details. + """ + events = _get_pipeline_events(pipeline_id=pipeline_id, max_results=max_results) + return [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events] + + +@mcp.tool +def create_or_update_pipeline( + name: str, + root_path: str, + catalog: str, + schema: str, + workspace_file_paths: List[str], + start_run: bool = False, + wait_for_completion: bool = False, + full_refresh: bool = True, + timeout: int = 1800, + extra_settings: Dict[str, Any] = None, +) -> Dict[str, Any]: + """ + Create a new pipeline or update an existing one with the same name. + + This is the main tool for pipeline management. It: + 1. Searches for an existing pipeline with the same name (or uses 'id' from extra_settings) + 2. Creates a new pipeline or updates the existing one + 3. Optionally starts a pipeline run with full refresh + 4. Optionally waits for the run to complete and returns detailed results + + Uses Unity Catalog and serverless compute by default. + + Args: + name: Pipeline name (used for lookup and creation) + root_path: Root folder for source code (added to Python sys.path for imports) + catalog: Unity Catalog name for output tables + schema: Schema name for output tables + workspace_file_paths: List of workspace file paths (raw .sql or .py files) + start_run: If True, start a pipeline update after create/update (default: False) + wait_for_completion: If True, wait for run to complete (default: False) + full_refresh: If True, perform full refresh when starting (default: True) + timeout: Maximum wait time in seconds (default: 1800 = 30 minutes) + extra_settings: Optional dict with additional pipeline settings. Supports all SDK + options: clusters, continuous, development, photon, edition, channel, event_log, + configuration, notifications, tags, serverless, etc. + If 'id' is provided, the pipeline will be updated instead of created. + Explicit parameters (name, root_path, catalog, schema) take precedence. + + Returns: + Dictionary with detailed status: + - pipeline_id: The pipeline ID + - pipeline_name: The pipeline name + - created: True if newly created, False if updated + - success: True if all operations succeeded + - state: Final state if run was started (COMPLETED, FAILED, etc.) + - duration_seconds: Time taken if waited + - error_message: Error message if failed + - errors: List of detailed errors if failed + - message: Human-readable status message + + Example usage: + # Just create/update the pipeline + create_or_update_pipeline(name="my_pipeline", ...) + + # Create/update and run immediately + create_or_update_pipeline(name="my_pipeline", ..., start_run=True) + + # Create/update, run, and wait for completion + create_or_update_pipeline( + name="my_pipeline", ..., + start_run=True, + wait_for_completion=True ) - return result - try: - wait_result = _wait_for_pipeline_update( - pipeline_id=pipeline_id, - update_id=update_id, - timeout=timeout, + # Create with custom settings (non-serverless, development mode) + create_or_update_pipeline( + name="my_pipeline", ..., + extra_settings={ + "serverless": False, + "development": True, + "clusters": [{"label": "default", "num_workers": 2}] + } ) - result["state"] = wait_result["state"] - result["success"] = wait_result["success"] - result["duration_seconds"] = wait_result["duration_seconds"] - result["status"] = "completed" if wait_result["success"] else "failed" - - if not wait_result["success"]: - result["errors"] = wait_result.get("errors", []) - if result["errors"]: - first_error = result["errors"][0] - error_msg = first_error.get("message", "") - if first_error.get("exceptions"): - exc = first_error["exceptions"][0] - error_msg = exc.get("message", error_msg) - result["error_message"] = error_msg - result["message"] = ( - f"Pipeline run failed with state: {result['state']}. " - f"Use get_pipeline(pipeline_id='{pipeline_id}') for full details." + """ + # Auto-inject default tags into extra_settings; user tags take precedence + extra_settings = extra_settings or {} + extra_settings.setdefault("tags", {}) + extra_settings["tags"] = {**get_default_tags(), **extra_settings["tags"]} + + result = _create_or_update_pipeline( + name=name, + root_path=root_path, + catalog=catalog, + schema=schema, + workspace_file_paths=workspace_file_paths, + start_run=start_run, + wait_for_completion=wait_for_completion, + full_refresh=full_refresh, + timeout=timeout, + extra_settings=extra_settings, + ) + + # Track resource on successful create/update + try: + result_dict = result.to_dict() + pipeline_id = result_dict.get("pipeline_id") + if pipeline_id: + from ..manifest import track_resource + + track_resource( + resource_type="pipeline", + name=name, + resource_id=pipeline_id, ) - else: - result["message"] = f"Pipeline completed successfully in {result['duration_seconds']}s." - - except TimeoutError as e: - result["state"] = "TIMEOUT" - result["success"] = False - result["status"] = "timeout" - result["error_message"] = str(e) - result["message"] = ( - f"Pipeline run timed out after {timeout}s. The pipeline may still be running. " - f"Use get_pipeline(pipeline_id='{pipeline_id}') to check status." - ) + except Exception: + pass # best-effort tracking - return result + return result.to_dict() + + +@mcp.tool +def find_pipeline_by_name(name: str) -> Dict[str, Any]: + """ + Find a pipeline by name and return its ID. + + Args: + name: Pipeline name to search for + + Returns: + Dictionary with: + - found: True if pipeline exists + - pipeline_id: Pipeline ID if found, None otherwise + """ + pipeline_id = _find_pipeline_by_name(name=name) + return { + "found": pipeline_id is not None, + "pipeline_id": pipeline_id, + } diff --git a/databricks-mcp-server/databricks_mcp_server/tools/sql.py b/databricks-mcp-server/databricks_mcp_server/tools/sql.py index 01355525..2be905d3 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/sql.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/sql.py @@ -28,9 +28,8 @@ def execute_sql( If no warehouse_id is provided, automatically selects the best available warehouse. IMPORTANT: For creating or dropping schemas, catalogs, and volumes, use the - manage_uc_objects tool instead of SQL DDL. It handles resource tracking and - auto-tagging. Only use execute_sql for queries (SELECT, INSERT, UPDATE) and - table DDL (CREATE TABLE, DROP TABLE). + manage_uc_objects tool instead of SQL DDL. Only use execute_sql for queries + (SELECT, INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). Args: sql_query: SQL query to execute @@ -67,9 +66,8 @@ def execute_sql_multi( in optimal order. Independent queries run in parallel. IMPORTANT: For creating or dropping schemas, catalogs, and volumes, use the - manage_uc_objects tool instead of SQL DDL. It handles resource tracking and - auto-tagging. Only use execute_sql/execute_sql_multi for queries (SELECT, - INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). + manage_uc_objects tool instead of SQL DDL. Only use execute_sql/execute_sql_multi + for queries (SELECT, INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). Args: sql_content: SQL content with multiple statements separated by ; diff --git a/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py b/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py index e4f1a04d..a75da32e 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py @@ -1,6 +1,6 @@ """Volume file tools - Manage files in Unity Catalog Volumes.""" -from typing import Dict, Any, List +from typing import Dict, Any from databricks_tools_core.unity_catalog import ( list_volume_files as _list_volume_files, diff --git a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py b/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py index 13789798..7b829bbc 100644 --- a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py +++ b/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py @@ -13,7 +13,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import requests from databricks.sdk import WorkspaceClient @@ -21,7 +21,6 @@ from ..auth import get_workspace_client, get_current_username from .models import ( EndpointStatus, - EvaluationRunDict, GenieIds, GenieListInstructionsResponseDict, GenieListQuestionsResponseDict, diff --git a/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py b/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py index beabb5e2..7f3cceb2 100644 --- a/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py +++ b/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py @@ -8,7 +8,7 @@ import asyncio import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union from databricks.sdk.service.dashboards import Dashboard diff --git a/databricks-tools-core/databricks_tools_core/jobs/runs.py b/databricks-tools-core/databricks_tools_core/jobs/runs.py index 665280b4..6b74c055 100644 --- a/databricks-tools-core/databricks_tools_core/jobs/runs.py +++ b/databricks-tools-core/databricks_tools_core/jobs/runs.py @@ -13,7 +13,7 @@ ) from ..auth import get_workspace_client -from .models import JobRunResult, JobError, RunLifecycleState, RunResultState as RunResultStateEnum +from .models import JobRunResult, JobError # Terminal states - run has finished (success or failure) diff --git a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py b/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py index 1808e0a1..44e53763 100644 --- a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py +++ b/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py @@ -1,3 +1,3 @@ """Spark Declarative Pipelines (SDP) operations""" -from . import pipelines, workspace_files +from . import pipelines as pipelines, workspace_files as workspace_files diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py b/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py index a8a2d208..958684d8 100644 --- a/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py +++ b/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import List, Optional -from databricks.sdk.service.files import DirectoryEntry from ..auth import get_workspace_client diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py b/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py index bfe4fb68..edef3450 100644 --- a/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py +++ b/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py @@ -11,7 +11,6 @@ """ import logging -import os import time import uuid diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py index 6a27d934..f0f24635 100644 --- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py +++ b/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py @@ -11,7 +11,6 @@ import logging import time -import uuid import pytest diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py index b46b5fae..b4317718 100644 --- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py +++ b/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py @@ -10,7 +10,6 @@ import logging import time -import uuid import pytest diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py index 1da5ff66..f2a7e465 100644 --- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py +++ b/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py @@ -7,7 +7,6 @@ import logging -import pytest from databricks_tools_core.lakebase_autoscale import ( generate_credential, diff --git a/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py b/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py index b5f445db..e6bdd178 100644 --- a/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py +++ b/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py @@ -6,7 +6,6 @@ - txt_files/: Simple text files """ -import os from pathlib import Path import pandas as pd diff --git a/databricks-tools-core/tests/integration/sql/test_table_stats.py b/databricks-tools-core/tests/integration/sql/test_table_stats.py index 4894108b..66c9515b 100644 --- a/databricks-tools-core/tests/integration/sql/test_table_stats.py +++ b/databricks-tools-core/tests/integration/sql/test_table_stats.py @@ -13,7 +13,6 @@ get_table_details, TableStatLevel, TableSchemaResult, - TableInfo, ) diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py b/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py index 824f3996..5134baae 100644 --- a/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py +++ b/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py @@ -18,9 +18,7 @@ from databricks_tools_core.unity_catalog import ( create_monitor, get_monitor, - run_monitor_refresh, list_monitor_refreshes, - delete_monitor, ) logger = logging.getLogger(__name__) diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py b/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py index 7f427768..4eb1dfc9 100644 --- a/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py +++ b/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py @@ -21,13 +21,11 @@ create_share, add_table_to_share, remove_table_from_share, - delete_share, grant_share_to_recipient, revoke_share_from_recipient, list_recipients, get_recipient, create_recipient, - delete_recipient, list_providers, )