diff --git a/.env.example b/.env.example index 2e23b05..589d71c 100644 --- a/.env.example +++ b/.env.example @@ -6,9 +6,13 @@ REDIS_HOST=127.0.0.1 REDIS_PORT=6789 REDIS_PASSWORD= -# STT Provider: "gladia" (default) +# STT Provider: "gladia" (default) or "openai" # STT_PROVIDER=gladia +# ============================================================================= +# --- Gladia STT (STT_PROVIDER=gladia) --- +# ============================================================================= + GLADIA_API_KEY= # The following env vars serves as a translation locale mapper between # (Gladia) and - (BBB) locale formats. @@ -58,3 +62,17 @@ GLADIA_TRANSLATION_LANG_MAP="de:de-DE,en:en-US,es:es-ES,fr:fr-FR,hi:hi-IN,it:it- #GLADIA_PRE_PROCESSING_AUDIO_ENHANCER=false #GLADIA_PRE_PROCESSING_SPEECH_THRESHOLD=0.5 + +# ============================================================================= +# --- OpenAI STT (STT_PROVIDER=openai) --- +# Supports the official OpenAI API and any OpenAI-compatible endpoint. +# ============================================================================= + +# OpenAI API key (required) +#OPENAI_API_KEY= + +# Transcription model (default: gpt-4o-transcribe; use "whisper-1" for classic Whisper) +#OPENAI_STT_MODEL=gpt-4o-transcribe + +# Base URL override — set this to use a compatible provider (e.g. a local Whisper server) +#OPENAI_BASE_URL= diff --git a/CHANGELOG.md b/CHANGELOG.md index 26aedf8..fa8bbf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Final releases will consolidate all intermediate changes in chronological order. ## UNRELEASED +* feat(openai): add OpenAI STT provider support (official and compatible endpoints) * feat: add GladiaSttAgent provider and factory * refactor: move GladiaConfig to providers package, delete old agent module * feat(tests): add unit and integration tests with pytest diff --git a/README.md b/README.md index 4327e60..30a4bad 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@ This application provides Speech-to-Text (STT) for BigBlueButton meetings using LiveKit as their audio bridge. -Initially, the only supported STT engine is Gladia through the official [LiveKit Gladia Plugin](https://docs.livekit.io/agents/integrations/stt/gladia/). +Supported STT engines: -It'll be expanded in the future to support other STT plugins from the LiveKit Agents -ecosystem. +- **Gladia** — via the official [LiveKit Gladia plugin](https://docs.livekit.io/agents/integrations/stt/gladia/) (default) +- **OpenAI** — via the [LiveKit OpenAI plugin](https://docs.livekit.io/agents/models/stt/openai/); supports the official OpenAI API and any OpenAI-compatible endpoint ## Getting Started @@ -14,7 +14,7 @@ ecosystem. - Python 3.10+ - A LiveKit instance -- A Gladia API key +- A Gladia API key **or** an OpenAI API key (depending on your chosen STT provider) - uv: - See installation instructions: https://docs.astral.sh/uv/getting-started/installation/ @@ -48,13 +48,17 @@ ecosystem. LIVEKIT_API_KEY=... LIVEKIT_API_SECRET=... - # Gladia API Key + # For Gladia (default provider): GLADIA_API_KEY=... + + # For OpenAI (set STT_PROVIDER=openai): + # STT_PROVIDER=openai + # OPENAI_API_KEY=... ``` Feel free to check `.env.example` for any other configurations of interest. - **All options ingested by the Gladia STT plugin are exposed via env vars**. + **All options ingested by the Gladia and OpenAI STT plugins are exposed via env vars**. ### Running @@ -98,6 +102,30 @@ docker run --network host --rm -it --env-file .env bbb-livekit-stt Pre-built images are available via GitHub Container Registry as well. +### OpenAI STT provider + +Set `STT_PROVIDER=openai` to use OpenAI STT instead of Gladia. + +**Official OpenAI API:** + +```bash +STT_PROVIDER=openai +OPENAI_API_KEY=your-key +# OPENAI_STT_MODEL=gpt-4o-transcribe # default; use "whisper-1" for classic Whisper +``` + +**OpenAI-compatible endpoint** (e.g. a self-hosted Whisper server): + +```bash +STT_PROVIDER=openai +OPENAI_API_KEY=any-value +OPENAI_BASE_URL=http://your-server:8000 +OPENAI_STT_MODEL=your-model-name +``` + +> **Note**: OpenAI STT does not support real-time translation. Only the original +> transcript language is returned, matching the user's BBB speech locale. + ### Development #### Testing @@ -114,12 +142,20 @@ Run with coverage: uv run pytest tests/ --ignore=tests/integration --cov --cov-report=term-missing ``` -Integration tests require a real Gladia API key and make live requests to the Gladia service. Set `GLADIA_API_KEY` and run: +Integration tests require a real API key and make live requests to the STT service. + +For Gladia, set `GLADIA_API_KEY` and run: ```bash GLADIA_API_KEY=your-key uv run pytest tests/integration -m integration ``` +For OpenAI, set `OPENAI_API_KEY` and run: + +```bash +OPENAI_API_KEY=your-key uv run pytest tests/integration -m integration +``` + #### Linting This project uses [ruff](https://docs.astral.sh/ruff/) for linting and formatting. To check for issues: diff --git a/main.py b/main.py index 58621ba..414e285 100644 --- a/main.py +++ b/main.py @@ -49,7 +49,7 @@ async def on_redis_message(message_data: str): meeting_id = routing.get("meetingId") user_id = routing.get("userId") - if meeting_id != agent.room.name: + if agent.room is None or meeting_id != agent.room.name: return if event_name == RedisManager.USER_SPEECH_LOCALE_CHANGED_EVT_MSG: @@ -102,7 +102,8 @@ async def on_final_transcript( original_lang = original_locale.split("-")[0] for alternative in event.alternatives: - transcript_lang = alternative.language + # Some providers (e.g. OpenAI) may not report a language; fall back to original. + transcript_lang = alternative.language or original_lang text = alternative.text bbb_locale = None start_time_adjusted = math.floor(open_time + alternative.start_time) @@ -171,7 +172,8 @@ async def on_interim_transcript( min_utterance_length = p_settings.get("min_utterance_length", 0) for alternative in event.alternatives: - transcript_lang = alternative.language + # Some providers (e.g. OpenAI) may not report a language; fall back to original. + transcript_lang = alternative.language or original_lang text = alternative.text start_time_adjusted = math.floor(open_time + alternative.start_time) end_time_adjusted = math.floor(open_time + alternative.end_time) diff --git a/providers/__init__.py b/providers/__init__.py index e85919a..03d0b9f 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -6,4 +6,8 @@ def create_agent(provider: str) -> BaseSttAgent: from providers.gladia import GladiaSttAgent, gladia_config return GladiaSttAgent(gladia_config) + if provider == "openai": + from providers.openai import OpenAiSttAgent, openai_config + + return OpenAiSttAgent(openai_config) raise ValueError(f"Unknown STT provider: {provider}") diff --git a/providers/openai.py b/providers/openai.py new file mode 100644 index 0000000..754731e --- /dev/null +++ b/providers/openai.py @@ -0,0 +1,244 @@ +import asyncio +import logging +import os +import time +from dataclasses import dataclass, field + +import aiohttp +import numpy as np +from livekit import rtc +from livekit.agents import stt + +from providers.base import BaseSttAgent, BaseSttConfig + +# Energy-based voice activity detection parameters. +# RMS threshold (int16 scale 0–32768): frames below this are considered silence. +_SILENCE_THRESHOLD_RMS = 500 +# Seconds of continuous silence after speech before the segment is flushed. +_SILENCE_DURATION_S = 0.8 +# Maximum segment duration before a forced flush (prevents unbounded buffering). +_MAX_BUFFER_DURATION_S = 12.0 + + +@dataclass +class OpenAiConfig(BaseSttConfig): + api_key: str | None = field(default_factory=lambda: os.getenv("OPENAI_API_KEY")) + model: str = field( + default_factory=lambda: os.getenv("OPENAI_STT_MODEL", "gpt-4o-transcribe") + ) + base_url: str | None = field( + default_factory=lambda: os.getenv("OPENAI_BASE_URL", None) + ) + + +openai_config = OpenAiConfig() + + +class OpenAiSttAgent(BaseSttAgent): + def __init__(self, config: OpenAiConfig): + super().__init__(config) + self._http_session: aiohttp.ClientSession | None = None + + # --- HTTP session management --- + + def _get_http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=30) + ) + return self._http_session + + async def _transcribe_wav(self, wav_bytes: bytes, language: str) -> str: + """Call the OpenAI-compatible REST endpoint directly. + + Constructs the URL as ``{base_url}/v1/audio/transcriptions`` so that + custom backends (e.g. ``http://my-server/api/``) work correctly + regardless of how the OpenAI SDK would handle the ``/v1`` path segment. + """ + base_url = (self.config.base_url or "https://api.openai.com").rstrip("/") + url = f"{base_url}/v1/audio/transcriptions" + + form = aiohttp.FormData() + form.add_field( + "file", wav_bytes, filename="audio.wav", content_type="audio/wav" + ) + form.add_field("model", self.config.model) + form.add_field("response_format", "json") + if language: + form.add_field("language", language) + + headers = {"Authorization": f"Bearer {self.config.api_key}"} + session = self._get_http_session() + async with session.post(url, data=form, headers=headers) as resp: + resp.raise_for_status() + result = await resp.json() + return result.get("text", "").strip() + + # --- BaseSttAgent abstract method implementations --- + + def _create_stt_stream(self, locale: str) -> stt.SpeechStream: # type: ignore[override] + """Not used: REST mode overrides start_transcription_for_user directly.""" + raise NotImplementedError("OpenAI REST mode does not use STT streams") + + def _update_stream_locale(self, user_id: str, locale: str): + """Restart the pipeline with the new locale.""" + provider = self.participant_settings.get(user_id, {}).get("provider", "openai") + self.stop_transcription_for_user(user_id) + self.start_transcription_for_user(user_id, locale, provider) + + # --- Override start_transcription_for_user to pass language, not a stream --- + + def start_transcription_for_user(self, user_id: str, locale: str, provider: str): + settings = self.participant_settings.setdefault(user_id, {}) + settings["locale"] = locale + settings["provider"] = provider + + participant = self._find_participant(user_id) + if not participant: + logging.error( + f"Cannot start transcription, participant {user_id} not found." + ) + return + + track = self._find_audio_track(participant) + if not track: + logging.warning( + f"Won't start transcription yet, no audio track found for {user_id}." + ) + return + + if participant.identity in self.processing_info: + logging.debug( + f"Transcription task already running for {participant.identity}, ignoring start command." + ) + return + + language = self._sanitize_locale(locale) + task = asyncio.create_task( + self._run_transcription_pipeline(participant, track, language) + ) + self.processing_info[participant.identity] = {"task": task} + logging.info( + f"Started transcription for {participant.identity} with locale {locale}." + ) + + # --- Override _cleanup to close HTTP session --- + + async def _cleanup(self): + await super()._cleanup() + if self._http_session: + await self._http_session.close() + self._http_session = None + + # --- REST-based transcription pipeline --- + + async def _run_transcription_pipeline( # type: ignore[override] + self, + participant: rtc.RemoteParticipant, + track: rtc.Track, + language: str, + ): + """Collect audio, segment by silence, and transcribe via REST API. + + The OpenAI plugin's stream() uses the Realtime WebSocket API which is + not implemented by all OpenAI-compatible backends. This implementation + uses energy-based silence detection (RMS threshold) to segment audio + into utterances, then calls the standard REST /audio/transcriptions + endpoint for each segment. + + TODO: Support the Realtime WebSocket endpoint as an opt-in mode (e.g. + via an ``OpenAiConfig`` flag). When enabled, delegate to the livekit + openai plugin's stream() directly. This would unlock lower latency for + backends that implement the OpenAI Realtime API. + """ + audio_stream = rtc.AudioStream(track) + open_time = time.time() + self.open_time = open_time + + speech_buffer: list[rtc.AudioFrame] = [] + buffer_duration = 0.0 + silence_duration = 0.0 + was_speaking = False + speech_start_time = 0.0 + + async def flush_segment(frames: list[rtc.AudioFrame], seg_start: float) -> None: + if not frames: + return + try: + wav_bytes = rtc.combine_audio_frames(frames).to_wav_bytes() + text = await self._transcribe_wav(wav_bytes, language) + if text: + seg_end = time.time() - open_time + event = stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[ + stt.SpeechData( + text=text, + language=language, + start_time=seg_start, + end_time=seg_end, + ) + ], + ) + self.emit( + "final_transcript", + participant=participant, + event=event, + open_time=open_time, + ) + except asyncio.CancelledError: + raise + except Exception as e: + logging.error( + f"Error transcribing segment for {participant.identity}: {e}" + ) + + try: + async for audio_event in audio_stream: + frame = audio_event.frame + samples = np.frombuffer(frame.data, dtype=np.int16) + rms = float(np.sqrt(np.mean(samples.astype(np.float32) ** 2))) + is_speaking = rms > _SILENCE_THRESHOLD_RMS + frame_duration = frame.samples_per_channel / frame.sample_rate + + if is_speaking: + if not was_speaking: + speech_start_time = time.time() - open_time + speech_buffer.append(frame) + buffer_duration += frame_duration + silence_duration = 0.0 + was_speaking = True + + if buffer_duration >= _MAX_BUFFER_DURATION_S: + # Safety flush: prevent unbounded buffer growth during + # continuous speech or sustained noise above the RMS threshold. + await flush_segment(speech_buffer[:], speech_start_time) + speech_buffer.clear() + buffer_duration = 0.0 + speech_start_time = time.time() - open_time + elif was_speaking: + # Carry silence frames so the segment has natural trailing audio. + speech_buffer.append(frame) + buffer_duration += frame_duration + silence_duration += frame_duration + + if ( + silence_duration >= _SILENCE_DURATION_S + or buffer_duration >= _MAX_BUFFER_DURATION_S + ): + await flush_segment(speech_buffer[:], speech_start_time) + speech_buffer.clear() + buffer_duration = 0.0 + silence_duration = 0.0 + was_speaking = False + + # Flush any remaining buffered speech at end of stream. + await flush_segment(speech_buffer[:], speech_start_time) + + except asyncio.CancelledError: + logging.info(f"Transcription for {participant.identity} was cancelled.") + except Exception as e: + logging.error(f"Error during transcription for track {track.sid}: {e}") + finally: + self.processing_info.pop(participant.identity, None) + await audio_stream.aclose() diff --git a/tests/integration/test_openai_stt.py b/tests/integration/test_openai_stt.py new file mode 100644 index 0000000..8b08a10 --- /dev/null +++ b/tests/integration/test_openai_stt.py @@ -0,0 +1,91 @@ +"""Integration tests for the OpenAI STT pipeline. + +These tests require a valid OPENAI_API_KEY environment variable and make real +requests to the OpenAI transcription service. They are skipped automatically +when the key is absent. +""" + +import os + +import numpy as np +import pytest +from livekit import rtc + +from providers.openai import OpenAiConfig, OpenAiSttAgent + +pytestmark = pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY environment variable is not set", +) + + +def _make_agent() -> OpenAiSttAgent: + config = OpenAiConfig( + api_key=os.environ["OPENAI_API_KEY"], + model=os.environ.get("OPENAI_STT_MODEL", "gpt-4o-transcribe"), + base_url=os.environ.get("OPENAI_BASE_URL"), + ) + return OpenAiSttAgent(config) + + +def _silent_wav_bytes(duration_s: float = 0.5, sample_rate: int = 16000) -> bytes: + """Return WAV bytes containing silent PCM audio.""" + num_samples = int(duration_s * sample_rate) + frame = rtc.AudioFrame( + data=bytes(num_samples * 2), + sample_rate=sample_rate, + num_channels=1, + samples_per_channel=num_samples, + ) + return frame.to_wav_bytes() + + +def _tone_wav_bytes( + duration_s: float = 0.5, sample_rate: int = 16000, freq_hz: float = 440.0 +) -> bytes: + """Return WAV bytes containing a sine-wave tone (not speech).""" + num_samples = int(duration_s * sample_rate) + t = np.linspace(0, duration_s, num_samples, endpoint=False) + samples = (np.sin(2 * np.pi * freq_hz * t) * 16000).astype(np.int16) + frame = rtc.AudioFrame( + data=samples.tobytes(), + sample_rate=sample_rate, + num_channels=1, + samples_per_channel=num_samples, + ) + return frame.to_wav_bytes() + + +@pytest.mark.integration +async def test_transcribe_wav_silent_audio_returns_empty(): + """Silent audio should return an empty or near-empty transcript.""" + agent = _make_agent() + try: + result = await agent._transcribe_wav(_silent_wav_bytes(), language="en") + assert isinstance(result, str) + # Silent audio may return empty or a minimal artefact — just no long text. + assert len(result) < 20, f"Unexpected transcript for silence: {result!r}" + finally: + await agent._cleanup() + + +@pytest.mark.integration +async def test_transcribe_wav_returns_string(): + """A tone clip should return a string (possibly empty) without raising.""" + agent = _make_agent() + try: + result = await agent._transcribe_wav(_tone_wav_bytes(), language="en") + assert isinstance(result, str) + finally: + await agent._cleanup() + + +@pytest.mark.integration +async def test_cleanup_closes_http_session(): + """_cleanup() should close the aiohttp session opened during a request.""" + agent = _make_agent() + # Trigger session creation via a real request + await agent._transcribe_wav(_silent_wav_bytes(), language="en") + assert agent._http_session is not None + await agent._cleanup() + assert agent._http_session is None diff --git a/tests/test_config.py b/tests/test_config.py index e40169a..2459461 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,7 @@ redact_config_values, ) from providers.gladia import GladiaConfig +from providers.openai import OpenAiConfig class TestGetBoolEnv: @@ -223,3 +224,33 @@ def test_min_confidence_interim_overrides_base(self, monkeypatch): config = GladiaConfig() assert config.min_confidence_interim == pytest.approx(0.2) assert config.min_confidence_final == pytest.approx(0.5) + + +class TestOpenAiConfigDefaults: + @pytest.fixture(autouse=True) + def _clean_openai_env(self, monkeypatch): + """Remove all OPENAI_* env vars so dataclass defaults are exercised.""" + for key in list(os.environ): + if key.startswith("OPENAI_"): + monkeypatch.delenv(key, raising=False) + + def test_model_defaults_to_gpt4o_transcribe(self): + assert OpenAiConfig().model == "gpt-4o-transcribe" + + def test_api_key_defaults_to_none(self): + assert OpenAiConfig().api_key is None + + def test_base_url_defaults_to_none(self): + assert OpenAiConfig().base_url is None + + def test_model_overridden_by_env_var(self, monkeypatch): + monkeypatch.setenv("OPENAI_STT_MODEL", "whisper-1") + assert OpenAiConfig().model == "whisper-1" + + def test_base_url_overridden_by_env_var(self, monkeypatch): + monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:8000") + assert OpenAiConfig().base_url == "http://localhost:8000" + + def test_api_key_overridden_by_env_var(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + assert OpenAiConfig().api_key == "sk-test" diff --git a/tests/test_openai_agent.py b/tests/test_openai_agent.py new file mode 100644 index 0000000..2e35a29 --- /dev/null +++ b/tests/test_openai_agent.py @@ -0,0 +1,291 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from livekit import rtc +from livekit.agents import stt + +from providers.openai import ( + OpenAiConfig, + OpenAiSttAgent, + _SILENCE_THRESHOLD_RMS, +) + + +def _make_agent(**kwargs): + config = OpenAiConfig(api_key="fake-key", **kwargs) + return OpenAiSttAgent(config) + + +def _make_agent_with_room(participants=None, **kwargs): + agent = _make_agent(**kwargs) + mock_room = MagicMock() + mock_room.remote_participants = participants or {} + agent.room = mock_room + return agent + + +def _make_participant(identity, audio_track=None): + participant = MagicMock(spec=rtc.RemoteParticipant) + participant.identity = identity + pubs = {} + if audio_track: + pub = MagicMock() + pub.track = audio_track + pub.track.kind = rtc.TrackKind.KIND_AUDIO + pubs["audio"] = pub + participant.track_publications = pubs + return participant + + +def _make_audio_event(amplitude: int = 0) -> MagicMock: + """Create a mock audio event with PCM bytes at the given amplitude.""" + samples = np.full(160, amplitude, dtype=np.int16) + event = MagicMock() + event.frame.data = samples.tobytes() + event.frame.sample_rate = 16000 + event.frame.samples_per_channel = 160 + return event + + +def _make_loud_event() -> MagicMock: + """Audio event with RMS energy above the speech threshold.""" + return _make_audio_event(amplitude=int(_SILENCE_THRESHOLD_RMS * 2)) + + +class TestOpenAiConfigDefaults: + @pytest.fixture(autouse=True) + def _clean_openai_env(self, monkeypatch): + for key in list(__import__("os").environ): + if key.startswith("OPENAI_"): + monkeypatch.delenv(key, raising=False) + + def test_model_defaults_to_gpt4o_transcribe(self): + assert OpenAiConfig().model == "gpt-4o-transcribe" + + def test_api_key_defaults_to_none(self): + assert OpenAiConfig().api_key is None + + def test_base_url_defaults_to_none(self): + assert OpenAiConfig().base_url is None + + +class TestUpdateLocaleForUser: + def test_updates_locale_in_participant_settings(self): + agent = _make_agent_with_room() + agent.participant_settings["user_1"] = {"locale": "en", "provider": "openai"} + + agent.update_locale_for_user("user_1", "fr") + + assert agent.participant_settings["user_1"]["locale"] == "fr" + + def test_restarts_transcription_when_active(self): + """OpenAI REST requires stop+restart to change locale.""" + mock_track = MagicMock() + mock_track.kind = rtc.TrackKind.KIND_AUDIO + participant = _make_participant("user_1", audio_track=mock_track) + agent = _make_agent_with_room(participants={"pid": participant}) + agent.participant_settings["user_1"] = {"locale": "en", "provider": "openai"} + agent.processing_info["user_1"] = {"task": MagicMock()} + + with ( + patch.object(agent, "stop_transcription_for_user") as mock_stop, + patch.object(agent, "start_transcription_for_user") as mock_start, + ): + agent.update_locale_for_user("user_1", "de") + + mock_stop.assert_called_once_with("user_1") + mock_start.assert_called_once_with("user_1", "de", "openai") + + def test_does_not_restart_when_no_active_transcription(self): + agent = _make_agent_with_room() + agent.participant_settings["user_1"] = {"locale": "en", "provider": "openai"} + + with ( + patch.object(agent, "stop_transcription_for_user") as mock_stop, + patch.object(agent, "start_transcription_for_user") as mock_start, + ): + agent.update_locale_for_user("user_1", "fr") + + mock_stop.assert_not_called() + mock_start.assert_not_called() + assert agent.participant_settings["user_1"]["locale"] == "fr" + + +class TestStartTranscriptionForUser: + async def test_passes_sanitized_locale_to_pipeline(self): + """Locale 'pt-BR' should be sanitized to 'pt' when starting the pipeline.""" + mock_track = MagicMock() + mock_track.kind = rtc.TrackKind.KIND_AUDIO + participant = _make_participant("user_1", audio_track=mock_track) + agent = _make_agent_with_room(participants={"pid": participant}) + + with patch.object( + agent, "_run_transcription_pipeline", new_callable=AsyncMock + ) as mock_pipeline: + agent.start_transcription_for_user("user_1", "pt-BR", "openai") + await asyncio.sleep(0) + + mock_pipeline.assert_called_once_with(participant, mock_track, "pt") + agent.processing_info.pop("user_1", None) + + async def test_processing_info_has_no_stream_key(self): + """REST mode stores only 'task' in processing_info.""" + mock_track = MagicMock() + mock_track.kind = rtc.TrackKind.KIND_AUDIO + participant = _make_participant("user_1", audio_track=mock_track) + agent = _make_agent_with_room(participants={"pid": participant}) + + with patch.object(agent, "_run_transcription_pipeline", new_callable=AsyncMock): + agent.start_transcription_for_user("user_1", "en", "openai") + + assert "task" in agent.processing_info["user_1"] + assert "stream" not in agent.processing_info["user_1"] + agent.processing_info.pop("user_1", None) + + +class TestRunTranscriptionPipeline: + async def test_cancellation_cleans_up_processing_info(self): + """CancelledError should be caught and processing_info entry removed.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.side_effect = asyncio.CancelledError + + agent.processing_info["user_1"] = {"task": MagicMock()} + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + await agent._run_transcription_pipeline(mock_participant, mock_track, "en") + + assert "user_1" not in agent.processing_info + + async def test_emits_final_transcript_for_speech_frames(self): + """Speech frames trigger a final_transcript event via REST API.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + # One loud audio frame followed by end-of-stream triggers end-of-stream flush + loud_event = _make_loud_event() + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.return_value = iter([loud_event]) + + emitted = [] + agent.on("final_transcript", lambda **kw: emitted.append(kw)) + agent._transcribe_wav = AsyncMock(return_value="hello world") + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + with patch("providers.openai.rtc.combine_audio_frames"): + await agent._run_transcription_pipeline( + mock_participant, mock_track, "en" + ) + await asyncio.sleep(0) + + assert len(emitted) == 1 + assert emitted[0]["participant"] is mock_participant + event = emitted[0]["event"] + assert event.type == stt.SpeechEventType.FINAL_TRANSCRIPT + assert event.alternatives[0].text == "hello world" + + async def test_does_not_emit_for_empty_transcript(self): + """No event emitted when REST returns empty text.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + loud_event = _make_loud_event() + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.return_value = iter([loud_event]) + + emitted = [] + agent.on("final_transcript", lambda **kw: emitted.append(kw)) + agent._transcribe_wav = AsyncMock(return_value="") + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + with patch("providers.openai.rtc.combine_audio_frames"): + await agent._run_transcription_pipeline( + mock_participant, mock_track, "en" + ) + + assert len(emitted) == 0 + + async def test_does_not_call_transcribe_for_silent_audio(self): + """Silent frames (below energy threshold) should not trigger REST calls.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + silent_event = _make_audio_event(amplitude=0) + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.return_value = iter([silent_event]) + + agent._transcribe_wav = AsyncMock() + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + await agent._run_transcription_pipeline(mock_participant, mock_track, "en") + + agent._transcribe_wav.assert_not_called() + + async def test_generic_exception_cleans_up_processing_info(self): + """Unexpected exceptions should be caught and processing_info cleaned up.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.side_effect = RuntimeError("boom") + + agent.processing_info["user_1"] = {"task": MagicMock()} + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + await agent._run_transcription_pipeline(mock_participant, mock_track, "en") + + assert "user_1" not in agent.processing_info + + async def test_segment_has_start_and_end_times(self): + """Emitted events must have non-zero start_time and end_time on SpeechData.""" + agent = _make_agent() + mock_participant = MagicMock(spec=rtc.RemoteParticipant) + mock_participant.identity = "user_1" + mock_track = MagicMock() + + loud_event = _make_loud_event() + mock_audio_stream = AsyncMock() + mock_audio_stream.__aiter__.return_value = iter([loud_event]) + + emitted = [] + agent.on("final_transcript", lambda **kw: emitted.append(kw)) + agent._transcribe_wav = AsyncMock(return_value="hello") + + with patch("providers.openai.rtc.AudioStream", return_value=mock_audio_stream): + with patch("providers.openai.rtc.combine_audio_frames"): + await agent._run_transcription_pipeline( + mock_participant, mock_track, "en" + ) + await asyncio.sleep(0) + + assert len(emitted) == 1 + alt = emitted[0]["event"].alternatives[0] + assert alt.start_time >= 0.0 + assert alt.end_time >= alt.start_time + + +class TestCleanup: + async def test_closes_http_session_on_cleanup(self): + """_cleanup() should close the aiohttp session.""" + agent = _make_agent() + mock_session = AsyncMock() + agent._http_session = mock_session + + await agent._cleanup() + + mock_session.close.assert_called_once() + assert agent._http_session is None diff --git a/tests/test_providers_init.py b/tests/test_providers_init.py index bd0c901..34220a6 100644 --- a/tests/test_providers_init.py +++ b/tests/test_providers_init.py @@ -5,6 +5,7 @@ from providers import create_agent from providers.gladia import GladiaSttAgent +from providers.openai import OpenAiSttAgent class TestCreateAgent: @@ -13,6 +14,10 @@ def test_returns_gladia_agent_for_gladia_provider(self): agent = create_agent("gladia") assert isinstance(agent, GladiaSttAgent) + def test_returns_openai_agent_for_openai_provider(self): + agent = create_agent("openai") + assert isinstance(agent, OpenAiSttAgent) + def test_raises_for_unknown_provider(self): with pytest.raises(ValueError, match="Unknown STT provider"): create_agent("nonexistent")