diff --git a/galileo-adk/src/galileo_adk/trace_builder.py b/galileo-adk/src/galileo_adk/trace_builder.py index ad82528f..3a36fbae 100644 --- a/galileo-adk/src/galileo_adk/trace_builder.py +++ b/galileo-adk/src/galileo_adk/trace_builder.py @@ -21,11 +21,12 @@ from typing import Any from galileo_core.schemas.logging.agent import AgentType -from galileo_core.schemas.logging.span import AgentSpan, LlmSpan, RetrieverSpan, ToolSpan, WorkflowSpan -from galileo_core.schemas.logging.trace import Trace +from galileo_core.schemas.logging.span import LlmMetrics, RetrieverSpan, ToolSpan +from galileo_core.schemas.logging.step import Metrics from galileo_core.schemas.shared.traces_logger import TracesLogger from pydantic import PrivateAttr +from galileo.schema.logged import LoggedAgentSpan, LoggedLlmSpan, LoggedTrace, LoggedWorkflowSpan from galileo.schema.trace import TracesIngestRequest from galileo.utils.retrievers import convert_to_documents @@ -65,7 +66,7 @@ class TraceBuilder(TracesLogger): """ # Public fields (Pydantic fields) - traces: list[Trace] = [] + traces: list[LoggedTrace] = [] session_id: str | None = None # Private attributes (use PrivateAttr for non-field attributes) @@ -85,6 +86,46 @@ def __init__(self, ingestion_hook: Callable[[TracesIngestRequest], None], **data self._ingestion_hook = ingestion_hook self._session_external_id = None + def add_trace( + self, + input: str, + redacted_input: str | None = None, + output: str | None = None, + redacted_output: str | None = None, + name: str | None = None, + created_at: datetime | None = None, + duration_ns: int | None = None, + user_metadata: dict[str, str] | None = None, + tags: list[str] | None = None, + dataset_input: str | None = None, + dataset_output: str | None = None, + dataset_metadata: dict[str, str] | None = None, + external_id: str | None = None, + id: uuid.UUID | None = None, + ) -> LoggedTrace: + if self.current_parent() is not None: + raise ValueError("You must conclude the existing trace before adding a new one.") + trace = LoggedTrace( + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + name=name, + created_at=created_at, + user_metadata=user_metadata, + tags=tags, + metrics=Metrics(duration_ns=duration_ns), + dataset_input=dataset_input, + dataset_output=dataset_output, + dataset_metadata=dataset_metadata if dataset_metadata is not None else {}, + external_id=external_id, + id=id, + ) + trace._parent = None + self.traces.append(trace) + self._set_current_parent(trace) + return trace + @staticmethod def _convert_metadata_value(v: Any) -> str: """Convert a metadata value to string.""" @@ -107,7 +148,7 @@ def start_trace( dataset_output: str | None = None, dataset_metadata: dict[str, MetadataValue] | None = None, external_id: str | None = None, - ) -> Trace: + ) -> LoggedTrace: """Create a new trace and add it to the list of traces. This method mirrors GalileoLogger.start_trace() for API compatibility @@ -184,26 +225,30 @@ def add_workflow_span( tags: list[str] | None = None, step_number: int | None = None, status_code: int | None = None, - ) -> WorkflowSpan: + ) -> LoggedWorkflowSpan: """Add a workflow span to the current parent. This method wraps TracesLogger.add_workflow_span() to accept `metadata` parameter (for GalileoBaseHandler compatibility). """ - span = super().add_workflow_span( - id=uuid.uuid4(), + parent = self.current_parent() + span = LoggedWorkflowSpan( input=input, redacted_input=redacted_input, output=output, redacted_output=redacted_output, name=name, - duration_ns=duration_ns, - created_at=created_at, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, user_metadata=metadata, tags=tags, + metrics=Metrics(duration_ns=duration_ns), + id=uuid.uuid4(), step_number=step_number, ) - if span is not None and status_code is not None: + span._parent = parent + self.add_child_span_to_parent(span) + self._set_current_parent(span) + if status_code is not None: span.status_code = status_code return span @@ -221,27 +266,31 @@ def add_agent_span( agent_type: AgentType | None = None, step_number: int | None = None, status_code: int | None = None, - ) -> AgentSpan: + ) -> LoggedAgentSpan: """Add an agent span to the current parent. This method wraps TracesLogger.add_agent_span() to accept `metadata` parameter (for GalileoBaseHandler compatibility). """ - span = super().add_agent_span( - id=uuid.uuid4(), + parent = self.current_parent() + span = LoggedAgentSpan( input=input, redacted_input=redacted_input, output=output, redacted_output=redacted_output, name=name, - duration_ns=duration_ns, - created_at=created_at, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, user_metadata=metadata, tags=tags, + metrics=Metrics(duration_ns=duration_ns), agent_type=agent_type, + id=uuid.uuid4(), step_number=step_number, ) - if span is not None and status_code is not None: + span._parent = parent + self.add_child_span_to_parent(span) + self._set_current_parent(span) + if status_code is not None: span.status_code = status_code return span @@ -266,14 +315,13 @@ def add_llm_span( time_to_first_token_ns: int | None = None, step_number: int | None = None, events: list[Any] | None = None, - ) -> LlmSpan: + ) -> LoggedLlmSpan: """Add an LLM span to the current parent. This method wraps TracesLogger.add_llm_span() to accept `metadata` parameter (for GalileoBaseHandler compatibility). """ - return super().add_llm_span( - id=uuid.uuid4(), + span = LoggedLlmSpan( input=input, output=output, model=model, @@ -281,19 +329,24 @@ def add_llm_span( redacted_output=redacted_output, tools=tools, name=name, - created_at=created_at, - duration_ns=duration_ns, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, user_metadata=metadata, tags=tags, - num_input_tokens=num_input_tokens, - num_output_tokens=num_output_tokens, - total_tokens=total_tokens, + metrics=LlmMetrics( + duration_ns=duration_ns, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + num_total_tokens=total_tokens, + time_to_first_token_ns=time_to_first_token_ns, + ), + events=events, temperature=temperature, status_code=status_code, - time_to_first_token_ns=time_to_first_token_ns, + id=uuid.uuid4(), step_number=step_number, - events=events, ) + self.add_child_span_to_parent(span) + return span def add_tool_span( self, diff --git a/src/galileo/logger/logger.py b/src/galileo/logger/logger.py index c22a33e5..46c7e48a 100644 --- a/src/galileo/logger/logger.py +++ b/src/galileo/logger/logger.py @@ -17,6 +17,7 @@ from galileo.log_streams import LogStreams from galileo.logger.task_handler import ThreadPoolTaskHandler from galileo.projects import Projects +from galileo.schema.logged import IngestInputType, LoggedAgentSpan, LoggedLlmSpan, LoggedTrace, LoggedWorkflowSpan from galileo.schema.metrics import LocalMetricConfig from galileo.schema.trace import ( LogRecordsSearchFilter, @@ -52,7 +53,7 @@ from galileo_core.schemas.logging.agent import AgentType from galileo_core.schemas.logging.llm import Event from galileo_core.schemas.logging.span import ( - AgentSpan, + LlmMetrics, LlmSpan, LlmSpanAllowedInputType, LlmSpanAllowedOutputType, @@ -60,9 +61,8 @@ Span, StepWithChildSpans, ToolSpan, - WorkflowSpan, ) -from galileo_core.schemas.logging.step import BaseStep, Metrics, StepAllowedInputType, StepType +from galileo_core.schemas.logging.step import BaseStep, Metrics, StepType from galileo_core.schemas.logging.trace import Trace from galileo_core.schemas.protect.payload import Payload from galileo_core.schemas.protect.response import Response @@ -376,7 +376,7 @@ def _init_distributed_trace_stubs(self) -> None: Note: trace_id and span_id are already validated as UUIDs in __init__ """ - stub_trace = Trace( + stub_trace = LoggedTrace( input="", name=STUB_TRACE_NAME, created_at=datetime.now(), @@ -391,7 +391,7 @@ def _init_distributed_trace_stubs(self) -> None: if self.span_id: # If span_id is provided, also add the span (it's the immediate parent) - stub_span = WorkflowSpan( + stub_span = LoggedWorkflowSpan( input="", name="stub_parent_span", created_at=datetime.now(), @@ -402,6 +402,46 @@ def _init_distributed_trace_stubs(self) -> None: stub_span._parent = stub_trace self._set_current_parent(stub_span) + def add_trace( + self, + input: str, + redacted_input: Optional[str] = None, + output: Optional[str] = None, + redacted_output: Optional[str] = None, + name: Optional[str] = None, + created_at: Optional[datetime] = None, + duration_ns: Optional[int] = None, + user_metadata: Optional[dict[str, str]] = None, + tags: Optional[list[str]] = None, + dataset_input: Optional[str] = None, + dataset_output: Optional[str] = None, + dataset_metadata: Optional[dict[str, str]] = None, + external_id: Optional[str] = None, + id: Optional[uuid.UUID] = None, + ) -> LoggedTrace: + if self.current_parent() is not None: + raise ValueError("You must conclude the existing trace before adding a new one.") + trace = LoggedTrace( + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + name=name, + created_at=created_at, + user_metadata=user_metadata, + tags=tags, + metrics=Metrics(duration_ns=duration_ns), + dataset_input=dataset_input, + dataset_output=dataset_output, + dataset_metadata=dataset_metadata if dataset_metadata is not None else {}, + external_id=external_id, + id=id, + ) + trace._parent = None + self.traces.append(trace) + self._set_current_parent(trace) + return trace + @staticmethod def _convert_metadata_value(v: Any) -> str: """Convert a metadata value to string. @@ -655,7 +695,7 @@ async def update_span_with_backoff(request: Any) -> None: @nop_sync @warn_catch_exception(exceptions=(Exception,)) def _ingest_step_streaming(self, step: StepWithChildSpans, is_complete: bool = False) -> None: - if isinstance(step, Trace): + if isinstance(step, LoggedTrace): self._ingest_trace_streaming(step, is_complete=is_complete) else: self._ingest_span_streaming(step) @@ -663,7 +703,7 @@ def _ingest_step_streaming(self, step: StepWithChildSpans, is_complete: bool = F @nop_sync @warn_catch_exception(exceptions=(Exception,)) def _update_step_streaming(self, step: StepWithChildSpans, is_complete: bool = False) -> None: - if isinstance(step, Trace): + if isinstance(step, LoggedTrace): self._update_trace_streaming(step, is_complete=is_complete) else: self._update_span_streaming(step) @@ -752,8 +792,8 @@ def get_tracing_headers(self) -> dict[str, str]: @warn_catch_exception(exceptions=(Exception,)) def start_trace( self, - input: StepAllowedInputType | dict, - redacted_input: Optional[StepAllowedInputType | dict] = None, + input: Union[IngestInputType, dict], + redacted_input: Optional[Union[IngestInputType, dict]] = None, name: Optional[str] = None, duration_ns: Optional[int] = None, created_at: Optional[datetime] = None, @@ -763,21 +803,22 @@ def start_trace( dataset_output: Optional[str] = None, dataset_metadata: Optional[dict[str, MetadataValue]] = None, external_id: Optional[str] = None, - ) -> Trace: + ) -> LoggedTrace: """ Create a new trace and add it to the list of traces. Once this trace is complete, you can close it out by calling conclude(). Parameters ---------- - input: StepAllowedInputType | dict + input: IngestInputType | dict Input to the node. - Expected format: String, dict (auto-converted to JSON), or sequence of Message objects. + Expected format: String, dict (auto-converted to JSON), sequence of Message objects, + or sequence of LoggedMessage objects with multimodal content blocks. Examples - - String: "User query: What is the weather today?" - Dict: `{"query": "hello", "context": "world"}` (auto-converted to JSON string) - - Messages: `[Message(content="Hello", role=MessageRole.user)]` - redacted_input: Optional[StepAllowedInputType | dict] + - Messages: `[LoggedMessage(content="Hello", role=MessageRole.user)]` + redacted_input: Optional[IngestInputType | dict] Input that removes any sensitive information (redacted input). Same format as input parameter. name: Optional[str] @@ -809,7 +850,7 @@ def start_trace( Returns ------- - Trace + LoggedTrace The created trace. """ # Auto-convert dict input to JSON string (addresses common user mistake) @@ -876,7 +917,7 @@ def add_single_llm_span_trace( dataset_output: Optional[str] = None, dataset_metadata: Optional[dict[str, MetadataValue]] = None, span_step_number: Optional[int] = None, - ) -> Trace: + ) -> LoggedTrace: """ Create a new trace with a single span and add it to the list of traces. The trace is automatically concluded. @@ -955,7 +996,7 @@ def add_single_llm_span_trace( Returns ------- - Trace + LoggedTrace The created trace. """ # Auto-convert non-string metadata values to strings @@ -964,31 +1005,53 @@ def add_single_llm_span_trace( if dataset_metadata: dataset_metadata = {k: GalileoLogger._convert_metadata_value(v) for k, v in dataset_metadata.items()} - trace = super().add_single_llm_span_trace( + if self.current_parent() is not None: + raise ValueError("A trace cannot be created within a parent trace or span, it must always be the root.") + + trace = LoggedTrace( input=input, - output=output, redacted_input=redacted_input, + output=output, redacted_output=redacted_output, - model=model, - tools=tools, name=name, created_at=created_at, - duration_ns=duration_ns, user_metadata=metadata, tags=tags, - num_input_tokens=num_input_tokens, - num_output_tokens=num_output_tokens, - total_tokens=total_tokens, - temperature=temperature, - status_code=status_code, - time_to_first_token_ns=time_to_first_token_ns, dataset_input=dataset_input, dataset_output=dataset_output, - dataset_metadata=dataset_metadata, - span_step_number=span_step_number, - trace_id=uuid.uuid4(), - span_id=uuid.uuid4(), + dataset_metadata=dataset_metadata if dataset_metadata is not None else {}, + id=uuid.uuid4(), ) + trace.add_child_span( + LoggedLlmSpan( + name=name, + created_at=created_at, + user_metadata=metadata, + tags=tags, + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + metrics=LlmMetrics( + duration_ns=duration_ns, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + num_total_tokens=total_tokens, + time_to_first_token_ns=time_to_first_token_ns, + ), + tools=tools, + model=model, + temperature=temperature, + status_code=status_code, + dataset_input=dataset_input, + dataset_output=dataset_output, + dataset_metadata=dataset_metadata if dataset_metadata is not None else {}, + id=uuid.uuid4(), + step_number=span_step_number, + ) + ) + self.traces.append(trace) + self._set_current_parent(None) if self.mode == "distributed": self.traces = [trace] @@ -1097,30 +1160,31 @@ def add_llm_span( if metadata: metadata = {k: GalileoLogger._convert_metadata_value(v) for k, v in metadata.items()} - kwargs = { - "input": input, - "output": output, - "model": model, - "redacted_input": redacted_input, - "redacted_output": redacted_output, - "tools": tools, - "name": name, - "created_at": created_at, - "duration_ns": duration_ns, - "user_metadata": metadata, - "tags": tags, - "num_input_tokens": num_input_tokens, - "num_output_tokens": num_output_tokens, - "total_tokens": total_tokens, - "temperature": temperature, - "status_code": status_code, - "time_to_first_token_ns": time_to_first_token_ns, - "step_number": step_number, - "id": uuid.uuid4(), - "events": events, - } - - span = super().add_llm_span(**kwargs) + span = LoggedLlmSpan( + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + name=name, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, + user_metadata=metadata, + tags=tags, + metrics=LlmMetrics( + duration_ns=duration_ns, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + num_total_tokens=total_tokens, + time_to_first_token_ns=time_to_first_token_ns, + ), + tools=tools, + events=events, + model=model, + temperature=temperature, + status_code=status_code, + id=uuid.uuid4(), + step_number=step_number, + ) + self.add_child_span_to_parent(span) if self.mode == "distributed": self._ingest_step_streaming(span) @@ -1370,6 +1434,19 @@ def add_protect_span( return span + def _attach_parentable_span( + self, span: StepWithChildSpans, status_code: Optional[int] = None + ) -> StepWithChildSpans: + parent = self.current_parent() + span._parent = parent + self.add_child_span_to_parent(span) + self._set_current_parent(span) + if status_code is not None: + span.status_code = status_code + if self.mode == "distributed": + self._ingest_step_streaming(span) + return span + @nop_sync @warn_catch_exception(exceptions=(Exception,)) def add_workflow_span( @@ -1385,7 +1462,7 @@ def add_workflow_span( tags: Optional[list[str]] = None, step_number: Optional[int] = None, status_code: Optional[int] = None, - ) -> WorkflowSpan: + ) -> LoggedWorkflowSpan: """ Add a workflow span to the current parent. This is useful when you want to create a nested workflow span within the trace or current workflow span. The next span you add will be a child of the current parent. To @@ -1427,35 +1504,27 @@ def add_workflow_span( Returns ------- - WorkflowSpan + LoggedWorkflowSpan The created span. """ # Auto-convert non-string metadata values to strings if metadata: metadata = {k: GalileoLogger._convert_metadata_value(v) for k, v in metadata.items()} - kwargs = { - "input": input, - "redacted_input": redacted_input, - "output": output, - "redacted_output": redacted_output, - "name": name, - "duration_ns": duration_ns, - "created_at": created_at, - "user_metadata": metadata, - "tags": tags, - "step_number": step_number, - "id": uuid.uuid4(), - } - span = super().add_workflow_span(**kwargs) - - if span is not None and status_code is not None: - span.status_code = status_code - - if self.mode == "distributed": - self._ingest_step_streaming(span) - - return span + span = LoggedWorkflowSpan( + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + name=name, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, + user_metadata=metadata, + tags=tags, + metrics=Metrics(duration_ns=duration_ns), + id=uuid.uuid4(), + step_number=step_number, + ) + return self._attach_parentable_span(span, status_code) @nop_sync @warn_catch_exception(exceptions=(Exception,)) @@ -1473,7 +1542,7 @@ def add_agent_span( agent_type: Optional[AgentType] = None, step_number: Optional[int] = None, status_code: Optional[int] = None, - ) -> AgentSpan: + ) -> LoggedAgentSpan: """ Add an agent type span to the current parent. @@ -1517,36 +1586,28 @@ def add_agent_span( Returns ------- - AgentSpan + LoggedAgentSpan The created span. """ # Auto-convert non-string metadata values to strings if metadata: metadata = {k: GalileoLogger._convert_metadata_value(v) for k, v in metadata.items()} - kwargs = { - "input": input, - "redacted_input": redacted_input, - "output": output, - "redacted_output": redacted_output, - "name": name, - "duration_ns": duration_ns, - "created_at": created_at, - "user_metadata": metadata, - "tags": tags, - "agent_type": agent_type, - "step_number": step_number, - "id": uuid.uuid4(), - } - span = super().add_agent_span(**kwargs) - - if span is not None and status_code is not None: - span.status_code = status_code - - if self.mode == "distributed": - self._ingest_step_streaming(span) - - return span + span = LoggedAgentSpan( + input=input, + redacted_input=redacted_input, + output=output, + redacted_output=redacted_output, + name=name, + created_at=self._get_child_span_timestamp() if created_at is None else created_at, + user_metadata=metadata, + tags=tags, + metrics=Metrics(duration_ns=duration_ns), + agent_type=agent_type, + id=uuid.uuid4(), + step_number=step_number, + ) + return self._attach_parentable_span(span, status_code) @warn_catch_exception(exceptions=(Exception,)) def _conclude( @@ -1639,13 +1700,13 @@ def conclude( @nop_sync @warn_catch_exception(exceptions=(Exception,)) - def flush(self) -> list[Trace]: + def flush(self) -> list[LoggedTrace]: """ Upload all traces to Galileo. Returns ------- - List[Trace] + list[LoggedTrace] The list of uploaded traces. """ try: @@ -1659,13 +1720,13 @@ def flush(self) -> list[Trace]: @nop_async @async_warn_catch_exception(exceptions=(Exception,)) - async def async_flush(self) -> list[Trace]: + async def async_flush(self) -> list[LoggedTrace]: """ Async upload all traces to Galileo. Returns ------- - List[Trace] + list[LoggedTrace] The list of uploaded traces. """ try: @@ -1777,7 +1838,7 @@ def _auto_conclude_trace(self) -> None: self._update_trace_streaming(trace, is_complete=True) @async_warn_catch_exception(exceptions=(Exception,)) - async def _flush_distributed(self) -> list[Trace]: + async def _flush_distributed(self) -> list[LoggedTrace]: """Flush in distributed mode: conclude traces and wait for pending tasks. In distributed mode, traces/spans are sent immediately via conclude(). This method: @@ -1807,7 +1868,7 @@ async def _flush_distributed(self) -> list[Trace]: return [] @async_warn_catch_exception(exceptions=(Exception,)) - async def _flush_batch(self) -> list[Trace]: + async def _flush_batch(self) -> list[LoggedTrace]: """Flush in batch mode: conclude unconcluded traces and send all traces to backend.""" if not self.traces: self._logger.info("No traces to flush.") diff --git a/src/galileo/schema/__init__.py b/src/galileo/schema/__init__.py index e69de29b..f27be5de 100644 --- a/src/galileo/schema/__init__.py +++ b/src/galileo/schema/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa: F401 +from galileo.schema.content_blocks import DataContentBlock, IngestContentBlock, IngestMessageContent, TextContentBlock +from galileo.schema.logged import ( + IngestInputType, + IngestOutputType, + LoggedAgentSpan, + LoggedLlmSpan, + LoggedSpan, + LoggedTrace, + LoggedWorkflowSpan, +) +from galileo.schema.message import LoggedMessage diff --git a/src/galileo/schema/content_blocks.py b/src/galileo/schema/content_blocks.py new file mode 100644 index 00000000..d6a4735f --- /dev/null +++ b/src/galileo/schema/content_blocks.py @@ -0,0 +1,51 @@ +"""SDK-local content block types for ingestion. + +These types define what the SDK sends to the ingest service for multimodal content. +They are separate from the read-side ContentPart types in galileo-core, which +represent what the API/UI returns after storage. + +Ingestion blocks support inline data (base64, URLs, provider file IDs), +while read-side ContentParts reference stored files by file_id. +""" + +from typing import Annotated, Literal, Optional, Union + +from pydantic import BaseModel, Field, model_validator + +from galileo_core.schemas.shared.multimodal import ContentModality + + +class TextContentBlock(BaseModel): + """A text segment for ingestion.""" + + type: Literal["text"] = "text" + text: str + index: Optional[int] = None + metadata: Optional[dict[str, str]] = None + + +class DataContentBlock(BaseModel): + """A binary/media content block for ingestion. + + Exactly one of base64 or url must be set. + """ + + type: Literal["data"] = "data" + modality: ContentModality + mime_type: Optional[str] = None + base64: Optional[str] = None + url: Optional[str] = None + index: Optional[int] = None + metadata: Optional[dict[str, str]] = None + + @model_validator(mode="after") + def _exactly_one_source(self) -> "DataContentBlock": + sources = sum(v is not None for v in (self.base64, self.url)) + if sources != 1: + raise ValueError("Exactly one of base64 or url must be set.") + return self + + +IngestContentBlock = Annotated[Union[TextContentBlock, DataContentBlock], Field(discriminator="type")] + +IngestMessageContent = Union[str, list[IngestContentBlock]] diff --git a/src/galileo/schema/logged.py b/src/galileo/schema/logged.py new file mode 100644 index 00000000..d4322f0a --- /dev/null +++ b/src/galileo/schema/logged.py @@ -0,0 +1,127 @@ +"""SDK-local ingestion models that widen input/output for multimodal content. + +These models subclass the core Trace/Span types and override only the fields +that differ for ingestion (input, output, redacted_input, redacted_output). +Read-side code continues to use the core types directly. +""" + +from collections.abc import Sequence +from json import dumps +from typing import Annotated, Any, Optional, Union + +from pydantic import Field + +from galileo.schema.message import LoggedMessage +from galileo_core.schemas.logging.llm import Message, MessageRole +from galileo_core.schemas.logging.span import ( + AgentSpan, + LlmSpan, + LlmSpanAllowedInputType, + LlmSpanAllowedOutputType, + RetrieverSpan, + Span, # noqa: F401 # needed for Pydantic model_rebuild to resolve forward refs + ToolSpan, + WorkflowSpan, +) +from galileo_core.schemas.logging.step import BaseStep +from galileo_core.schemas.logging.trace import Trace +from galileo_core.schemas.shared.document import Document + +IngestInputType = Union[str, Sequence[LoggedMessage]] +IngestOutputType = Union[str, LoggedMessage, Sequence[Document]] + +_INPUT_FIELD = Field(default="", description=BaseStep.model_fields["input"].description, union_mode="left_to_right") +_REDACTED_INPUT_FIELD = Field( + default=None, description=BaseStep.model_fields["redacted_input"].description, union_mode="left_to_right" +) +_OUTPUT_FIELD = Field(default=None, description=BaseStep.model_fields["output"].description, union_mode="left_to_right") +_REDACTED_OUTPUT_FIELD = Field( + default=None, description=BaseStep.model_fields["redacted_output"].description, union_mode="left_to_right" +) + + +class LoggedTrace(Trace): + """Trace with widened input/output for multimodal ingestion.""" + + input: IngestInputType = _INPUT_FIELD + redacted_input: Optional[IngestInputType] = _REDACTED_INPUT_FIELD + output: Optional[IngestOutputType] = _OUTPUT_FIELD + redacted_output: Optional[IngestOutputType] = _REDACTED_OUTPUT_FIELD + spans: list["LoggedSpan"] = Field(default_factory=list) + + +class LoggedWorkflowSpan(WorkflowSpan): + """WorkflowSpan with widened input/output for multimodal ingestion.""" + + input: IngestInputType = _INPUT_FIELD + redacted_input: Optional[IngestInputType] = _REDACTED_INPUT_FIELD + output: Optional[IngestOutputType] = _OUTPUT_FIELD + redacted_output: Optional[IngestOutputType] = _REDACTED_OUTPUT_FIELD + spans: list["LoggedSpan"] = Field(default_factory=list) + + +class LoggedAgentSpan(AgentSpan): + """AgentSpan with widened input/output for multimodal ingestion.""" + + input: IngestInputType = _INPUT_FIELD + redacted_input: Optional[IngestInputType] = _REDACTED_INPUT_FIELD + output: Optional[IngestOutputType] = _OUTPUT_FIELD + redacted_output: Optional[IngestOutputType] = _REDACTED_OUTPUT_FIELD + spans: list["LoggedSpan"] = Field(default_factory=list) + + +class LoggedLlmSpan(LlmSpan): + """LlmSpan for ingestion using LoggedMessage (supports IngestContentBlocks, not ContentParts).""" + + input: Sequence[LoggedMessage] = Field( + default_factory=list, validate_default=True, description=BaseStep.model_fields["input"].description + ) + redacted_input: Optional[Sequence[LoggedMessage]] = Field( + default=None, description=BaseStep.model_fields["redacted_input"].description + ) + output: LoggedMessage = Field( + default_factory=lambda: LoggedMessage(content="", role=MessageRole.assistant), + validate_default=True, + description=BaseStep.model_fields["output"].description, + ) + redacted_output: Optional[LoggedMessage] = Field( + default=None, description=BaseStep.model_fields["redacted_output"].description + ) + + @classmethod + def _to_logged_message(cls, msg: Message) -> LoggedMessage: + if isinstance(msg, LoggedMessage): + return msg + return LoggedMessage( + content=msg.content, role=msg.role, tool_call_id=msg.tool_call_id, tool_calls=msg.tool_calls + ) + + @classmethod + def _convert_dict_to_message( + cls, value: dict[str, Any], default_role: MessageRole = MessageRole.user + ) -> LoggedMessage: + try: + return LoggedMessage.model_validate(value) + except Exception: + return LoggedMessage(content=dumps(value), role=default_role) + + @classmethod + def _convert_input_to_messages(cls, value: LlmSpanAllowedInputType) -> Sequence[LoggedMessage]: + messages = super()._convert_input_to_messages(value) + return [cls._to_logged_message(m) for m in messages] + + @classmethod + def _convert_output_to_message(cls, value: LlmSpanAllowedOutputType) -> LoggedMessage: + message = super()._convert_output_to_message(value) + return cls._to_logged_message(message) + + +# RetrieverSpan and ToolSpan use plain string/document I/O and don't need multimodal widening. +LoggedSpan = Annotated[ + Union[LoggedAgentSpan, LoggedWorkflowSpan, LoggedLlmSpan, RetrieverSpan, ToolSpan], Field(discriminator="type") +] + +LoggedTrace.model_rebuild() +LoggedWorkflowSpan.model_rebuild() +LoggedAgentSpan.model_rebuild() +LoggedLlmSpan.model_rebuild() diff --git a/src/galileo/schema/message.py b/src/galileo/schema/message.py index 1fbc4fcb..20019f9f 100644 --- a/src/galileo/schema/message.py +++ b/src/galileo/schema/message.py @@ -1,6 +1,9 @@ # ruff: noqa: F401 from typing import Any +from pydantic import Field + +from galileo.schema.content_blocks import IngestMessageContent from galileo_core.schemas.logging.llm import Message as CoreMessage # These classes should not be removed. They are used to rebuild the new `Message` model @@ -24,3 +27,16 @@ def __eq__(self, value: CoreMessage) -> bool: # Without rebuilding the model, Message class we create here would not know and validate # constituent classes defined in core which build up message. Message.model_rebuild() + + +class LoggedMessage(Message): + """Message type for ingestion that accepts IngestContentBlocks. + + Unlike the read-side Message (whose content is str | List[ContentPart]), + this accepts str | List[IngestContentBlock] for multimodal ingestion. + """ + + content: IngestMessageContent = Field(default="") + + +LoggedMessage.model_rebuild() diff --git a/src/galileo/schema/trace.py b/src/galileo/schema/trace.py index b213e985..63273856 100644 --- a/src/galileo/schema/trace.py +++ b/src/galileo/schema/trace.py @@ -4,9 +4,8 @@ from pydantic import UUID4, BaseModel, Field from galileo.resources.models import Document -from galileo_core.schemas.logging.span import Span +from galileo.schema.logged import LoggedSpan, LoggedTrace from galileo_core.schemas.logging.step import StepAllowedInputType, StepAllowedOutputType -from galileo_core.schemas.logging.trace import Trace SPAN_TYPE = Literal["llm", "retriever", "tool", "workflow", "agent"] @@ -33,14 +32,14 @@ class LogRecordsIngestRequest(BaseLogStreamOrExperimentModel): class TracesIngestRequest(LogRecordsIngestRequest): - traces: list[Trace] = Field(..., description="List of traces to log.", min_length=1) + traces: list[LoggedTrace] = Field(..., description="List of traces to log.", min_length=1) session_id: Optional[UUID4] = Field(default=None, description="Session id associated with the traces.") session_external_id: Optional[str] = Field(default=None, description="External id for session grouping.") is_complete: Optional[bool] = Field(default=True, description="Is complete.") class SpansIngestRequest(LogRecordsIngestRequest): - spans: list[Span] = Field(..., description="List of spans to log.", min_length=1) + spans: list[LoggedSpan] = Field(..., description="List of spans to log.", min_length=1) trace_id: UUID4 = Field(description="Trace id associated with the spans.") parent_id: UUID4 = Field(description="Parent trace or span id.") diff --git a/tests/schemas/test_logged.py b/tests/schemas/test_logged.py new file mode 100644 index 00000000..b600b90f --- /dev/null +++ b/tests/schemas/test_logged.py @@ -0,0 +1,402 @@ +"""Tests for SDK-local ingestion models (Logged variants and content blocks).""" + +import pytest +from pydantic import ValidationError + +from galileo.schema.content_blocks import DataContentBlock, TextContentBlock +from galileo.schema.logged import LoggedAgentSpan, LoggedLlmSpan, LoggedTrace, LoggedWorkflowSpan +from galileo.schema.message import LoggedMessage +from galileo.schema.trace import TracesIngestRequest +from galileo_core.schemas.logging.llm import MessageRole +from galileo_core.schemas.logging.span import AgentSpan, LlmSpan, RetrieverSpan, ToolSpan, WorkflowSpan +from galileo_core.schemas.logging.trace import Trace +from galileo_core.schemas.shared.document import Document +from galileo_core.schemas.shared.multimodal import ContentModality + + +class TestTextContentBlock: + def test_basic_text_block(self) -> None: + # Given: a plain text string + block = TextContentBlock(text="hello world") + + # Then: fields are set correctly + assert block.type == "text" + assert block.text == "hello world" + assert block.index is None + assert block.metadata is None + + def test_text_block_with_metadata(self) -> None: + # Given: a text block with index and metadata + block = TextContentBlock(text="chunk", index=0, metadata={"source": "doc1"}) + + # Then: all fields are preserved + assert block.index == 0 + assert block.metadata == {"source": "doc1"} + + +class TestDataContentBlock: + def test_base64_source(self) -> None: + # Given: a data block with base64 content + block = DataContentBlock(modality=ContentModality.image, mime_type="image/png", base64="iVBORw0KGgoAAAANS") + + # Then: source is base64 + assert block.type == "data" + assert block.modality == ContentModality.image + assert block.base64 is not None + assert block.url is None + + def test_url_source(self) -> None: + # Given: a data block with URL content + block = DataContentBlock(modality=ContentModality.image, url="https://example.com/image.png") + + # Then: source is url + assert block.url == "https://example.com/image.png" + assert block.base64 is None + + def test_no_source_raises(self) -> None: + # When/Then: missing all source fields raises + with pytest.raises(ValidationError, match="Exactly one of"): + DataContentBlock(modality=ContentModality.image) + + def test_multiple_sources_raises(self) -> None: + # When/Then: multiple source fields raises + with pytest.raises(ValidationError, match="Exactly one of"): + DataContentBlock(modality=ContentModality.image, base64="abc", url="https://example.com/img.png") + + +class TestTracesIngestRequestBoundary: + def test_rejects_core_trace(self) -> None: + # When/Then: a plain core Trace is rejected since TracesIngestRequest expects LoggedTrace + with pytest.raises(ValidationError, match="Input should be a valid dictionary or instance of LoggedTrace"): + TracesIngestRequest(traces=[Trace(input="plain text")]) + + +class TestJsonRoundtripNoCoercion: + """Verify that model_dump(mode='json') → model_validate preserves types, nothing coerced to str.""" + + def test_logged_trace_roundtrip(self) -> None: + # Given: a LoggedTrace with multimodal message input and nested LLM span + trace = LoggedTrace( + input=[ + LoggedMessage( + content=[ + TextContentBlock(text="Analyze"), + DataContentBlock(modality=ContentModality.image, base64="abc"), + ], + role=MessageRole.user, + ) + ], + output="done", + spans=[ + LoggedLlmSpan( + input=[LoggedMessage(content="prompt", role=MessageRole.user)], + output=LoggedMessage(content="response", role=MessageRole.assistant), + ) + ], + ) + + # When: JSON dump + raw = trace.model_dump(mode="json") + + # Then: serialized structure is correct + assert isinstance(raw["input"], list) + assert raw["input"][0]["content"][0]["type"] == "text" + assert raw["input"][0]["content"][1]["type"] == "data" + assert raw["output"] == "done" + + # When: validated back + restored = LoggedTrace.model_validate(raw) + + # Then: restored types are exact, not coerced to string + assert isinstance(restored.input, list) + assert type(restored.input[0]) is LoggedMessage + assert isinstance(restored.input[0].content, list) + assert type(restored.input[0].content[0]) is TextContentBlock + assert type(restored.input[0].content[1]) is DataContentBlock + assert restored.input[0].content[1].base64 == "abc" + assert isinstance(restored.output, str) + assert type(restored.spans[0]) is LoggedLlmSpan + + def test_logged_workflow_span_roundtrip(self) -> None: + # Given: a LoggedWorkflowSpan with message input and child LLM span + span = LoggedWorkflowSpan( + input=[LoggedMessage(content="workflow input", role=MessageRole.user)], + output=LoggedMessage(content="workflow output", role=MessageRole.assistant), + spans=[ + LoggedLlmSpan( + input=[LoggedMessage(content="inner", role=MessageRole.user)], + output=LoggedMessage(content="inner out", role=MessageRole.assistant), + ) + ], + ) + + # When: JSON roundtrip + raw = span.model_dump(mode="json") + restored = LoggedWorkflowSpan.model_validate(raw) + + # Then: input stayed as list of messages, not coerced to string + assert isinstance(restored.input, list) + assert type(restored.input[0]) is LoggedMessage + assert restored.input[0].content == "workflow input" + assert type(restored.output) is LoggedMessage + assert type(restored.spans[0]) is LoggedLlmSpan + + def test_logged_agent_span_roundtrip(self) -> None: + # Given: a LoggedAgentSpan with multimodal message input + span = LoggedAgentSpan( + input=[ + LoggedMessage( + content=[ + TextContentBlock(text="agent task"), + DataContentBlock(modality=ContentModality.document, url="https://example.com/doc.pdf"), + ], + role=MessageRole.user, + ) + ], + output="agent result", + ) + + # When: JSON roundtrip + raw = span.model_dump(mode="json") + restored = LoggedAgentSpan.model_validate(raw) + + # Then: multimodal content preserved, not stringified + assert isinstance(restored.input, list) + assert type(restored.input[0]) is LoggedMessage + assert isinstance(restored.input[0].content, list) + assert type(restored.input[0].content[0]) is TextContentBlock + assert type(restored.input[0].content[1]) is DataContentBlock + assert restored.input[0].content[1].url == "https://example.com/doc.pdf" + + def test_logged_llm_span_roundtrip(self) -> None: + # Given: a LoggedLlmSpan with multimodal input messages + span = LoggedLlmSpan( + input=[ + LoggedMessage( + content=[ + TextContentBlock(text="What is this?"), + DataContentBlock(modality=ContentModality.image, base64="img_data"), + ], + role=MessageRole.user, + ) + ], + output=LoggedMessage(content="A cat", role=MessageRole.assistant), + ) + + # When: JSON dump + raw = span.model_dump(mode="json") + + # Then: serialized message content is a list of typed dicts + assert isinstance(raw["input"][0]["content"], list) + assert raw["input"][0]["content"][0]["type"] == "text" + assert raw["input"][0]["content"][1]["type"] == "data" + assert isinstance(raw["output"]["content"], str) + + # When: validated back + restored = LoggedLlmSpan.model_validate(raw) + + # Then: input messages preserved with typed content blocks + assert type(restored.input[0]) is LoggedMessage + assert isinstance(restored.input[0].content, list) + assert type(restored.input[0].content[0]) is TextContentBlock + assert type(restored.input[0].content[1]) is DataContentBlock + assert restored.input[0].content[1].base64 == "img_data" + assert type(restored.output) is LoggedMessage + assert restored.output.content == "A cat" + + def test_logged_trace_document_output_roundtrip(self) -> None: + # Given: a LoggedTrace with Sequence[Document] output (the third IngestOutputType arm) + trace = LoggedTrace( + input="query", output=[Document(content="doc 1"), Document(content="doc 2", metadata={"src": "wiki"})] + ) + + # When: JSON roundtrip + raw = trace.model_dump(mode="json") + restored = LoggedTrace.model_validate(raw) + + # Then: output is a list of Documents, not coerced to string + assert isinstance(restored.output, list) + assert len(restored.output) == 2 + assert type(restored.output[0]) is Document + assert restored.output[0].content == "doc 1" + assert restored.output[1].metadata == {"src": "wiki"} + + def test_logged_workflow_span_document_output_roundtrip(self) -> None: + # Given: a LoggedWorkflowSpan with Sequence[Document] output + span = LoggedWorkflowSpan(input="search query", output=[Document(content="result 1")]) + + # When: JSON roundtrip + raw = span.model_dump(mode="json") + restored = LoggedWorkflowSpan.model_validate(raw) + + # Then: output preserved as list of Documents + assert isinstance(restored.output, list) + assert type(restored.output[0]) is Document + + def test_retriever_span_roundtrip(self) -> None: + # Given: a trace containing a RetrieverSpan + trace = LoggedTrace( + input="find docs", + spans=[ + RetrieverSpan( + input="search query", output=[Document(content="retrieved doc", metadata={"score": "0.95"})] + ) + ], + ) + + # When: JSON roundtrip + raw = trace.model_dump(mode="json") + restored = LoggedTrace.model_validate(raw) + + # Then: RetrieverSpan and its Document output are preserved + assert type(restored.spans[0]) is RetrieverSpan + assert restored.spans[0].output[0].content == "retrieved doc" + + def test_tool_span_roundtrip(self) -> None: + # Given: a trace containing a ToolSpan + trace = LoggedTrace(input="use tool", spans=[ToolSpan(input="tool_call(arg=1)", output="tool result")]) + + # When: JSON roundtrip + raw = trace.model_dump(mode="json") + restored = LoggedTrace.model_validate(raw) + + # Then: ToolSpan is preserved + assert type(restored.spans[0]) is ToolSpan + assert restored.spans[0].output == "tool result" + + def test_full_ingest_request_roundtrip(self) -> None: + # Given: a TracesIngestRequest exercising all 5 span types and all output variants + request = TracesIngestRequest( + traces=[ + LoggedTrace( + input="top-level string", + output=[Document(content="trace-level doc output")], + spans=[ + LoggedAgentSpan( + input=[LoggedMessage(content="agent task", role=MessageRole.user)], + output="agent done", + spans=[ + LoggedWorkflowSpan( + input=[LoggedMessage(content="wf", role=MessageRole.user)], + output=LoggedMessage(content="wf result", role=MessageRole.assistant), + spans=[ + LoggedLlmSpan( + input=[ + LoggedMessage( + content=[ + TextContentBlock(text="deep"), + DataContentBlock( + modality=ContentModality.audio, base64="audio_b64" + ), + ], + role=MessageRole.user, + ) + ], + output=LoggedMessage(content="deep answer", role=MessageRole.assistant), + ), + RetrieverSpan(input="retrieve context", output=[Document(content="ctx doc")]), + ToolSpan(input="calc(2+2)", output="4"), + ], + ) + ], + ) + ], + ) + ] + ) + + # When: JSON roundtrip + raw = request.model_dump(mode="json") + restored = TracesIngestRequest.model_validate(raw) + + # Then: walk the full tree and verify types + trace = restored.traces[0] + assert type(trace) is LoggedTrace + assert isinstance(trace.output, list) + assert type(trace.output[0]) is Document + + agent = trace.spans[0] + assert type(agent) is LoggedAgentSpan + assert isinstance(agent.input, list) + assert type(agent.input[0]) is LoggedMessage + + wf = agent.spans[0] + assert type(wf) is LoggedWorkflowSpan + assert type(wf.output) is LoggedMessage + + llm = wf.spans[0] + assert type(llm) is LoggedLlmSpan + assert isinstance(llm.input[0].content, list) + assert type(llm.input[0].content[0]) is TextContentBlock + assert type(llm.input[0].content[1]) is DataContentBlock + assert llm.input[0].content[1].base64 == "audio_b64" + + retriever = wf.spans[1] + assert type(retriever) is RetrieverSpan + assert retriever.output[0].content == "ctx doc" + + tool = wf.spans[2] + assert type(tool) is ToolSpan + assert tool.output == "4" + + +class TestLoggedAndCoreParity: + """Verify Trace/LoggedTrace and Span/LoggedSpan behave identically for plain string content.""" + + def test_trace_string_io_parity(self) -> None: + # Given: the same plain-string fields + kwargs = dict(input="hello", output="world", name="t") + + # When: constructing both variants + core = Trace(**kwargs) + logged = LoggedTrace(**kwargs) + + # Then: field values and serialized output match + assert core.input == logged.input + assert core.output == logged.output + assert core.model_dump(include={"input", "output", "name"}) == logged.model_dump( + include={"input", "output", "name"} + ) + + def test_workflow_span_string_io_parity(self) -> None: + kwargs = dict(input="query", output="answer", name="wf") + core = WorkflowSpan(**kwargs) + logged = LoggedWorkflowSpan(**kwargs) + + assert core.input == logged.input + assert core.output == logged.output + assert core.model_dump(include={"input", "output", "name"}) == logged.model_dump( + include={"input", "output", "name"} + ) + + def test_agent_span_string_io_parity(self) -> None: + kwargs = dict(input="task", output="result", name="ag") + core = AgentSpan(**kwargs) + logged = LoggedAgentSpan(**kwargs) + + assert core.input == logged.input + assert core.output == logged.output + assert core.model_dump(include={"input", "output", "name"}) == logged.model_dump( + include={"input", "output", "name"} + ) + + def test_llm_span_string_io_parity(self) -> None: + core = LlmSpan(input="prompt", output="completion", name="llm", model="gpt-4") + logged = LoggedLlmSpan(input="prompt", output="completion", name="llm", model="gpt-4") + + assert core.input == logged.input + assert core.output == logged.output + assert core.model_dump(include={"input", "output", "name", "model"}) == logged.model_dump( + include={"input", "output", "name", "model"} + ) + + def test_logged_trace_is_instance_of_trace(self) -> None: + logged = LoggedTrace(input="x") + + assert isinstance(logged, Trace) + assert isinstance(logged, LoggedTrace) + + def test_logged_spans_are_instances_of_core(self) -> None: + assert isinstance(LoggedWorkflowSpan(input="x"), WorkflowSpan) + assert isinstance(LoggedAgentSpan(input="x"), AgentSpan) + assert isinstance(LoggedLlmSpan(input="x"), LlmSpan) diff --git a/tests/test_logger_batch.py b/tests/test_logger_batch.py index aacc6dc2..b676f930 100644 --- a/tests/test_logger_batch.py +++ b/tests/test_logger_batch.py @@ -10,9 +10,13 @@ import pytest from galileo.logger import GalileoLogger +from galileo.schema.content_blocks import DataContentBlock, TextContentBlock +from galileo.schema.logged import LoggedTrace +from galileo.schema.message import LoggedMessage from galileo.schema.metrics import LocalMetricConfig from galileo.schema.trace import TracesIngestRequest from galileo_core.schemas.logging.agent import AgentType +from galileo_core.schemas.logging.llm import MessageRole from galileo_core.schemas.logging.span import AgentSpan, LlmSpan, RetrieverSpan, Span, ToolSpan, WorkflowSpan from galileo_core.schemas.logging.step import Metrics from galileo_core.schemas.logging.trace import Trace @@ -20,6 +24,7 @@ from galileo_core.schemas.protect.payload import Payload from galileo_core.schemas.protect.response import Response, TraceMetadata from galileo_core.schemas.shared.document import Document +from galileo_core.schemas.shared.multimodal import ContentModality from tests.testutils.setup import ( setup_mock_experiments_client, setup_mock_logstreams_client, @@ -104,32 +109,16 @@ def test_single_span_trace_to_galileo( mock_traces_client_instance.ingest_traces.assert_called_once() payload: TracesIngestRequest = mock_traces_client_instance.ingest_traces.call_args.args[0] - expected_payload = TracesIngestRequest( - log_stream_id=None, # TODO: fix this - experiment_id=None, - traces=[ - Trace( - input="input", - output="output", - name="test-trace", - created_at=created_at, - user_metadata=metadata, - status_code=200, - spans=[span], - metrics=Metrics(duration_ns=1000000), - ) - ], - ) - trace = payload.traces[0] - assert trace.input == expected_payload.traces[0].input - assert trace.output == expected_payload.traces[0].output - assert trace.name == expected_payload.traces[0].name - assert trace.created_at == expected_payload.traces[0].created_at - assert trace.user_metadata == expected_payload.traces[0].user_metadata - assert trace.status_code == expected_payload.traces[0].status_code - assert trace.spans == expected_payload.traces[0].spans - assert trace.metrics == expected_payload.traces[0].metrics + assert isinstance(trace, LoggedTrace) + assert trace.input == "input" + assert trace.output == "output" + assert trace.name == "test-trace" + assert trace.created_at == created_at + assert trace.user_metadata == metadata + assert trace.status_code == 200 + assert trace.spans == [span] + assert trace.metrics == Metrics(duration_ns=1000000) assert logger.traces == [] assert logger._parent_stack == deque() @@ -339,31 +328,16 @@ def test_single_span_trace_to_galileo_experiment_id( mock_traces_client_instance.ingest_traces.assert_called_once() payload: TracesIngestRequest = mock_traces_client_instance.ingest_traces.call_args.args[0] - expected_payload = TracesIngestRequest( - log_stream_id=None, - experiment_id="6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9a", - traces=[ - Trace( - input="input", - output="output", - name="test-trace", - created_at=created_at, - user_metadata=metadata, - status_code=200, - spans=[], - metrics=Metrics(duration_ns=1000000), - ) - ], - ) trace = payload.traces[0] - assert trace.input == expected_payload.traces[0].input - assert trace.output == expected_payload.traces[0].output - assert trace.name == expected_payload.traces[0].name - assert trace.created_at == expected_payload.traces[0].created_at - assert trace.user_metadata == expected_payload.traces[0].user_metadata - assert trace.status_code == expected_payload.traces[0].status_code - assert trace.spans == expected_payload.traces[0].spans - assert trace.metrics == expected_payload.traces[0].metrics + assert isinstance(trace, LoggedTrace) + assert trace.input == "input" + assert trace.output == "output" + assert trace.name == "test-trace" + assert trace.created_at == created_at + assert trace.user_metadata == metadata + assert trace.status_code == 200 + assert trace.spans == [] + assert trace.metrics == Metrics(duration_ns=1000000) assert logger.traces == [] assert logger._parent_stack == deque() @@ -585,31 +559,16 @@ def test_multi_span_trace_to_galileo( logger.flush() payload: TracesIngestRequest = mock_traces_client_instance.ingest_traces.call_args[0][0] - expected_payload = TracesIngestRequest( - log_stream_id=None, # TODO: fix this - experiment_id=None, - traces=[ - Trace( - input="input", - output="response2", - name="test-trace", - created_at=created_at, - user_metadata=metadata, - status_code=200, - spans=[workflow_span, second_span], - metrics=Metrics(duration_ns=1000000), - ) - ], - ) trace = payload.traces[0] - assert trace.input == expected_payload.traces[0].input - assert trace.output == expected_payload.traces[0].output - assert trace.name == expected_payload.traces[0].name - assert trace.created_at == expected_payload.traces[0].created_at - assert trace.user_metadata == expected_payload.traces[0].user_metadata - assert trace.status_code == expected_payload.traces[0].status_code - assert trace.spans == expected_payload.traces[0].spans - assert trace.metrics == expected_payload.traces[0].metrics + assert isinstance(trace, LoggedTrace) + assert trace.input == "input" + assert trace.output == "response2" + assert trace.name == "test-trace" + assert trace.created_at == created_at + assert trace.user_metadata == metadata + assert trace.status_code == 200 + assert trace.spans == [workflow_span, second_span] + assert trace.metrics == Metrics(duration_ns=1000000) assert logger.traces == [] assert logger._parent_stack == deque() @@ -658,31 +617,16 @@ def local_scorer(step: Union[Trace, Span]) -> int: span.metrics.length = 6 payload: TracesIngestRequest = mock_traces_client_instance.ingest_traces.call_args[0][0] - expected_payload = TracesIngestRequest( - log_stream_id=None, # TODO: fix this - experiment_id=None, - traces=[ - Trace( - input="input", - output="output", - name="test-trace", - created_at=created_at, - user_metadata=metadata, - status_code=200, - spans=[span], - metrics=Metrics(duration_ns=1000000), - ) - ], - ) trace = payload.traces[0] - assert trace.input == expected_payload.traces[0].input - assert trace.output == expected_payload.traces[0].output - assert trace.name == expected_payload.traces[0].name - assert trace.created_at == expected_payload.traces[0].created_at - assert trace.user_metadata == expected_payload.traces[0].user_metadata - assert trace.status_code == expected_payload.traces[0].status_code - assert trace.spans == expected_payload.traces[0].spans - assert trace.metrics == expected_payload.traces[0].metrics + assert isinstance(trace, LoggedTrace) + assert trace.input == "input" + assert trace.output == "output" + assert trace.name == "test-trace" + assert trace.created_at == created_at + assert trace.user_metadata == metadata + assert trace.status_code == 200 + assert trace.spans == [span] + assert trace.metrics == Metrics(duration_ns=1000000) assert logger.traces == [] assert logger._parent_stack == deque() @@ -1546,7 +1490,7 @@ async def test_ingest_traces_methods( setup_mock_logstreams_client(mock_logstreams_client) logger = GalileoLogger(project="my_project", log_stream="my_log_stream") - trace = Trace(id=uuid4(), input="input", output="output") + trace = LoggedTrace(id=uuid4(), input="input", output="output") ingest_request = TracesIngestRequest(traces=[trace]) await logger.async_ingest_traces(ingest_request) @@ -1754,6 +1698,45 @@ def test_start_trace_auto_conversion( assert getattr(payload_trace, attr) == expected_value, f"payload.{attr} mismatch" +@patch("galileo.logger.logger.LogStreams") +@patch("galileo.logger.logger.Projects") +@patch("galileo.logger.logger.Traces") +def test_multimodal_input_not_stringified_at_trace_level( + mock_traces_client: Mock, mock_projects_client: Mock, mock_logstreams_client: Mock +) -> None: + """Multimodal content must be preserved at trace level, not serialized to string.""" + mock_traces_client_instance = setup_mock_traces_client(mock_traces_client) + setup_mock_projects_client(mock_projects_client) + setup_mock_logstreams_client(mock_logstreams_client) + + logger = GalileoLogger(project="my_project", log_stream="my_log_stream") + + # Given: multimodal message input with text + image content blocks + messages = [ + LoggedMessage( + content=[ + TextContentBlock(text="Describe this image"), + DataContentBlock(modality=ContentModality.image, url="https://example.com/img.png"), + ], + role=MessageRole.user, + ) + ] + logger.start_trace(input=messages) + logger.add_llm_span(input=messages, output="A sunset", model="gpt-4o") + logger.conclude("A sunset") + logger.flush() + + # Then: trace.input is the message list, not a stringified version + payload: TracesIngestRequest = mock_traces_client_instance.ingest_traces.call_args.args[0] + trace = payload.traces[0] + assert not isinstance(trace.input, str), "trace input should not be stringified" + assert isinstance(trace.input, list) + assert isinstance(trace.input[0], LoggedMessage) + assert trace.input[0].content[0].text == "Describe this image" + assert trace.input[0].content[1].modality == ContentModality.image + assert trace.input[0].content[1].url == "https://example.com/img.png" + + @pytest.mark.parametrize( "span_method,span_kwargs,expected_metadata", [ @@ -2134,7 +2117,7 @@ def test_ingest_traces_reuses_existing_client( assert logger._traces_client is not None # Given: a minimal trace payload - trace = Trace( + trace = LoggedTrace( input="test", name="test", created_at=datetime.datetime.now(), id=uuid4(), metrics=Metrics(duration_ns=0) ) request = TracesIngestRequest(traces=[trace])