Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ type-check:

# Run test suite
test:
uv run pytest tests/
uv run pytest -m "not sandbox" tests/

# Create Sandbox
sandbox:
uv run pytest -m "sandbox"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Issues = "https://github.com/dlt-hub/dlt-mcp/issues"

[dependency-groups]
dev = [
"dlt[duckdb]>=1.17.1",
"griffe>=1.13.0",
"prek>=0.2.10",
"pytest>=8.4.1",
Expand Down
21 changes: 21 additions & 0 deletions src/dlt_mcp/_prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from types import FunctionType
from fastmcp.prompts import Prompt

from dlt_mcp._prompts.infer_table_reference import infer_table_reference


__all__ = [
"PROMPTS_REGISTRY",
]


PROMPTS_REGISTRY: dict[str, Prompt] = {}


def register_prompt(fn: FunctionType) -> FunctionType:
global PROMPTS_REGISTRY
PROMPTS_REGISTRY[fn.__name__] = Prompt.from_function(fn)
return fn


register_prompt(infer_table_reference)
75 changes: 75 additions & 0 deletions src/dlt_mcp/_prompts/infer_table_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Optional
from dlt.common.schema.typing import TTableReference


def infer_table_reference(pipeline_name: Optional[str] = None):
"""Generates guidelines to infer table references for a given pipeline"""
table_reference_documentation = _get_table_reference_documentation()

prompt = (
"You are an helpful assistant to data architect and data engineers using DLT tasked with analyzing table relationships within a dlt pipeline.\n"
"## Workflow Steps \n"
)

if pipeline_name is None:
prompt += (
"- **Pipeline Discovery**:\n"
" - First, list all available pipelines using `list_pipelines()`\n"
" - Ask the user which pipeline they want to investigate \n"
" - wait until the user has provided a valid pipeline name"
)
prompt += (
"- **Data Exploration**:\n"
" - Ask the User if they want to explore specific tables or all of them \n"
" - Get schema details for each table using `get_table_schema(pipeline_name, table_name)`\n"
" - Show schema using mermaid, this will help make conversation more easier and save the world from bad data schemas \n"
"- **Relationship Analysis**:\n"
" - Look for common column names across tables (e.g., 'user_id', 'customer_id')\n"
" - Identify foreign key patterns (e.g., 'parent_id', 'foreign_key')\n"
" - Check for date/time columns that might indicate relationships\n"
" - Examine table descriptions for hints about relationships\n"
" - Look for auto-incrementing IDs that might reference other tables\n"
" - Generate mermaid to showcase the relationships. This too helps save the world from bad data\n"
"- **Validate Relationships**:\n"
" - Suggest ways to confirm these relationships (e.g., sample data inspection, referential integrity checks)\n"
" - You can execute these validations using execute_sql_query tool \n"
"- **Generating Table References**: \n"
f" - {table_reference_documentation} \n"
" - To maintain the information across codebase it's important to generate the table reference in the above format \n"
"## Tips:\n"
"- Use only the tools available to go through the process \n"
"- keep explanations small and to the point until asked for more details \n"
"- Think before providing reasoning about the relationships and one by one confirm each of them with the user \n"
"- NO NEED TO EXPLAIN THE FULL STRATEGY IN THE BEGINING KEEP IT SMALL\n"
"- DON'T ASK FOR PERMISION TO CREATE A MERMAID DIAGRAM. JUST DO IT\n"
"- AT THE END GENERATE TABLE REFERENCES IN THE FORMAT DEFINED ABOVE THIS CAN ACCELERATE THE USERS WORKFLOW 10 FOLD BECAUSE THEY CAN DIRECTLY USE IT IN CODE\n"
"## Information Presentation**:\n"
" GENERATE MERMAID DIAGRAM TO REPRESENT THE SCHEMA AND RELATIONSHIPS"
)
return prompt


def _get_table_reference_documentation() -> str:
"""Generate documentation for TTableReference with examples."""
user_id_ref: TTableReference = {
"columns": ["user_id"],
"referenced_table": "users",
"referenced_columns": ["id"],
}

product_ref: TTableReference = {
"columns": ["category_id", "brand_id"],
"referenced_table": "categories",
"referenced_columns": ["id", "id"],
}

table_reference_documentation = (
"**TTableReference Documentation: \n"
f"{TTableReference.__doc__} \n"
f"Required Columns: {TTableReference.__dict__} \n"
"Example: \n"
f"{user_id_ref} \n"
f"{product_ref} \n"
)

return table_reference_documentation
4 changes: 2 additions & 2 deletions src/dlt_mcp/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def register_tool(fn: FunctionType) -> FunctionType:
from dlt_mcp._tools.core import ( # noqa: E402
list_pipelines,
list_tables,
get_table_schema,
get_table_schemas,
execute_sql_query,
get_load_table,
get_pipeline_local_state,
Expand All @@ -28,7 +28,7 @@ def register_tool(fn: FunctionType) -> FunctionType:

register_tool(list_pipelines)
register_tool(list_tables)
register_tool(get_table_schema)
register_tool(get_table_schemas)
register_tool(execute_sql_query)
register_tool(get_load_table)
register_tool(get_pipeline_local_state)
15 changes: 10 additions & 5 deletions src/dlt_mcp/_tools/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,27 @@ def list_tables(pipeline_name: str) -> list[str]:
return schema.data_table_names()


def get_table_schema(pipeline_name: str, table_name: str) -> TTableSchema:
"""Get the schema of the specified table."""
def get_table_schemas(
pipeline_name: str, table_names: list[str]
) -> dict[str, TTableSchema]:
"""Get the schema of the specified tables names from a given pipeline"""
# TODO refactor try/except to specific line or at the tool manager level
# the inconsistent errors are probably due to database locking
try:
pipeline = dlt.attach(pipeline_name)
table_schema = pipeline.default_schema.get_table(table_name)
return table_schema
table_schemas = {
table_name: pipeline.default_schema.get_table(table_name)
for table_name in table_names
}
return table_schemas
except Exception:
raise


def execute_sql_query(pipeline_name: str, sql_select_query: str) -> list[tuple]:
f"""Executes SELECT SQL statement for simple data analysis.

Use the `{list_tables.__name__}()` and `{get_table_schema.__name__}()` tools to
Use the `{list_tables.__name__}()` and `{get_table_schemas.__name__}()` tools to
retrieve the available tables and columns.
"""
pipeline = dlt.attach(pipeline_name)
Expand Down
5 changes: 5 additions & 0 deletions src/dlt_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastmcp import FastMCP

from dlt_mcp._prompts import PROMPTS_REGISTRY
from dlt_mcp._tools import TOOLS_REGISTRY


Expand All @@ -17,6 +18,10 @@ def create_server() -> FastMCP:
tools=tools, # type: ignore[invalid-argument-type]
)

prompts = tuple(PROMPTS_REGISTRY.values())
for prompt in prompts:
server.add_prompt(prompt)

return server


Expand Down
Empty file added tests/prompts/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions tests/prompts/test_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio

from dlt_mcp.server import create_server
from dlt_mcp._prompts import PROMPTS_REGISTRY


def test_expected_prompts_in_all_clause():
"""Ensures expected prompt references are in `PROMPTS_REGISTRY`.

This test is expected to be modified as prompts are updated.

Renaming or removing a prompt is technically a breaking change,
but it can be patched downstream.
"""

expected_prompt_names = [
"infer_table_reference",
]

assert len(PROMPTS_REGISTRY) == len(expected_prompt_names)
assert set(PROMPTS_REGISTRY) == set(expected_prompt_names)


def test_expected_prompts_are_registered():
"""Ensures expected prompts exist on the server instance."""
expected_prompt_names = [
"infer_table_reference",
]

mcp_server = create_server()

prompts = asyncio.run(mcp_server.get_prompts())

assert len(prompts) == len(expected_prompt_names)
assert set(prompts) == set(expected_prompt_names)
Loading