Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,15 @@ def poll_result():
def ask_question(self, question, conversation_id: Optional[str] = None):
import mlflow

with mlflow.start_span(name="ask_question"):
with mlflow.start_span(name="ask_question") as span:
span.set_attributes(
{
"space_id": self.space_id,
"input.question": question,
"input.conversation_id": conversation_id or "",
}
)

# check if a conversation_id is supplied
# if yes, continue an existing genie conversation
# otherwise start a new conversation
Expand All @@ -472,4 +480,13 @@ def ask_question(self, question, conversation_id: Optional[str] = None):
genie_response = self.poll_for_result(resp["conversation_id"], resp["message_id"])
if not genie_response.conversation_id:
genie_response.conversation_id = resp["conversation_id"]

span.set_attributes(
{
"output.query": genie_response.query or "",
"output.description": genie_response.description or "",
"output.conversation_id": genie_response.conversation_id or "",
}
)

return genie_response
80 changes: 80 additions & 0 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,83 @@ def test_poll_null_attachments(genie, mock_workspace_client):
result = genie.poll_for_result("conv_123", "msg_456")
assert result.suggested_questions is None
assert result.text_attachment_content == ""


def test_ask_question_mlflow_trace_logs_inputs_and_outputs(genie, mock_workspace_client):
"""Test that ask_question logs inputs and outputs as span attributes."""
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "conv_123", "message_id": "msg_456"},
{
"status": "COMPLETED",
"conversation_id": "conv_123",
"attachments": [
{"attachment_id": "att_1", "query": {"query": "SELECT * FROM sales", "description": "All sales"}},
],
},
{
"statement_response": {
"status": {"state": "SUCCEEDED"},
"conversation_id": "conv_123",
"manifest": {"schema": {"columns": [{"name": "id", "type_name": "INT"}]}},
"result": {"data_array": [["1"], ["2"]]},
}
},
]

result = genie.ask_question("What is the meaning of life?")
assert isinstance(result.result, str)

# Verify the mlflow trace was created with correct span attributes
trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
assert trace is not None

spans = trace.search_spans(name="ask_question")
assert len(spans) == 1
ask_span = spans[0]

# Verify input attributes
assert ask_span.attributes["space_id"] == "test_space_id"
assert ask_span.attributes["input.question"] == "What is the meaning of life?"
assert ask_span.attributes["input.conversation_id"] == ""

# Verify output attributes (output.result is intentionally not logged)
assert "output.result" not in ask_span.attributes
assert ask_span.attributes["output.query"] == "SELECT * FROM sales"
assert ask_span.attributes["output.description"] == "All sales"
assert ask_span.attributes["output.conversation_id"] == "conv_123"


def test_ask_question_mlflow_trace_with_conversation_id(genie, mock_workspace_client):
"""Test that ask_question logs conversation_id input when continuing a conversation."""
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "conv_existing", "message_id": "msg_789"},
{
"status": "COMPLETED",
"conversation_id": "conv_existing",
"attachments": [
{"text": {"content": "Follow-up answer"}},
{"suggested_questions": {"questions": ["Next question?"]}},
],
},
]

result = genie.ask_question("Follow-up question", conversation_id="conv_existing")
assert result.result == "Follow-up answer"

trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
assert trace is not None

spans = trace.search_spans(name="ask_question")
assert len(spans) == 1
ask_span = spans[0]

# Verify input attributes include the conversation_id
assert ask_span.attributes["space_id"] == "test_space_id"
assert ask_span.attributes["input.question"] == "Follow-up question"
assert ask_span.attributes["input.conversation_id"] == "conv_existing"

# Verify output attributes
assert "output.result" not in ask_span.attributes
assert ask_span.attributes["output.query"] == ""
assert ask_span.attributes["output.description"] == ""
assert ask_span.attributes["output.conversation_id"] == "conv_existing"