diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 6257eacb..17cdb5bd 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -461,7 +461,15 @@ def poll_result(): def ask_question(self, question, conversation_id: Optional[str] = None): import mlflow - with mlflow.start_span(name="ask_question"): + with mlflow.start_span(name="ask_question") as span: + span.set_attributes( + { + "space_id": self.space_id, + "input.question": question, + "input.conversation_id": conversation_id or "", + } + ) + # check if a conversation_id is supplied # if yes, continue an existing genie conversation # otherwise start a new conversation @@ -472,4 +480,13 @@ def ask_question(self, question, conversation_id: Optional[str] = None): genie_response = self.poll_for_result(resp["conversation_id"], resp["message_id"]) if not genie_response.conversation_id: genie_response.conversation_id = resp["conversation_id"] + + span.set_attributes( + { + "output.query": genie_response.query or "", + "output.description": genie_response.description or "", + "output.conversation_id": genie_response.conversation_id or "", + } + ) + return genie_response diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index 3025bb8a..57b75e85 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -1023,3 +1023,83 @@ def test_poll_null_attachments(genie, mock_workspace_client): result = genie.poll_for_result("conv_123", "msg_456") assert result.suggested_questions is None assert result.text_attachment_content == "" + + +def test_ask_question_mlflow_trace_logs_inputs_and_outputs(genie, mock_workspace_client): + """Test that ask_question logs inputs and outputs as span attributes.""" + mock_workspace_client.genie._api.do.side_effect = [ + {"conversation_id": "conv_123", "message_id": "msg_456"}, + { + "status": "COMPLETED", + "conversation_id": "conv_123", + "attachments": [ + {"attachment_id": "att_1", "query": {"query": "SELECT * FROM sales", "description": "All sales"}}, + ], + }, + { + "statement_response": { + "status": {"state": "SUCCEEDED"}, + "conversation_id": "conv_123", + "manifest": {"schema": {"columns": [{"name": "id", "type_name": "INT"}]}}, + "result": {"data_array": [["1"], ["2"]]}, + } + }, + ] + + result = genie.ask_question("What is the meaning of life?") + assert isinstance(result.result, str) + + # Verify the mlflow trace was created with correct span attributes + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + assert trace is not None + + spans = trace.search_spans(name="ask_question") + assert len(spans) == 1 + ask_span = spans[0] + + # Verify input attributes + assert ask_span.attributes["space_id"] == "test_space_id" + assert ask_span.attributes["input.question"] == "What is the meaning of life?" + assert ask_span.attributes["input.conversation_id"] == "" + + # Verify output attributes (output.result is intentionally not logged) + assert "output.result" not in ask_span.attributes + assert ask_span.attributes["output.query"] == "SELECT * FROM sales" + assert ask_span.attributes["output.description"] == "All sales" + assert ask_span.attributes["output.conversation_id"] == "conv_123" + + +def test_ask_question_mlflow_trace_with_conversation_id(genie, mock_workspace_client): + """Test that ask_question logs conversation_id input when continuing a conversation.""" + mock_workspace_client.genie._api.do.side_effect = [ + {"conversation_id": "conv_existing", "message_id": "msg_789"}, + { + "status": "COMPLETED", + "conversation_id": "conv_existing", + "attachments": [ + {"text": {"content": "Follow-up answer"}}, + {"suggested_questions": {"questions": ["Next question?"]}}, + ], + }, + ] + + result = genie.ask_question("Follow-up question", conversation_id="conv_existing") + assert result.result == "Follow-up answer" + + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + assert trace is not None + + spans = trace.search_spans(name="ask_question") + assert len(spans) == 1 + ask_span = spans[0] + + # Verify input attributes include the conversation_id + assert ask_span.attributes["space_id"] == "test_space_id" + assert ask_span.attributes["input.question"] == "Follow-up question" + assert ask_span.attributes["input.conversation_id"] == "conv_existing" + + # Verify output attributes + assert "output.result" not in ask_span.attributes + assert ask_span.attributes["output.query"] == "" + assert ask_span.attributes["output.description"] == "" + assert ask_span.attributes["output.conversation_id"] == "conv_existing"