Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-realtime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
AZURE_OPENAI_API_VERSION: ${{ secrets.AZURE_OPENAI_API_VERSION }}
# XAI_API_KEY not set — xAI rate-limits GitHub Actions IPs (429 on ws handshake)
run: |
uv run pytest -v tests/test_realtime/ --tb=long -p no:xdist
uv run pytest -v tests/test_realtime/ -s --tb=long -p no:xdist
18 changes: 12 additions & 6 deletions livekit-agents/livekit/agents/beta/workflows/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __init__(
return_exceptions (bool): Whether or not to directly propagate an error. When set to True, the exception is added to the results dictionary and the sequence continues. Defaults to False.
on_task_completed (Callable[]): A callable that executes upon each task completion. The callback takes in a single argument of a TaskCompletedEvent.
"""
super().__init__(instructions="*empty*", chat_ctx=chat_ctx, llm=None)
super().__init__(
instructions="*empty*", chat_ctx=chat_ctx, llm=NOT_GIVEN
) # the LLM is set as NOT_GIVEN to allow session reusage if supported

self._summarize_chat_ctx = summarize_chat_ctx
self._return_exceptions = return_exceptions
Expand Down Expand Up @@ -137,9 +139,11 @@ async def on_enter(self) -> None:
self.complete(e)
return

try:
if self._summarize_chat_ctx:
assert isinstance(self.session.llm, llm.LLM)
if self._summarize_chat_ctx:
try:
assert isinstance(self.session.llm, llm.LLM), (
"llm must be a LLM instance to summarize the chat_ctx"
)

# when a task is done, the chat_ctx is going to be merged with the "caller" chat_ctx
# enabling summarization will result on only one ChatMessage added.
Expand All @@ -153,8 +157,10 @@ async def on_enter(self) -> None:
)._summarize(llm_v=self.session.llm, keep_last_turns=0)

await self.update_chat_ctx(summarized_chat_ctx)
except Exception as e:
self.complete(RuntimeError(f"failed to summarize the chat_ctx: {e}"))
except Exception as e:
self.complete(e)
return

self.complete(TaskGroupResult(task_results=task_results))

def _build_out_of_scope_tool(self, *, active_task_id: str) -> FunctionTool | None:
Expand Down
36 changes: 35 additions & 1 deletion livekit-agents/livekit/agents/llm/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from livekit import rtc

from ..log import logger
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import is_given
from .chat_context import ChatContext, ChatItem, FunctionCall
from .tool_context import Tool, ToolChoice, ToolContext

Expand Down Expand Up @@ -62,7 +64,14 @@ class RealtimeCapabilities:
auto_tool_reply_generation: bool
audio_output: bool
manual_function_calls: bool
per_response_tool_choice: bool
mutable_chat_context: bool = False
"""Whether the chat context can be updated mid-session"""
mutable_instructions: bool = False
"""Whether the instructions can be updated mid-session"""
mutable_tools: bool = False
"""Whether the tools can be updated mid-session"""
per_response_tool_choice: bool = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should find a better name for per_response_tool_choice

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about mutable_*: mutable_chat_context, mutable_instructions, mutable_tools, and mutable_tool_choice?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i second mutable_*!

"""Whether the tool and tool choice can be specified per response"""


class RealtimeError(Exception):
Expand Down Expand Up @@ -237,6 +246,31 @@ def truncate(
@abstractmethod
async def aclose(self) -> None: ...

async def _update_session(
self,
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
chat_ctx: NotGivenOr[ChatContext] = NOT_GIVEN,
tools: NotGivenOr[list[Tool]] = NOT_GIVEN,
) -> None:
if is_given(instructions):
try:
await self.update_instructions(instructions)
except RealtimeError:
logger.exception("failed to update the instructions")

if is_given(chat_ctx):
try:
await self.update_chat_ctx(chat_ctx)
except RealtimeError:
logger.exception("failed to update the chat_ctx")

if is_given(tools):
try:
await self.update_tools(tools)
except RealtimeError:
logger.exception("failed to update the tools")

def start_user_activity(self) -> None:
"""notifies the model that user activity has started"""
pass
193 changes: 150 additions & 43 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ class _OnEnterData:
_OnEnterContextVar = contextvars.ContextVar["_OnEnterData"]("agents_activity_on_enter")


@dataclass
class _ReusableResources:
stt_pipeline: _STTPipeline | None = None
rt_session: llm.RealtimeSession | None = None

async def cleanup(self) -> None:
tasks = []
if self.stt_pipeline is not None:
tasks.append(self.stt_pipeline.aclose())
self.stt_pipeline = None
if self.rt_session is not None:
tasks.append(self.rt_session.aclose())
self.rt_session = None

if tasks:
outputs = await asyncio.gather(*tasks, return_exceptions=True)
for output in outputs:
if isinstance(output, Exception):
logger.error("error cleaning up reusable resources", exc_info=output)


@dataclass
class _PreemptiveGeneration:
speech_handle: SpeechHandle
Expand Down Expand Up @@ -511,7 +532,7 @@ def _mark_done_if_needed(_: asyncio.Task) -> None:

return task

async def start(self, *, reuse_stt_pipeline: _STTPipeline | None = None) -> None:
async def start(self, *, reuse_resources: _ReusableResources | None = None) -> None:
# `start` must only be called by AgentSession

async with self._lock:
Expand All @@ -536,7 +557,7 @@ async def start(self, *, reuse_stt_pipeline: _STTPipeline | None = None) -> None
self.tts.prewarm()

# don't use start_span for _start_session, avoid nested user/assistant turns
await self._start_session(reuse_stt_pipeline=reuse_stt_pipeline)
await self._start_session(reuse_resources=reuse_resources)
self._started = True

@tracer.start_as_current_span(
Expand All @@ -560,24 +581,72 @@ async def _traceable_on_enter() -> None:
finally:
start_span.end()

async def _detach_stt_pipeline_if_reusable(
self, new_activity: AgentActivity
) -> _STTPipeline | None:
"""Detach and return the STT pipeline if it can be handed off to *new_activity*.
async def _detach_reusable_resources(self, new_activity: AgentActivity) -> _ReusableResources:
"""Detach reusable resources for handoff to *new_activity*."""
resources = _ReusableResources()

Requires the same STT instance and the same stt_node implementation.
"""
if (
self._audio_recognition
and self.stt is not None
and type(self.agent).stt_node is type(new_activity.agent).stt_node
and self.stt is new_activity.stt
):
return await self._audio_recognition.detach_stt()
try:
# stt pipeline
if (
self._audio_recognition
and self.stt is not None
and type(self.agent).stt_node is type(new_activity.agent).stt_node
and self.stt is new_activity.stt
):
resources.stt_pipeline = await self._audio_recognition.detach_stt()

return None
# rt session
if (
self._rt_session is not None
and isinstance(self.llm, llm.RealtimeModel)
and self.llm is new_activity.llm
):
# context update is supported or chat context is equivalent
reusable = self.llm.capabilities.mutable_chat_context or (
self._rt_session.chat_ctx.copy(
exclude_instructions=True, exclude_handoff=True, exclude_config_update=True
).is_equivalent(
new_activity.agent.chat_ctx.copy(
exclude_instructions=True,
exclude_handoff=True,
exclude_config_update=True,
)
)
)
# instructions update is supported or instructions are the same
reusable = reusable and (
self.llm.capabilities.mutable_instructions
or self.agent.instructions == new_activity.agent.instructions
)
# tools update is supported or tools are the same
reusable = reusable and (
self.llm.capabilities.mutable_tools
or llm.ToolContext(self.tools) == llm.ToolContext(new_activity.tools)
)

async def _start_session(self, *, reuse_stt_pipeline: _STTPipeline | None = None) -> None:
if reusable:
# detach: remove event listeners but don't close the session
self._rt_session.off("generation_created", self._on_generation_created)
self._rt_session.off("input_speech_started", self._on_input_speech_started)
self._rt_session.off("input_speech_stopped", self._on_input_speech_stopped)
self._rt_session.off(
"input_audio_transcription_completed",
self._on_input_audio_transcription_completed,
)
self._rt_session.off("metrics_collected", self._on_metrics_collected)
self._rt_session.off("remote_item_added", self._on_remote_item_added)
self._rt_session.off("error", self._on_error)
resources.rt_session = self._rt_session
self._rt_session = None # prevent _close_session from closing it

except Exception:
# avoid leaking resources
await resources.cleanup()
raise

return resources

async def _start_session(self, *, reuse_resources: _ReusableResources | None = None) -> None:
assert self._lock.locked(), "_start_session should only be used when locked."

if isinstance(self.llm, llm.LLM):
Expand Down Expand Up @@ -625,7 +694,19 @@ async def _setup_toolset(toolset: llm.Toolset) -> None:
)

if isinstance(self.llm, llm.RealtimeModel):
self._rt_session = self.llm.session()
rt_reused = reuse_resources is not None and reuse_resources.rt_session is not None
if rt_reused:
assert reuse_resources and reuse_resources.rt_session is not None
logger.debug("reusing realtime session from previous activity")
self._rt_session = reuse_resources.rt_session
reuse_resources.rt_session = None # ownership transferred

# clear any stale audio/generation state
self._rt_session.interrupt()
self._rt_session.clear_audio()
else:
self._rt_session = self.llm.session()

self._rt_session.on("generation_created", self._on_generation_created)
self._rt_session.on("input_speech_started", self._on_input_speech_started)
self._rt_session.on("input_speech_stopped", self._on_input_speech_stopped)
Expand All @@ -639,27 +720,23 @@ async def _setup_toolset(toolset: llm.Toolset) -> None:

remove_instructions(self._agent._chat_ctx)

try:
await self._rt_session.update_instructions(self._agent.instructions)
except llm.RealtimeError:
logger.exception("failed to update the instructions")

try:
await self._rt_session.update_chat_ctx(self._agent.chat_ctx)
except llm.RealtimeError:
logger.exception("failed to update the chat_ctx")

try:
await self._rt_session.update_tools(llm.ToolContext(self.tools).flatten())
except llm.RealtimeError:
logger.exception("failed to update the tools")
capabilities = self.llm.capabilities
reset_instructions = reset_chat_ctx = reset_tools = True
if rt_reused:
# skip the update if the session is reused and no mid-session update is supported
# this means the content is the same as the previous session
reset_instructions = capabilities.mutable_instructions
reset_chat_ctx = capabilities.mutable_chat_context
reset_tools = capabilities.mutable_tools

await self._rt_session._update_session(
instructions=self._agent.instructions if reset_instructions else NOT_GIVEN,
chat_ctx=self._agent.chat_ctx if reset_chat_ctx else NOT_GIVEN,
tools=llm.ToolContext(self.tools).flatten() if reset_tools else NOT_GIVEN,
)

self._realtime_spans = utils.BoundedDict[str, trace.Span](maxsize=100)
if (
not self.llm.capabilities.audio_output
and not self.tts
and self._session.output.audio
):
if not capabilities.audio_output and not self.tts and self._session.output.audio:
logger.error(
"audio output is enabled but RealtimeModel has no audio modality "
"and no TTS is set. Either enable audio modality in the RealtimeModel "
Expand Down Expand Up @@ -697,14 +774,17 @@ async def _setup_toolset(toolset: llm.Toolset) -> None:
stt_model=self.stt.model if self.stt else None,
stt_provider=self.stt.provider if self.stt else None,
)
if reuse_stt_pipeline is not None:
if reuse_resources and reuse_resources.stt_pipeline is not None:
logger.debug("reusing STT pipeline from previous activity")
self._audio_recognition.start(stt_pipeline=reuse_stt_pipeline)
self._audio_recognition.start(stt_pipeline=reuse_resources.stt_pipeline)
reuse_resources.stt_pipeline = None # ownership transferred
else:
self._audio_recognition.start()

@tracer.start_as_current_span("drain_agent_activity")
async def drain(self) -> None:
async def drain(
self, *, new_activity: AgentActivity | None = None
) -> _ReusableResources | None:
# `drain` must only be called by AgentSession
# AgentSession makes sure there is always one agent available to the users.
current_span = trace.get_current_span()
Expand Down Expand Up @@ -733,6 +813,15 @@ async def _traceable_on_exit() -> None:

await self._pause_scheduling_task()

# detach after speech tasks are done but before _close_session
if new_activity is not None:
try:
return await self._detach_reusable_resources(new_activity)
except BaseException:
logger.exception("failed to detach reusable resources")

return None

async def _pause_scheduling_task(
self, *, blocked_tasks: list[asyncio.Task] | None = None
) -> None:
Expand Down Expand Up @@ -762,7 +851,7 @@ async def _resume_scheduling_task(self) -> None:
self._scheduling_task(), name="_scheduling_task"
)

async def resume(self, *, reuse_stt_pipeline: _STTPipeline | None = None) -> None:
async def resume(self, *, reuse_resources: _ReusableResources | None = None) -> None:
# `resume` must only be called by AgentSession

async with self._lock:
Expand All @@ -771,14 +860,19 @@ async def resume(self, *, reuse_stt_pipeline: _STTPipeline | None = None) -> Non
attributes={trace_types.ATTR_AGENT_LABEL: self.agent.label},
)
try:
await self._start_session(reuse_stt_pipeline=reuse_stt_pipeline)
await self._start_session(reuse_resources=reuse_resources)
finally:
span.end()

def _wake_up_scheduling_task(self) -> None:
self._q_updated.set()

async def pause(self, *, blocked_tasks: list[asyncio.Task]) -> None:
async def pause(
self,
*,
blocked_tasks: list[asyncio.Task],
new_activity: AgentActivity | None = None,
) -> _ReusableResources | None:
# `pause` must only be called by AgentSession

# When draining, the tasks that have done the "premption" must be ignored.
Expand All @@ -791,12 +885,25 @@ async def pause(self, *, blocked_tasks: list[asyncio.Task]) -> None:
"pause_agent_activity",
attributes={trace_types.ATTR_AGENT_LABEL: self._agent.label},
)

resources: _ReusableResources | None = None
try:
await self._pause_scheduling_task(blocked_tasks=blocked_tasks)

# detach after speech tasks are done but before _close_session
if new_activity is not None:
resources = await self._detach_reusable_resources(new_activity)

await self._close_session()
except BaseException:
if resources is not None:
await resources.cleanup()
raise
finally:
span.end()

return resources

async def _close_session(self) -> None:
assert self._lock.locked(), "_close_session should only be used when locked."

Expand Down
Loading
Loading