Skip to content

Commit 572f214

Browse files
[VA-136] Support Vad params (#42)
* Add VadConfig dataclass and integrate VAD configuration into STT client parameters * Update VAD configuration handling in STT client
1 parent 003b65c commit 572f214

4 files changed

Lines changed: 337 additions & 10 deletions

File tree

aiola/clients/stt/client.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AiolaValidationError,
1818
)
1919
from ...http_client import create_async_authenticated_client, create_authenticated_client
20-
from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse
20+
from ...types import AiolaClientOptions, File, TasksConfig, TranscriptionResponse, VadConfig
2121
from .stream_client import AsyncStreamConnection, StreamConnection
2222

2323
if TYPE_CHECKING:
@@ -54,6 +54,7 @@ def _build_query_and_headers(
5454
time_zone: str | None,
5555
keywords: dict[str, str] | None,
5656
tasks_config: TasksConfig | None,
57+
vad_config: VadConfig | None,
5758
access_token: str,
5859
) -> tuple[dict[str, str], dict[str, str]]:
5960
"""Build query parameters and headers for streaming requests."""
@@ -73,6 +74,8 @@ def _build_query_and_headers(
7374
query["keywords"] = json.dumps(keywords)
7475
if tasks_config is not None:
7576
query["tasks_config"] = json.dumps(tasks_config)
77+
if vad_config is not None:
78+
query["vad_config"] = json.dumps(vad_config)
7679

7780
headers = {
7881
"Authorization": f"Bearer {access_token}",
@@ -88,6 +91,7 @@ def _validate_stream_params(
8891
time_zone: str | None,
8992
keywords: dict[str, str] | None,
9093
tasks_config: TasksConfig | None,
94+
vad_config: VadConfig | None,
9195
) -> None:
9296
"""Validate streaming parameters."""
9397
if flow_id is not None and not isinstance(flow_id, str):
@@ -100,8 +104,10 @@ def _validate_stream_params(
100104
raise AiolaValidationError("time_zone must be a string")
101105
if keywords is not None and not isinstance(keywords, dict):
102106
raise AiolaValidationError("keywords must be a dictionary")
103-
if tasks_config is not None and not isinstance(tasks_config, dict):
104-
raise AiolaValidationError("tasks_config must be a dictionary")
107+
if tasks_config is not None and not isinstance(tasks_config, dict | TasksConfig):
108+
raise AiolaValidationError("tasks_config must be a dictionary or a TasksConfig object")
109+
if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
110+
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")
105111

106112

107113
class SttClient(_BaseStt):
@@ -119,6 +125,7 @@ def stream(
119125
time_zone: str | None = None,
120126
keywords: dict[str, str] | None = None,
121127
tasks_config: TasksConfig | None = None,
128+
vad_config: VadConfig | None = None,
122129
) -> StreamConnection:
123130
"""Create a streaming connection for real-time transcription.
124131
@@ -135,7 +142,9 @@ def stream(
135142
StreamConnection: A connection object for real-time streaming.
136143
"""
137144
try:
138-
self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config)
145+
self._validate_stream_params(
146+
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config
147+
)
139148

140149
# Resolve workflow_id with proper precedence
141150
resolved_workflow_id = self._resolve_workflow_id(workflow_id)
@@ -149,7 +158,7 @@ def stream(
149158

150159
# Build query parameters and headers
151160
query, headers = self._build_query_and_headers(
152-
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token
161+
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token
153162
)
154163

155164
url = self._build_url(query)
@@ -168,6 +177,7 @@ def transcribe_file(
168177
*,
169178
language: str | None = None,
170179
keywords: dict[str, str] | None = None,
180+
vad_config: VadConfig | None = None,
171181
) -> TranscriptionResponse:
172182
"""Transcribe an audio file and return the transcription result."""
173183

@@ -180,12 +190,16 @@ def transcribe_file(
180190
if keywords is not None and not isinstance(keywords, dict):
181191
raise AiolaValidationError("keywords must be a dictionary")
182192

193+
if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
194+
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")
195+
183196
try:
184197
# Prepare the form data
185198
files = {"file": file}
186199
data = {
187200
"language": language or "en",
188201
"keywords": json.dumps(keywords or {}),
202+
"vad_config": json.dumps(vad_config or {}),
189203
}
190204

191205
# Create authenticated HTTP client and make request
@@ -229,6 +243,7 @@ async def stream(
229243
time_zone: str | None = None,
230244
keywords: dict[str, str] | None = None,
231245
tasks_config: TasksConfig | None = None,
246+
vad_config: VadConfig | None = None,
232247
) -> AsyncStreamConnection:
233248
"""Create an async streaming connection for real-time transcription.
234249
@@ -245,7 +260,9 @@ async def stream(
245260
AsyncStreamConnection: A connection object for real-time async streaming.
246261
"""
247262
try:
248-
self._validate_stream_params(workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config)
263+
self._validate_stream_params(
264+
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config
265+
)
249266

250267
# Resolve workflow_id with proper precedence
251268
resolved_workflow_id = self._resolve_workflow_id(workflow_id)
@@ -259,7 +276,7 @@ async def stream(
259276

260277
# Build query parameters and headers
261278
query, headers = self._build_query_and_headers(
262-
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, access_token
279+
workflow_id, execution_id, lang_code, time_zone, keywords, tasks_config, vad_config, access_token
263280
)
264281

265282
url = self._build_url(query)
@@ -278,6 +295,7 @@ async def transcribe_file(
278295
*,
279296
language: str | None = None,
280297
keywords: dict[str, str] | None = None,
298+
vad_config: VadConfig | None = None,
281299
) -> TranscriptionResponse:
282300
"""Transcribe an audio file and return the transcription result."""
283301

@@ -290,12 +308,16 @@ async def transcribe_file(
290308
if keywords is not None and not isinstance(keywords, dict):
291309
raise AiolaValidationError("keywords must be a dictionary")
292310

311+
if vad_config is not None and not isinstance(vad_config, dict | VadConfig):
312+
raise AiolaValidationError("vad_config must be a dictionary or a VadConfig object")
313+
293314
try:
294315
# Prepare the form data
295316
files = {"file": file}
296317
data = {
297318
"language": language or "en",
298319
"keywords": json.dumps(keywords or {}),
320+
"vad_config": json.dumps(vad_config or {}),
299321
}
300322

301323
# Create authenticated HTTP client and make request

aiola/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ class TasksConfig:
140140
TRANSLATION: TranslationPayload | None = None
141141

142142

143+
@dataclass
144+
class VadConfig:
145+
threshold: float | None = None
146+
min_speech_ms: float | None = None
147+
min_silence_ms: float | None = None
148+
max_segment_ms: float | None = None
149+
150+
143151
FileContent = Union[IO[bytes], bytes, str]
144152
File = Union[
145153
# file (or bytes)

0 commit comments

Comments
 (0)