Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,7 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
interrupted=speech_handle.interrupted,
created_at=reply_started_at,
metrics=assistant_metrics,
extra=llm_gen_data.extra if llm_gen_data.extra else NOT_GIVEN,
)
self._agent._chat_ctx.insert(msg)
self._session._conversation_item_added(msg)
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/voice/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class _LLMGenerationData:
function_ch: aio.Chan[llm.FunctionCall]
generated_text: str = ""
generated_functions: list[llm.FunctionCall] = field(default_factory=list)
extra: dict[str, Any] = field(default_factory=dict)
id: str = field(default_factory=lambda: utils.shortuuid("item_"))
started_fut: asyncio.Future[None] = field(default_factory=asyncio.Future)
ttft: float | None = None
Expand Down Expand Up @@ -176,6 +177,9 @@ async def _llm_inference_task(
data.generated_functions.append(fnc_call)
function_ch.send_nowait(fnc_call)

if chunk.delta.extra:
data.extra.update(chunk.delta.extra)

if chunk.delta.content:
data.generated_text += chunk.delta.content
text_ch.send_nowait(chunk.delta.content)
Expand Down
10 changes: 8 additions & 2 deletions tests/fake_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class FakeLLMResponse(BaseModel):
ttft: float
duration: float
tool_calls: list[FunctionToolCall] = Field(default_factory=list)
extra: dict[str, Any] = Field(default_factory=dict)

def speed_up(self, factor: float) -> FakeLLMResponse:
obj = copy.deepcopy(self)
Expand Down Expand Up @@ -98,10 +99,14 @@ async def _run(self) -> None:

await asyncio.sleep(resp.duration - (time.perf_counter() - start_time))

self._send_chunk(tool_calls=resp.tool_calls)
self._send_chunk(tool_calls=resp.tool_calls, extra=resp.extra if resp.extra else None)

def _send_chunk(
self, *, delta: str | None = None, tool_calls: list[FunctionToolCall] | None = None
self,
*,
delta: str | None = None,
tool_calls: list[FunctionToolCall] | None = None,
extra: dict[str, Any] | None = None,
) -> None:
self._event_ch.send_nowait(
ChatChunk(
Expand All @@ -110,6 +115,7 @@ def _send_chunk(
role="assistant",
content=delta,
tool_calls=tool_calls or [],
extra=extra,
),
)
)
Expand Down
2 changes: 2 additions & 0 deletions tests/fake_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def add_llm(
input: NotGivenOr[str] = NOT_GIVEN,
ttft: float = 0.1,
duration: float = 0.3,
extra: dict[str, Any] | None = None,
) -> None:
if (
not utils.is_given(input)
Expand All @@ -167,6 +168,7 @@ def add_llm(
ttft=ttft,
duration=duration,
tool_calls=tool_calls or [],
extra=extra or {},
)
)

Expand Down
51 changes: 51 additions & 0 deletions tests/test_agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,3 +733,54 @@ def check_timestamp(
assert abs(t_event - t_target) <= max_abs_diff, (
f"event timestamp {t_event} is not within {max_abs_diff} of target {t_target}"
)


async def test_llm_extra_propagation() -> None:
"""
Test that ChoiceDelta.extra is propagated to ChatMessage.extra.

This ensures message-level metadata (like Gemini thought signatures) flows
through the inference proxy path and is preserved in the chat context.
"""
speed = 5.0
actions = FakeActions()
actions.add_user_speech(0.5, 2.5, "Hello, how are you?", stt_delay=0.2)
# LLM response with extra data (simulating thought signatures)
actions.add_llm(
"I'm doing well, thank you!",
ttft=0.1,
duration=0.3,
extra={"thought_signature": "test_signature_123", "other_metadata": {"key": "value"}},
)
actions.add_tts(2.0, ttfb=0.2, duration=0.3)

session = create_session(actions, speed_factor=speed)
agent = MyAgent()

conversation_events: list[ConversationItemAddedEvent] = []
session.on("conversation_item_added", conversation_events.append)

await asyncio.wait_for(run_session(session, agent), timeout=SESSION_TIMEOUT)

# Find the assistant message
assistant_events = [
ev
for ev in conversation_events
if ev.item.type == "message" and ev.item.role == "assistant"
]
assert len(assistant_events) == 1

assistant_msg = assistant_events[0].item
assert assistant_msg.text_content == "I'm doing well, thank you!"

# Verify extra data is preserved
assert assistant_msg.extra is not None
assert assistant_msg.extra.get("thought_signature") == "test_signature_123"
assert assistant_msg.extra.get("other_metadata") == {"key": "value"}

# Also verify via chat context
assistant_items = [
item for item in agent.chat_ctx.items if item.type == "message" and item.role == "assistant"
]
assert len(assistant_items) == 1
assert assistant_items[0].extra.get("thought_signature") == "test_signature_123"
Loading