Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ async def _send_session_update(self, msg: dict[str, Any]) -> None:

async def _run(self) -> None:
"""Main loop for streaming transcription."""
# reset per-session state since the gateway starts a fresh session on reconnect
self._speech_duration = 0
self._speaking = False
closing_ws = False

@utils.log_exceptions(logger=logger)
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/inference/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
health_check_cb=self._check_ws_health,
max_session_duration=300,
mark_refreshed_on_get=True,
)
Expand Down Expand Up @@ -485,6 +486,9 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:

return ws

async def _check_ws_health(self, ws: aiohttp.ClientWebSocketResponse) -> bool:
return not ws.closed

async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
await ws.close()

Expand Down
35 changes: 34 additions & 1 deletion livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import collections
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator
Expand Down Expand Up @@ -310,6 +311,14 @@ def __init__(
self._pushed_sr = 0
self._resampler: rtc.AudioResampler | None = None

# input buffer for replay on reconnect
self._input_buffer: collections.deque[rtc.AudioFrame | RecognizeStream._FlushSentinel] = (
collections.deque()
)
self._input_buffer_duration: float = 0.0
self._max_buffer_duration: float = 10.0 # seconds of audio to retain
self._input_ended: bool = False

self._start_time_offset: float = 0.0

@property
Expand Down Expand Up @@ -339,6 +348,17 @@ def _report_connection_acquired(self, acquire_time: float, connection_reused: bo
),
)

def _append_to_buffer(self, item: rtc.AudioFrame | RecognizeStream._FlushSentinel) -> None:
"""Append a frame or flush sentinel to the input buffer, evicting old entries."""
self._input_buffer.append(item)
if isinstance(item, rtc.AudioFrame):
self._input_buffer_duration += item.duration
# evict oldest entries when over max duration
while self._input_buffer_duration > self._max_buffer_duration and self._input_buffer:
oldest = self._input_buffer.popleft()
if isinstance(oldest, rtc.AudioFrame):
self._input_buffer_duration -= oldest.duration

@abstractmethod
async def _run(self) -> None: ...

Expand Down Expand Up @@ -376,6 +396,13 @@ async def _main_task(self) -> None:
)
await asyncio.sleep(retry_interval)

# replay buffered input into a fresh channel for retry
self._input_ch = aio.Chan[rtc.AudioFrame | RecognizeStream._FlushSentinel]()
for item in self._input_buffer:
self._input_ch.send_nowait(item)
if self._input_ended:
self._input_ch.close()

self._num_retries += 1

except Exception as e:
Expand Down Expand Up @@ -443,8 +470,10 @@ def push_frame(self, frame: rtc.AudioFrame) -> None:
frames = self._resampler.push(frame)
for frame in frames:
self._input_ch.send_nowait(frame)
self._append_to_buffer(frame)
else:
self._input_ch.send_nowait(frame)
self._append_to_buffer(frame)

def flush(self) -> None:
"""Mark the end of the current segment"""
Expand All @@ -454,13 +483,17 @@ def flush(self) -> None:
if self._resampler:
for frame in self._resampler.flush():
self._input_ch.send_nowait(frame)
self._append_to_buffer(frame)

self._input_ch.send_nowait(self._FlushSentinel())
sentinel = self._FlushSentinel()
self._input_ch.send_nowait(sentinel)
self._append_to_buffer(sentinel)

def end_input(self) -> None:
"""Mark the end of input, no more audio will be pushed"""
self.flush()
self._input_ch.close()
self._input_ended = True

async def aclose(self) -> None:
"""Close ths stream immediately"""
Expand Down
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ async def _main_task(self) -> None:
pushed_duration = output_emitter.pushed_duration()
should_retry = (
e.retryable
and pushed_duration == 0.0
and self._conn_options.max_retry > 0
and i < self._conn_options.max_retry
and (pushed_duration == 0.0 or self._conn_options.tts_replay_on_partial)
)

if not should_retry:
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class APIConnectOptions:
Timeout for connecting to the API in seconds.
"""

tts_replay_on_partial: bool = False
"""If True, TTS will retry and replay all tokens even after partial audio was emitted."""

def __post_init__(self) -> None:
if self.max_retry < 0:
raise ValueError("max_retry must be greater than or equal to 0")
Expand Down
34 changes: 25 additions & 9 deletions livekit-agents/livekit/agents/utils/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
mark_refreshed_on_get: bool = False,
connect_cb: Callable[[float], Awaitable[T]] | None = None,
close_cb: Callable[[T], Awaitable[None]] | None = None,
health_check_cb: Callable[[T], Awaitable[bool]] | None = None,
connect_timeout: float = 10.0,
) -> None:
"""Initialize the connection wrapper.
Expand All @@ -34,11 +35,13 @@ def __init__(
mark_refreshed_on_get: If True, the session will be marked as fresh when get() is called. only used when max_session_duration is set.
connect_cb: Optional async callback to create new connections
close_cb: Optional async callback to close connections
health_check_cb: Optional async callback to verify a pooled connection is still alive
""" # noqa: E501
self._max_session_duration = max_session_duration
self._mark_refreshed_on_get = mark_refreshed_on_get
self._connect_cb = connect_cb
self._close_cb = close_cb
self._health_check_cb = health_check_cb
self._connections: dict[T, float] = {} # conn -> connected_at timestamp
self._available: set[T] = set()
self._connect_timeout = connect_timeout
Expand Down Expand Up @@ -107,16 +110,29 @@ async def get(self, *, timeout: float) -> T:
while self._available:
conn = self._available.pop()
if (
self._max_session_duration is None
or now - self._connections[conn] <= self._max_session_duration
self._max_session_duration is not None
and now - self._connections[conn] > self._max_session_duration
):
if self._mark_refreshed_on_get:
self._connections[conn] = now
self.last_acquire_time = 0.0
self.last_connection_reused = True
return conn
# connection expired; mark it for resetting.
self.remove(conn)
# connection expired; mark it for resetting.
self.remove(conn)
continue

# health check pooled connection before reuse
if self._health_check_cb is not None:
try:
healthy = await asyncio.wait_for(self._health_check_cb(conn), timeout=2.0)
if not healthy:
self.remove(conn)
continue
except Exception:
self.remove(conn)
continue

if self._mark_refreshed_on_get:
self._connections[conn] = now
self.last_acquire_time = 0.0
self.last_connection_reused = True
return conn

t0 = time.perf_counter()
conn = await self._connect(timeout)
Expand Down
101 changes: 101 additions & 0 deletions tests/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time

import pytest
Expand All @@ -8,6 +9,7 @@
class DummyConnection:
def __init__(self, id):
self.id = id
self.healthy = True

def __repr__(self):
return f"DummyConnection({self.id})"
Expand Down Expand Up @@ -81,3 +83,102 @@ async def test_get_expired():

conn2 = await pool.get(timeout=10.0)
assert conn2 is not conn, "Expected a new connection to be returned."


@pytest.mark.asyncio
async def test_health_check_healthy_connection():
"""Health check passes — pooled connection is reused."""
dummy_connect = dummy_connect_factory()

async def health_check(conn: DummyConnection) -> bool:
return conn.healthy

pool = ConnectionPool(
max_session_duration=60,
connect_cb=dummy_connect,
health_check_cb=health_check,
)

conn1 = await pool.get(timeout=10.0)
pool.put(conn1)

conn2 = await pool.get(timeout=10.0)
assert conn1 is conn2, "Healthy connection should be reused."


@pytest.mark.asyncio
async def test_health_check_unhealthy_connection():
"""Health check fails — pooled connection is removed and a new one is created."""
dummy_connect = dummy_connect_factory()

async def health_check(conn: DummyConnection) -> bool:
return conn.healthy

pool = ConnectionPool(
max_session_duration=60,
connect_cb=dummy_connect,
health_check_cb=health_check,
)

conn1 = await pool.get(timeout=10.0)
conn1.healthy = False
pool.put(conn1)

conn2 = await pool.get(timeout=10.0)
assert conn2 is not conn1, "Unhealthy connection should not be reused."


@pytest.mark.asyncio
async def test_health_check_timeout():
"""Health check that times out causes connection to be removed."""
dummy_connect = dummy_connect_factory()

async def slow_health_check(conn: DummyConnection) -> bool:
await asyncio.sleep(10) # longer than the 2s timeout
return True

pool = ConnectionPool(
max_session_duration=60,
connect_cb=dummy_connect,
health_check_cb=slow_health_check,
)

conn1 = await pool.get(timeout=10.0)
pool.put(conn1)

conn2 = await pool.get(timeout=10.0)
assert conn2 is not conn1, "Connection with timed-out health check should not be reused."


@pytest.mark.asyncio
async def test_health_check_exception():
"""Health check that raises an exception causes connection to be removed."""
dummy_connect = dummy_connect_factory()

async def failing_health_check(conn: DummyConnection) -> bool:
raise RuntimeError("health check failed")

pool = ConnectionPool(
max_session_duration=60,
connect_cb=dummy_connect,
health_check_cb=failing_health_check,
)

conn1 = await pool.get(timeout=10.0)
pool.put(conn1)

conn2 = await pool.get(timeout=10.0)
assert conn2 is not conn1, "Connection with failing health check should not be reused."


@pytest.mark.asyncio
async def test_no_health_check_by_default():
"""Without health_check_cb, connections are reused without checks."""
dummy_connect = dummy_connect_factory()
pool = ConnectionPool(max_session_duration=60, connect_cb=dummy_connect)

conn1 = await pool.get(timeout=10.0)
pool.put(conn1)

conn2 = await pool.get(timeout=10.0)
assert conn1 is conn2, "Without health check, connection should be reused."
Loading
Loading