Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion livekit-agents/livekit/agents/inference/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,23 @@
"n",
}

# xAI reasoning models only restrict presence_penalty, frequency_penalty, stop.
# They still support temperature and top_p.
_XAI_REASONING_UNSUPPORTED_PARAMS: set[str] = {
"presence_penalty",
"frequency_penalty",
"stop",
}

# Model prefix -> set of param names that should be dropped
_UNSUPPORTED_PARAMS: dict[str, set[str]] = {
"o1": _REASONING_UNSUPPORTED_PARAMS,
"o3": _REASONING_UNSUPPORTED_PARAMS,
"o4": _REASONING_UNSUPPORTED_PARAMS,
"gpt-5": _REASONING_UNSUPPORTED_PARAMS,
"grok-4-1-fast-reasoning": _XAI_REASONING_UNSUPPORTED_PARAMS,
"grok-4.20-0309-reasoning": _XAI_REASONING_UNSUPPORTED_PARAMS,
"grok-4.20-multi-agent": _XAI_REASONING_UNSUPPORTED_PARAMS,
}

# models that don't support reasoning_effort when function tools are present
Expand Down Expand Up @@ -110,7 +121,15 @@ def drop_unsupported_params(
"deepseek-ai/deepseek-v3.2",
]

LLMModels = OpenAIModels | GoogleModels | KimiModels | DeepSeekModels
XAIModels = Literal[
"xai/grok-4-1-fast-non-reasoning",
"xai/grok-4-1-fast-reasoning",
"xai/grok-4.20-0309-non-reasoning",
"xai/grok-4.20-0309-reasoning",
"xai/grok-4.20-multi-agent-0309",
]

LLMModels = OpenAIModels | GoogleModels | KimiModels | DeepSeekModels | XAIModels


class ChatCompletionOptions(TypedDict, total=False):
Expand Down Expand Up @@ -444,11 +463,13 @@ def _parse_choice(
return call_chunk

if choice.finish_reason in ("tool_calls", "stop") and self._tool_call_id:
finish_extra = getattr(delta, "extra_content", None)
call_chunk = llm.ChatChunk(
id=id,
delta=llm.ChoiceDelta(
role="assistant",
content=delta.content,
extra=finish_extra,
tool_calls=[
llm.FunctionToolCall(
arguments=self._fnc_raw_arguments or "",
Expand Down
22 changes: 13 additions & 9 deletions livekit-agents/livekit/agents/llm/_provider_format/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from .utils import group_tool_calls

_EXTRA_CONTENT_KEYS = ("google", "livekit", "xai")


def _filter_extra(extra: dict[str, Any]) -> dict[str, Any]:
return {k: extra[k] for k in _EXTRA_CONTENT_KEYS if extra.get(k)}


def to_chat_ctx(
chat_ctx: llm.ChatContext, *, inject_dummy_user_message: bool = True
Expand All @@ -26,9 +32,9 @@ def to_chat_ctx(
"type": "function",
"function": {"name": tool_call.name, "arguments": tool_call.arguments},
}
# Include provider-specific extra content (e.g., Google thought signatures)
if tool_call.extra.get("google"):
tc["extra_content"] = {"google": tool_call.extra["google"]}
extra_content = _filter_extra(tool_call.extra) if tool_call.extra else {}
if extra_content:
tc["extra_content"] = extra_content
tool_calls.append(tc)
if tool_calls:
msg["tool_calls"] = tool_calls
Expand Down Expand Up @@ -62,9 +68,7 @@ def _to_chat_item(msg: llm.ChatItem) -> dict[str, Any]:
list_content.append({"type": "text", "text": text_content})
result = {"role": msg.role, "content": list_content}

# Include provider-specific extra content (e.g., Google thought signatures)
provider_keys = ("google", "livekit")
extra_content = {k: msg.extra[k] for k in provider_keys if msg.extra.get(k)}
extra_content = _filter_extra(msg.extra)
if extra_content:
result["extra_content"] = extra_content
return result
Expand All @@ -78,9 +82,9 @@ def _to_chat_item(msg: llm.ChatItem) -> dict[str, Any]:
"arguments": msg.arguments,
},
}
# Include provider-specific extra content (e.g., Google thought signatures)
if msg.extra.get("google"):
tc["extra_content"] = {"google": msg.extra["google"]}
extra_content = _filter_extra(msg.extra)
if extra_content:
tc["extra_content"] = extra_content
return {
"role": "assistant",
"tool_calls": [tc],
Expand Down
7 changes: 7 additions & 0 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class CollectedResponse(BaseModel):
text: str = ""
tool_calls: list[FunctionToolCall] = Field(default_factory=list)
usage: CompletionUsage | None = None
extra: dict[str, Any] = Field(default_factory=dict)
"""Provider-specific extra data accumulated across chunks
(e.g., xAI encrypted reasoning, Google thought signatures)."""


class ChoiceDelta(BaseModel):
Expand Down Expand Up @@ -428,6 +431,7 @@ async def collect(self) -> CollectedResponse:
text_parts: list[str] = []
tool_calls: list[FunctionToolCall] = []
usage: CompletionUsage | None = None
extra: dict[str, Any] = {}

async with self:
async for chunk in self:
Expand All @@ -436,11 +440,14 @@ async def collect(self) -> CollectedResponse:
text_parts.append(chunk.delta.content)
if chunk.delta.tool_calls:
tool_calls.extend(chunk.delta.tool_calls)
if chunk.delta.extra:
extra.update(chunk.delta.extra)
if chunk.usage is not None:
usage = chunk.usage

return CollectedResponse(
text="".join(text_parts).strip(),
tool_calls=tool_calls,
usage=usage,
extra=extra,
)
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,13 +2477,17 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None:
)

if forwarded_text:
extra_kwargs: dict = {}
if llm_gen_data.generated_extra:
extra_kwargs["extra"] = llm_gen_data.generated_extra
msg = chat_ctx.add_message(
role="assistant",
content=forwarded_text,
id=llm_gen_data.id,
interrupted=speech_handle.interrupted,
created_at=reply_started_at,
metrics=assistant_metrics,
**extra_kwargs,
)
self._agent._chat_ctx.insert(msg)
self._session._conversation_item_added(msg)
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/voice/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class _LLMGenerationData:
function_ch: aio.Chan[llm.FunctionCall]
generated_text: str = ""
generated_functions: list[llm.FunctionCall] = field(default_factory=list)
generated_extra: dict[str, Any] = field(default_factory=dict)
id: str = field(default_factory=lambda: utils.shortuuid("item_"))
started_fut: asyncio.Future[None] = field(default_factory=asyncio.Future)
ttft: float | None = None
Expand Down Expand Up @@ -176,6 +177,9 @@ async def _llm_inference_task(
data.generated_functions.append(fnc_call)
function_ch.send_nowait(fnc_call)

if chunk.delta.extra:
data.generated_extra.update(chunk.delta.extra)

if chunk.delta.content:
data.generated_text += chunk.delta.content
text_ch.send_nowait(chunk.delta.content)
Expand Down
Loading