diff --git a/README.md b/README.md index b3286d58..a70daf62 100644 --- a/README.md +++ b/README.md @@ -102,13 +102,33 @@ openplanter-agent --provider ollama --list-models The base URL defaults to `http://localhost:11434/v1` and can be overridden with `OPENPLANTER_OLLAMA_BASE_URL` or `--base-url`. The first request may be slow while Ollama loads the model into memory; a 120-second first-byte timeout is used automatically. -Additional service keys: `EXA_API_KEY` (web search), `VOYAGE_API_KEY` (embeddings). +Additional service keys: `EXA_API_KEY` (web search), `VOYAGE_API_KEY` (embeddings), `MISTRAL_TRANSCRIPTION_API_KEY` or `MISTRAL_API_KEY` (audio transcription). + +### Audio Transcription + +OpenPlanter includes an `audio_transcribe` tool backed by Mistral's offline transcription API. It accepts local workspace audio or video files, returns transcript text plus any timestamps or diarization metadata Mistral provides, and automatically falls back to overlapping chunked transcription for long recordings when `chunking` is left at `auto`. + +Useful overrides: + +```bash +export MISTRAL_API_KEY=... +export OPENPLANTER_MISTRAL_TRANSCRIPTION_MODEL=voxtral-mini-latest +export OPENPLANTER_MISTRAL_TRANSCRIPTION_MAX_BYTES=104857600 +export OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS=900 +export OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS=2.0 +``` + +Notes: +- The tool only accepts local workspace files. +- Long-form chunking requires `ffmpeg` and `ffprobe` to be available at runtime. +- `chunking: "force"` always chunks, and `chunking: "off"` keeps the single-upload path. +- Video inputs are audio-extracted with `ffmpeg` before transcription. All keys can also be set with an `OPENPLANTER_` prefix (e.g. `OPENPLANTER_OPENAI_API_KEY`), via `.env` files in the workspace, or via CLI flags. ## Agent Tools -The agent has access to 19 tools, organized around its investigation workflow: +The agent has access to 20 tools, organized around its investigation workflow: **Dataset ingestion & workspace** — `list_files`, `search_files`, `repo_map`, `read_file`, `write_file`, `edit_file`, `hashline_edit`, `apply_patch` — load, inspect, and transform source datasets; write structured findings. @@ -116,6 +136,8 @@ The agent has access to 19 tools, organized around its investigation workflow: **Web** — `web_search` (Exa), `fetch_url` — pull public records, verify entities, and retrieve supplementary data. +**Audio** — `audio_transcribe` — transcribe local audio or video with Mistral, including optional timestamps, diarization, and automatic chunking for long recordings. + **Planning & delegation** — `think`, `subtask`, `execute`, `list_artifacts`, `read_artifact` — decompose investigations into focused sub-tasks, each with acceptance criteria and independent verification. In **recursive mode** (the default), the agent spawns sub-agents via `subtask` and `execute` to parallelize entity resolution, cross-dataset linking, and evidence-chain construction across large investigations. diff --git a/agent/__main__.py b/agent/__main__.py index 55436c00..d42ca40e 100644 --- a/agent/__main__.py +++ b/agent/__main__.py @@ -213,6 +213,7 @@ def _load_credentials( cerebras_api_key=user_creds.cerebras_api_key, exa_api_key=user_creds.exa_api_key, voyage_api_key=user_creds.voyage_api_key, + mistral_transcription_api_key=user_creds.mistral_transcription_api_key, ) store = CredentialStore(workspace=cfg.workspace, session_root_dir=cfg.session_root_dir) @@ -229,6 +230,8 @@ def _load_credentials( creds.exa_api_key = stored.exa_api_key if stored.voyage_api_key: creds.voyage_api_key = stored.voyage_api_key + if stored.mistral_transcription_api_key: + creds.mistral_transcription_api_key = stored.mistral_transcription_api_key env_creds = credentials_from_env() if env_creds.openai_api_key: @@ -243,6 +246,8 @@ def _load_credentials( creds.exa_api_key = env_creds.exa_api_key if env_creds.voyage_api_key: creds.voyage_api_key = env_creds.voyage_api_key + if env_creds.mistral_transcription_api_key: + creds.mistral_transcription_api_key = env_creds.mistral_transcription_api_key for env_path in discover_env_candidates(cfg.workspace): file_creds = parse_env_file(env_path) @@ -304,6 +309,7 @@ def _apply_runtime_overrides(cfg: AgentConfig, args: argparse.Namespace, creds: cfg.cerebras_api_key = creds.cerebras_api_key cfg.exa_api_key = creds.exa_api_key cfg.voyage_api_key = creds.voyage_api_key + cfg.mistral_transcription_api_key = creds.mistral_transcription_api_key cfg.api_key = cfg.openai_api_key if args.base_url: diff --git a/agent/builder.py b/agent/builder.py index a47d3e31..6abb6887 100644 --- a/agent/builder.py +++ b/agent/builder.py @@ -165,6 +165,15 @@ def build_engine(cfg: AgentConfig) -> RLMEngine: max_search_hits=cfg.max_search_hits, exa_api_key=cfg.exa_api_key, exa_base_url=cfg.exa_base_url, + mistral_transcription_api_key=cfg.mistral_transcription_api_key, + mistral_transcription_base_url=cfg.mistral_transcription_base_url, + mistral_transcription_model=cfg.mistral_transcription_model, + mistral_transcription_max_bytes=cfg.mistral_transcription_max_bytes, + mistral_transcription_chunk_max_seconds=cfg.mistral_transcription_chunk_max_seconds, + mistral_transcription_chunk_overlap_seconds=cfg.mistral_transcription_chunk_overlap_seconds, + mistral_transcription_max_chunks=cfg.mistral_transcription_max_chunks, + mistral_transcription_request_timeout_sec=cfg.mistral_transcription_request_timeout_sec, + max_observation_chars=cfg.max_observation_chars, ) try: diff --git a/agent/config.py b/agent/config.py index 83239de2..87e85264 100644 --- a/agent/config.py +++ b/agent/config.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from pathlib import Path +MISTRAL_TRANSCRIPTION_BASE_URL = "https://api.mistral.ai" +MISTRAL_TRANSCRIPTION_DEFAULT_MODEL = "voxtral-mini-latest" +MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS = 900 +MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS = 2.0 +MISTRAL_TRANSCRIPTION_MAX_CHUNKS = 48 +MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC = 180 + PROVIDER_DEFAULT_MODELS: dict[str, str] = { "openai": "gpt-5.2", "anthropic": "claude-opus-4-6", @@ -27,12 +34,24 @@ class AgentConfig: cerebras_base_url: str = "https://api.cerebras.ai/v1" ollama_base_url: str = "http://localhost:11434/v1" exa_base_url: str = "https://api.exa.ai" + mistral_transcription_base_url: str = MISTRAL_TRANSCRIPTION_BASE_URL openai_api_key: str | None = None anthropic_api_key: str | None = None openrouter_api_key: str | None = None cerebras_api_key: str | None = None exa_api_key: str | None = None voyage_api_key: str | None = None + mistral_transcription_api_key: str | None = None + mistral_transcription_model: str = MISTRAL_TRANSCRIPTION_DEFAULT_MODEL + mistral_transcription_max_bytes: int = 100 * 1024 * 1024 + mistral_transcription_chunk_max_seconds: int = MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS + mistral_transcription_chunk_overlap_seconds: float = ( + MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS + ) + mistral_transcription_max_chunks: int = MISTRAL_TRANSCRIPTION_MAX_CHUNKS + mistral_transcription_request_timeout_sec: int = ( + MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC + ) max_depth: int = 4 max_steps_per_call: int = 100 budget_extension_enabled: bool = True @@ -71,6 +90,11 @@ def from_env(cls, workspace: str | Path) -> "AgentConfig": cerebras_api_key = os.getenv("OPENPLANTER_CEREBRAS_API_KEY") or os.getenv("CEREBRAS_API_KEY") exa_api_key = os.getenv("OPENPLANTER_EXA_API_KEY") or os.getenv("EXA_API_KEY") voyage_api_key = os.getenv("OPENPLANTER_VOYAGE_API_KEY") or os.getenv("VOYAGE_API_KEY") + mistral_transcription_api_key = ( + os.getenv("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + or os.getenv("MISTRAL_TRANSCRIPTION_API_KEY") + or os.getenv("MISTRAL_API_KEY") + ) openai_base_url = os.getenv("OPENPLANTER_OPENAI_BASE_URL") or os.getenv( "OPENPLANTER_BASE_URL", "https://api.openai.com/v1", @@ -100,12 +124,51 @@ def from_env(cls, workspace: str | Path) -> "AgentConfig": cerebras_base_url=os.getenv("OPENPLANTER_CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1"), ollama_base_url=os.getenv("OPENPLANTER_OLLAMA_BASE_URL", "http://localhost:11434/v1"), exa_base_url=os.getenv("OPENPLANTER_EXA_BASE_URL", "https://api.exa.ai"), + mistral_transcription_base_url=os.getenv( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_BASE_URL", + os.getenv("MISTRAL_TRANSCRIPTION_BASE_URL") + or os.getenv("MISTRAL_BASE_URL") + or MISTRAL_TRANSCRIPTION_BASE_URL, + ), openai_api_key=openai_api_key, anthropic_api_key=anthropic_api_key, openrouter_api_key=openrouter_api_key, cerebras_api_key=cerebras_api_key, exa_api_key=exa_api_key, voyage_api_key=voyage_api_key, + mistral_transcription_api_key=(mistral_transcription_api_key or "").strip() or None, + mistral_transcription_model=( + os.getenv("OPENPLANTER_MISTRAL_TRANSCRIPTION_MODEL") + or os.getenv("MISTRAL_TRANSCRIPTION_MODEL") + or MISTRAL_TRANSCRIPTION_DEFAULT_MODEL + ), + mistral_transcription_max_bytes=int( + os.getenv("OPENPLANTER_MISTRAL_TRANSCRIPTION_MAX_BYTES", "104857600") + ), + mistral_transcription_chunk_max_seconds=int( + os.getenv( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS", + str(MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS), + ) + ), + mistral_transcription_chunk_overlap_seconds=float( + os.getenv( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS", + str(MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS), + ) + ), + mistral_transcription_max_chunks=int( + os.getenv( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_MAX_CHUNKS", + str(MISTRAL_TRANSCRIPTION_MAX_CHUNKS), + ) + ), + mistral_transcription_request_timeout_sec=int( + os.getenv( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC", + str(MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC), + ) + ), max_depth=int(os.getenv("OPENPLANTER_MAX_DEPTH", "4")), max_steps_per_call=int(os.getenv("OPENPLANTER_MAX_STEPS", "100")), budget_extension_enabled=budget_extension_enabled, diff --git a/agent/credentials.py b/agent/credentials.py index fb3fb052..a6a90866 100644 --- a/agent/credentials.py +++ b/agent/credentials.py @@ -17,6 +17,7 @@ class CredentialBundle: cerebras_api_key: str | None = None exa_api_key: str | None = None voyage_api_key: str | None = None + mistral_transcription_api_key: str | None = None def has_any(self) -> bool: return bool( @@ -26,6 +27,10 @@ def has_any(self) -> bool: or (self.cerebras_api_key and self.cerebras_api_key.strip()) or (self.exa_api_key and self.exa_api_key.strip()) or (self.voyage_api_key and self.voyage_api_key.strip()) + or ( + self.mistral_transcription_api_key + and self.mistral_transcription_api_key.strip() + ) ) def merge_missing(self, other: "CredentialBundle") -> None: @@ -41,6 +46,11 @@ def merge_missing(self, other: "CredentialBundle") -> None: self.exa_api_key = other.exa_api_key if not self.voyage_api_key and other.voyage_api_key: self.voyage_api_key = other.voyage_api_key + if ( + not self.mistral_transcription_api_key + and other.mistral_transcription_api_key + ): + self.mistral_transcription_api_key = other.mistral_transcription_api_key def to_json(self) -> dict[str, str]: out: dict[str, str] = {} @@ -56,6 +66,8 @@ def to_json(self) -> dict[str, str]: out["exa_api_key"] = self.exa_api_key if self.voyage_api_key: out["voyage_api_key"] = self.voyage_api_key + if self.mistral_transcription_api_key: + out["mistral_transcription_api_key"] = self.mistral_transcription_api_key return out @classmethod @@ -69,6 +81,10 @@ def from_json(cls, payload: dict[str, str] | None) -> "CredentialBundle": cerebras_api_key=(payload.get("cerebras_api_key") or "").strip() or None, exa_api_key=(payload.get("exa_api_key") or "").strip() or None, voyage_api_key=(payload.get("voyage_api_key") or "").strip() or None, + mistral_transcription_api_key=( + payload.get("mistral_transcription_api_key") or "" + ).strip() + or None, ) @@ -115,6 +131,13 @@ def parse_env_file(path: Path) -> CredentialBundle: or None, exa_api_key=(env.get("EXA_API_KEY") or env.get("OPENPLANTER_EXA_API_KEY") or "").strip() or None, voyage_api_key=(env.get("VOYAGE_API_KEY") or env.get("OPENPLANTER_VOYAGE_API_KEY") or "").strip() or None, + mistral_transcription_api_key=( + env.get("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + or env.get("MISTRAL_TRANSCRIPTION_API_KEY") + or env.get("MISTRAL_API_KEY") + or "" + ).strip() + or None, ) @@ -140,6 +163,13 @@ def credentials_from_env() -> CredentialBundle: or None, exa_api_key=(os.getenv("OPENPLANTER_EXA_API_KEY") or os.getenv("EXA_API_KEY") or "").strip() or None, voyage_api_key=(os.getenv("OPENPLANTER_VOYAGE_API_KEY") or os.getenv("VOYAGE_API_KEY") or "").strip() or None, + mistral_transcription_api_key=( + os.getenv("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + or os.getenv("MISTRAL_TRANSCRIPTION_API_KEY") + or os.getenv("MISTRAL_API_KEY") + or "" + ).strip() + or None, ) @@ -230,6 +260,7 @@ def prompt_for_credentials( cerebras_api_key=existing.cerebras_api_key, exa_api_key=existing.exa_api_key, voyage_api_key=existing.voyage_api_key, + mistral_transcription_api_key=existing.mistral_transcription_api_key, ) should_prompt = force or not current.has_any() @@ -263,6 +294,9 @@ def _ask(label: str, existing_value: str | None) -> str | None: current.cerebras_api_key = _ask("Cerebras", current.cerebras_api_key) current.exa_api_key = _ask("Exa", current.exa_api_key) current.voyage_api_key = _ask("Voyage", current.voyage_api_key) + current.mistral_transcription_api_key = _ask( + "Mistral Transcription", current.mistral_transcription_api_key + ) if not force and current.has_any() and not existing.has_any(): changed = True return current, changed diff --git a/agent/engine.py b/agent/engine.py index 33ff033c..1e4f4de1 100644 --- a/agent/engine.py +++ b/agent/engine.py @@ -32,6 +32,7 @@ "fetch_url", "read_file", "read_image", + "audio_transcribe", "list_artifacts", "read_artifact", } @@ -1261,6 +1262,80 @@ def _apply_tool_call( self._pending_image.data = (b64, media_type) return False, text + if name == "audio_transcribe": + path = str(args.get("path", "")).strip() + if not path: + return False, "audio_transcribe requires path" + diarize = args.get("diarize") + diarize = diarize if isinstance(diarize, bool) else None + raw_timestamps = args.get("timestamp_granularities") + if isinstance(raw_timestamps, list): + timestamp_granularities = [ + str(v).strip() for v in raw_timestamps if str(v).strip() + ] + elif isinstance(raw_timestamps, str) and raw_timestamps.strip(): + timestamp_granularities = [raw_timestamps.strip()] + else: + timestamp_granularities = None + raw_context_bias = args.get("context_bias") + if isinstance(raw_context_bias, list): + context_bias = [ + str(v).strip() for v in raw_context_bias if str(v).strip() + ] + elif isinstance(raw_context_bias, str) and raw_context_bias.strip(): + context_bias = [ + part.strip() + for part in raw_context_bias.split(",") + if part.strip() + ] + else: + context_bias = None + language = str(args.get("language", "")).strip() or None + model = str(args.get("model", "")).strip() or None + raw_temperature = args.get("temperature") + temperature = None + if isinstance(raw_temperature, (int, float)) and not isinstance( + raw_temperature, bool + ): + temperature = float(raw_temperature) + chunking = str(args.get("chunking", "")).strip().lower() or None + raw_chunk_max_seconds = args.get("chunk_max_seconds") + chunk_max_seconds = None + if isinstance(raw_chunk_max_seconds, int) and not isinstance( + raw_chunk_max_seconds, bool + ): + chunk_max_seconds = raw_chunk_max_seconds + raw_chunk_overlap_seconds = args.get("chunk_overlap_seconds") + chunk_overlap_seconds = None + if isinstance(raw_chunk_overlap_seconds, (int, float)) and not isinstance( + raw_chunk_overlap_seconds, bool + ): + chunk_overlap_seconds = float(raw_chunk_overlap_seconds) + raw_max_chunks = args.get("max_chunks") + max_chunks = None + if isinstance(raw_max_chunks, int) and not isinstance(raw_max_chunks, bool): + max_chunks = raw_max_chunks + raw_continue_on_chunk_error = args.get("continue_on_chunk_error") + continue_on_chunk_error = ( + raw_continue_on_chunk_error + if isinstance(raw_continue_on_chunk_error, bool) + else None + ) + return False, self.tools.audio_transcribe( + path=path, + diarize=diarize, + timestamp_granularities=timestamp_granularities, + context_bias=context_bias, + language=language, + model=model, + temperature=temperature, + chunking=chunking, + chunk_max_seconds=chunk_max_seconds, + chunk_overlap_seconds=chunk_overlap_seconds, + max_chunks=max_chunks, + continue_on_chunk_error=continue_on_chunk_error, + ) + if name == "write_file": path = str(args.get("path", "")).strip() if not path: diff --git a/agent/model.py b/agent/model.py index 30bc3ff7..a029dae1 100644 --- a/agent/model.py +++ b/agent/model.py @@ -15,6 +15,23 @@ class ModelError(RuntimeError): pass +class RateLimitError(ModelError): + def __init__( + self, + message: str, + *, + status_code: int | None = None, + provider_code: str | int | None = None, + body: str = "", + retry_after_sec: float | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.provider_code = provider_code + self.body = body + self.retry_after_sec = retry_after_sec + + # --------------------------------------------------------------------------- # Core data types # --------------------------------------------------------------------------- diff --git a/agent/tool_defs.py b/agent/tool_defs.py index 323edbde..90fb3ba5 100644 --- a/agent/tool_defs.py +++ b/agent/tool_defs.py @@ -134,6 +134,72 @@ "additionalProperties": False, }, }, + { + "name": "audio_transcribe", + "description": ( + "Transcribe a local audio file with Mistral's offline transcription API. " + "Supports diarization, timestamp granularity, context bias, language, " + "model override, temperature, and optional chunking for long-form audio/video." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative or absolute path to the audio file within the workspace.", + }, + "diarize": { + "type": "boolean", + "description": "Whether to request speaker diarization.", + }, + "timestamp_granularities": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional timestamp granularity values such as 'segment' or 'word'.", + }, + "context_bias": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional bias phrases to steer transcription toward expected terms.", + }, + "language": { + "type": "string", + "description": "Optional ISO language hint. Cannot be combined with timestamp_granularities.", + }, + "model": { + "type": "string", + "description": "Optional transcription model override.", + }, + "temperature": { + "type": "number", + "description": "Optional decoding temperature.", + }, + "chunking": { + "type": "string", + "description": "Long-form handling mode: 'auto', 'off', or 'force'.", + "enum": ["auto", "off", "force"], + }, + "chunk_max_seconds": { + "type": "integer", + "description": "Maximum chunk duration in seconds for chunked transcription.", + }, + "chunk_overlap_seconds": { + "type": "number", + "description": "Overlap between adjacent chunks in seconds.", + }, + "max_chunks": { + "type": "integer", + "description": "Maximum number of chunks allowed for a transcription run.", + }, + "continue_on_chunk_error": { + "type": "boolean", + "description": "Continue remaining chunks and return partial output if a chunk fails.", + }, + }, + "required": ["path"], + "additionalProperties": False, + }, + }, { "name": "write_file", "description": "Create or overwrite a file in the workspace with the given content.", diff --git a/agent/tools.py b/agent/tools.py index 86a9e5ce..8310c5be 100644 --- a/agent/tools.py +++ b/agent/tools.py @@ -2,8 +2,10 @@ import ast import base64 +import copy import fnmatch import json +import mimetypes import os import signal import shutil @@ -12,6 +14,7 @@ import threading import urllib.error import urllib.request +import uuid import re as _re import zlib from contextlib import contextmanager @@ -34,6 +37,7 @@ _HASHLINE_PREFIX_RE = _re.compile(r"^\d+:[0-9a-f]{2}\|") _HEREDOC_RE = _re.compile(r"<<-?\s*['\"]?\w+['\"]?") _INTERACTIVE_RE = _re.compile(r"(^|[;&|]\s*)(vim|nano|less|more|top|htop|man)\b") +_TOKEN_NORMALIZE_RE = _re.compile(r"[^a-z0-9]+") def _line_hash(line: str) -> str: @@ -52,10 +56,19 @@ class WorkspaceTools: command_timeout_sec: int = 45 max_shell_output_chars: int = 16000 max_file_chars: int = 20000 + max_observation_chars: int = 6000 max_files_listed: int = 400 max_search_hits: int = 200 exa_api_key: str | None = None exa_base_url: str = "https://api.exa.ai" + mistral_transcription_api_key: str | None = None + mistral_transcription_base_url: str = "https://api.mistral.ai" + mistral_transcription_model: str = "voxtral-mini-latest" + mistral_transcription_max_bytes: int = 100 * 1024 * 1024 + mistral_transcription_chunk_max_seconds: int = 900 + mistral_transcription_chunk_overlap_seconds: float = 2.0 + mistral_transcription_max_chunks: int = 48 + mistral_transcription_request_timeout_sec: int = 180 def __post_init__(self) -> None: self.root = self.root.expanduser().resolve() @@ -548,6 +561,875 @@ def read_image(self, path: str) -> tuple[str, str | None, str | None]: text = f"Image {rel} ({len(raw):,} bytes, {media_type})" return text, b64, media_type + _AUDIO_EXTENSIONS = { + ".aac", + ".flac", + ".m4a", + ".mp3", + ".mpeg", + ".mpga", + ".oga", + ".ogg", + ".opus", + ".wav", + } + _VIDEO_EXTENSIONS = { + ".avi", + ".m4v", + ".mkv", + ".mov", + ".mp4", + ".webm", + } + _TIMESTAMP_GRANULARITIES = {"segment", "word"} + _AUDIO_CHUNKING_MODES = {"auto", "force", "off"} + _AUDIO_CHUNK_TARGET_FILL_RATIO = 0.85 + _AUDIO_CHUNK_BYTES_PER_SECOND = 32000 + _AUDIO_MIN_CHUNK_SECONDS = 30.0 + _AUDIO_MAX_CHUNK_SECONDS = 1800.0 + _AUDIO_MAX_CHUNK_OVERLAP_SECONDS = 15.0 + _AUDIO_MAX_CHUNKS = 200 + _AUDIO_SPEAKER_FIELDS = {"speaker", "speaker_id", "speaker_label"} + + def _mistral_transcription_url(self) -> str: + base = self.mistral_transcription_base_url.rstrip("/") + if base.endswith("/v1"): + return f"{base}/audio/transcriptions" + return f"{base}/v1/audio/transcriptions" + + def _encode_multipart_form_data( + self, + *, + fields: list[tuple[str, str]], + file_field_name: str, + file_name: str, + file_bytes: bytes, + media_type: str, + ) -> tuple[bytes, str]: + boundary = f"----OpenPlanter{uuid.uuid4().hex}" + chunks: list[bytes] = [] + for key, value in fields: + chunks.append(f"--{boundary}\r\n".encode("utf-8")) + chunks.append( + f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode( + "utf-8" + ) + ) + chunks.append(value.encode("utf-8")) + chunks.append(b"\r\n") + safe_name = Path(file_name).name.replace('"', "") + chunks.append(f"--{boundary}\r\n".encode("utf-8")) + chunks.append( + ( + f'Content-Disposition: form-data; name="{file_field_name}"; ' + f'filename="{safe_name}"\r\n' + ).encode("utf-8") + ) + chunks.append(f"Content-Type: {media_type}\r\n\r\n".encode("utf-8")) + chunks.append(file_bytes) + chunks.append(b"\r\n") + chunks.append(f"--{boundary}--\r\n".encode("utf-8")) + return b"".join(chunks), boundary + + def _mistral_transcription_request( + self, + *, + resolved: Path, + model: str, + diarize: bool | None, + timestamp_granularities: list[str] | None, + context_bias: list[str] | None, + language: str | None, + temperature: float | None, + ) -> dict[str, Any]: + if not ( + self.mistral_transcription_api_key + and self.mistral_transcription_api_key.strip() + ): + raise ToolError("Mistral transcription API key not configured") + try: + size = resolved.stat().st_size + except OSError as exc: + raise ToolError(f"Failed to inspect audio file {resolved.name}: {exc}") from exc + if size > self.mistral_transcription_max_bytes: + raise ToolError( + f"Audio file too large: {size:,} bytes " + f"(max {self.mistral_transcription_max_bytes:,} bytes)" + ) + try: + file_bytes = resolved.read_bytes() + except OSError as exc: + raise ToolError(f"Failed to read audio file {resolved.name}: {exc}") from exc + + media_type = mimetypes.guess_type(resolved.name)[0] or "application/octet-stream" + fields: list[tuple[str, str]] = [ + ("model", model), + ("stream", "false"), + ] + if diarize is not None: + fields.append(("diarize", "true" if diarize else "false")) + if language: + fields.append(("language", language)) + if temperature is not None: + fields.append(("temperature", str(temperature))) + for granularity in timestamp_granularities or []: + fields.append(("timestamp_granularities", granularity)) + for phrase in context_bias or []: + fields.append(("context_bias", phrase)) + + body, boundary = self._encode_multipart_form_data( + fields=fields, + file_field_name="file", + file_name=resolved.name, + file_bytes=file_bytes, + media_type=media_type, + ) + req = urllib.request.Request( + url=self._mistral_transcription_url(), + data=body, + headers={ + "Authorization": f"Bearer {self.mistral_transcription_api_key}", + "Content-Type": f"multipart/form-data; boundary={boundary}", + }, + method="POST", + ) + try: + with urllib.request.urlopen( + req, timeout=self.mistral_transcription_request_timeout_sec + ) as resp: + raw = resp.read().decode("utf-8", errors="replace") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise ToolError(f"Mistral transcription HTTP {exc.code}: {body}") from exc + except urllib.error.URLError as exc: + raise ToolError(f"Mistral transcription connection error: {exc}") from exc + except OSError as exc: + raise ToolError(f"Mistral transcription network error: {exc}") from exc + + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ToolError( + f"Mistral transcription returned non-JSON payload: {raw[:500]}" + ) from exc + if not isinstance(parsed, dict): + raise ToolError( + f"Mistral transcription returned non-object response: {type(parsed)!r}" + ) + return parsed + + def _audio_transcribe_max_chars(self) -> int: + return min(self.max_file_chars, self.max_observation_chars) + + def _audio_transcribe_options( + self, + *, + diarize: bool | None, + timestamp_granularities: list[str] | None, + context_bias: list[str] | None, + language: str | None, + temperature: float | None, + chunking: str, + chunk_max_seconds: int | None, + chunk_overlap_seconds: float | None, + max_chunks: int | None, + continue_on_chunk_error: bool | None, + ) -> dict[str, Any]: + options: dict[str, Any] = {"chunking": chunking} + if diarize is not None: + options["diarize"] = diarize + if timestamp_granularities: + options["timestamp_granularities"] = timestamp_granularities + if context_bias: + options["context_bias"] = context_bias + if language: + options["language"] = language + if temperature is not None: + options["temperature"] = temperature + if chunk_max_seconds is not None: + options["chunk_max_seconds"] = chunk_max_seconds + if chunk_overlap_seconds is not None: + options["chunk_overlap_seconds"] = chunk_overlap_seconds + if max_chunks is not None: + options["max_chunks"] = max_chunks + if continue_on_chunk_error is not None: + options["continue_on_chunk_error"] = continue_on_chunk_error + return options + + def _ensure_media_tools(self) -> None: + missing = [ + name for name in ("ffmpeg", "ffprobe") if shutil.which(name) is None + ] + if missing: + joined = ", ".join(missing) + raise ToolError( + f"Long-form transcription requires {joined}. Install ffmpeg/ffprobe and retry." + ) + + def _run_media_command(self, argv: list[str]) -> str: + try: + completed = subprocess.run( + argv, + capture_output=True, + text=True, + timeout=self.command_timeout_sec, + check=False, + ) + except FileNotFoundError as exc: + raise ToolError(f"Media tooling not available: {argv[0]}") from exc + except subprocess.TimeoutExpired as exc: + raise ToolError(f"{argv[0]} timed out after {self.command_timeout_sec}s") from exc + if completed.returncode != 0: + stderr = completed.stderr.strip() or completed.stdout.strip() + raise ToolError(f"{argv[0]} failed: {stderr or 'unknown error'}") + return completed.stdout + + def _probe_media_duration(self, source: Path) -> float: + raw = self._run_media_command( + [ + "ffprobe", + "-v", + "error", + "-print_format", + "json", + "-show_format", + str(source), + ] + ) + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ToolError(f"ffprobe returned invalid JSON for {source.name}") from exc + duration_value = ( + parsed.get("format", {}).get("duration") + if isinstance(parsed, dict) + else None + ) + try: + duration = float(duration_value) + except (TypeError, ValueError) as exc: + raise ToolError(f"ffprobe did not return a valid duration for {source.name}") from exc + if duration <= 0: + raise ToolError(f"ffprobe reported non-positive duration for {source.name}") + return duration + + def _extract_audio_source(self, source: Path, output: Path) -> None: + self._run_media_command( + [ + "ffmpeg", + "-nostdin", + "-y", + "-i", + str(source), + "-vn", + "-ac", + "1", + "-ar", + "16000", + "-c:a", + "pcm_s16le", + str(output), + ] + ) + + def _extract_audio_chunk( + self, + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + self._run_media_command( + [ + "ffmpeg", + "-nostdin", + "-y", + "-ss", + f"{start_sec:.3f}", + "-i", + str(source), + "-t", + f"{duration_sec:.3f}", + "-vn", + "-ac", + "1", + "-ar", + "16000", + "-c:a", + "pcm_s16le", + str(output), + ] + ) + + def _audio_chunk_seconds_budget(self, requested_seconds: float) -> float: + safe_seconds = ( + self.mistral_transcription_max_bytes + * self._AUDIO_CHUNK_TARGET_FILL_RATIO + / self._AUDIO_CHUNK_BYTES_PER_SECOND + ) + if safe_seconds <= 0: + raise ToolError("Mistral transcription max-bytes budget is too small to chunk audio") + return min(requested_seconds, safe_seconds) + + def _plan_audio_chunks( + self, + *, + duration_sec: float, + chunk_seconds: float, + overlap_seconds: float, + max_chunks: int, + ) -> list[dict[str, float]]: + if duration_sec <= 0: + raise ToolError("Cannot chunk media with non-positive duration") + chunk_seconds = max(1.0, chunk_seconds) + overlap_seconds = min(max(0.0, overlap_seconds), max(0.0, chunk_seconds - 0.001)) + chunks: list[dict[str, float]] = [] + start = 0.0 + while start < duration_sec - 1e-6: + end = min(duration_sec, start + chunk_seconds) + index = len(chunks) + chunks.append( + { + "index": float(index), + "start_sec": round(start, 3), + "end_sec": round(end, 3), + "duration_sec": round(end - start, 3), + "leading_overlap_sec": 0.0 if index == 0 else round(overlap_seconds, 3), + } + ) + if len(chunks) > max_chunks: + raise ToolError( + f"Chunk plan would create {len(chunks)} chunks (max {max_chunks})" + ) + if end >= duration_sec - 1e-6: + break + next_start = end - overlap_seconds + if next_start <= start + 1e-6: + next_start = end + start = next_start + return chunks + + def _is_video_extension(self, ext: str) -> bool: + return ext in self._VIDEO_EXTENSIONS + + def _normalized_audio_token(self, token: str) -> str: + return _TOKEN_NORMALIZE_RE.sub("", token.lower()) + + def _dedupe_audio_overlap_text(self, existing_text: str, incoming_text: str) -> str: + if not existing_text.strip(): + return incoming_text.strip() + current_tokens = incoming_text.split() + if not current_tokens: + return "" + previous_tokens = existing_text.split() + max_window = min(len(previous_tokens), len(current_tokens), 80) + if max_window < 5: + return incoming_text.strip() + previous_norm = [ + self._normalized_audio_token(token) + for token in previous_tokens[-max_window:] + ] + current_norm = [ + self._normalized_audio_token(token) + for token in current_tokens[:max_window] + ] + for match_len in range(max_window, 4, -1): + if previous_norm[-match_len:] == current_norm[:match_len]: + return " ".join(current_tokens[match_len:]).strip() + return incoming_text.strip() + + def _entry_time_bounds(self, entry: dict[str, Any]) -> tuple[float, float] | None: + start = entry.get("start") + end = entry.get("end") + if isinstance(start, (int, float)) and isinstance(end, (int, float)): + return float(start), float(end) + timestamps = entry.get("timestamps") + if ( + isinstance(timestamps, list) + and len(timestamps) >= 2 + and isinstance(timestamps[0], (int, float)) + and isinstance(timestamps[1], (int, float)) + ): + return float(timestamps[0]), float(timestamps[1]) + return None + + def _set_entry_time_bounds( + self, + entry: dict[str, Any], + *, + start: float, + end: float, + ) -> None: + if "start" in entry or "end" in entry: + entry["start"] = round(start, 3) + entry["end"] = round(end, 3) + elif isinstance(entry.get("timestamps"), list): + timestamps = list(entry.get("timestamps", [])) + while len(timestamps) < 2: + timestamps.append(0.0) + timestamps[0] = round(start, 3) + timestamps[1] = round(end, 3) + entry["timestamps"] = timestamps + + def _prefix_audio_speakers(self, value: Any, prefix: str) -> Any: + if isinstance(value, list): + return [self._prefix_audio_speakers(item, prefix) for item in value] + if isinstance(value, dict): + copied: dict[str, Any] = {} + for key, item in value.items(): + if ( + key in self._AUDIO_SPEAKER_FIELDS + and isinstance(item, str) + and item.strip() + ): + copied[key] = f"{prefix}{item.strip()}" + else: + copied[key] = self._prefix_audio_speakers(item, prefix) + return copied + return value + + def _shift_audio_items( + self, + items: list[Any], + *, + chunk_start_sec: float, + leading_overlap_sec: float, + speaker_prefix: str, + ) -> list[Any]: + shifted: list[Any] = [] + for item in items: + copied = self._prefix_audio_speakers(copy.deepcopy(item), speaker_prefix) + if isinstance(copied, dict): + bounds = self._entry_time_bounds(copied) + if bounds is not None: + start, end = bounds + if end <= leading_overlap_sec + 1e-6: + continue + if start < leading_overlap_sec: + start = leading_overlap_sec + self._set_entry_time_bounds( + copied, + start=start + chunk_start_sec, + end=end + chunk_start_sec, + ) + shifted.append(copied) + return shifted + + def _collect_chunk_metadata( + self, + parsed: dict[str, Any], + *, + chunk_start_sec: float, + leading_overlap_sec: float, + speaker_prefix: str, + ) -> dict[str, list[Any]]: + aggregated: dict[str, list[Any]] = {} + if isinstance(parsed.get("segments"), list): + aggregated["segments"] = self._shift_audio_items( + parsed["segments"], + chunk_start_sec=chunk_start_sec, + leading_overlap_sec=leading_overlap_sec, + speaker_prefix=speaker_prefix, + ) + elif isinstance(parsed.get("chunks"), list): + aggregated["segments"] = self._shift_audio_items( + parsed["chunks"], + chunk_start_sec=chunk_start_sec, + leading_overlap_sec=leading_overlap_sec, + speaker_prefix=speaker_prefix, + ) + if isinstance(parsed.get("words"), list): + aggregated["words"] = self._shift_audio_items( + parsed["words"], + chunk_start_sec=chunk_start_sec, + leading_overlap_sec=leading_overlap_sec, + speaker_prefix=speaker_prefix, + ) + if isinstance(parsed.get("diarization"), list): + aggregated["diarization"] = self._shift_audio_items( + parsed["diarization"], + chunk_start_sec=chunk_start_sec, + leading_overlap_sec=leading_overlap_sec, + speaker_prefix=speaker_prefix, + ) + return aggregated + + def _audio_json_length(self, payload: dict[str, Any]) -> int: + return len(json.dumps(payload, indent=2, ensure_ascii=True)) + + def _truncate_audio_text( + self, + payload: dict[str, Any], + *, + max_chars: int, + ) -> None: + text = str(payload.get("text", "")) + if not text: + return + base = copy.deepcopy(payload) + base["text"] = "" + if self._audio_json_length(base) > max_chars: + payload["text"] = "" + payload.setdefault("truncation", {})["text_truncated_chars"] = len(text) + return + low = 0 + high = len(text) + while low < high: + mid = (low + high + 1) // 2 + base["text"] = text[:mid] + if self._audio_json_length(base) <= max_chars: + low = mid + else: + high = mid - 1 + payload["text"] = text[:low] + omitted = len(text) - low + if omitted > 0: + payload.setdefault("truncation", {})["text_truncated_chars"] = omitted + + def _serialize_audio_envelope( + self, + envelope: dict[str, Any], + *, + max_chars: int, + ) -> str: + payload = copy.deepcopy(envelope) + payload.setdefault("truncation", {"applied": False}) + if self._audio_json_length(payload) <= max_chars: + return json.dumps(payload, indent=2, ensure_ascii=True) + + truncation = payload.setdefault("truncation", {}) + truncation["applied"] = True + response = payload.get("response") + omitted_response_fields: dict[str, int] = {} + + if isinstance(response, dict): + removal_order = ["words", "diarization", "segments"] + if payload.get("mode") != "chunked": + removal_order.append("chunks") + for key in removal_order: + value = response.get(key) + if isinstance(value, list) and value: + omitted_response_fields[key] = len(value) + response.pop(key, None) + if self._audio_json_length(payload) <= max_chars: + break + if omitted_response_fields: + truncation["omitted_response_fields"] = omitted_response_fields + if ( + payload.get("mode") == "chunked" + and isinstance(response.get("chunks"), list) + and self._audio_json_length(payload) > max_chars + ): + chunk_summaries = response["chunks"] + keep = min(len(chunk_summaries), 12) + omitted = len(chunk_summaries) - keep + if omitted > 0: + response["chunks"] = chunk_summaries[:keep] + truncation["omitted_chunk_statuses"] = omitted + + if self._audio_json_length(payload) > max_chars: + self._truncate_audio_text(payload, max_chars=max_chars) + + if ( + isinstance(payload.get("response"), dict) + and isinstance(payload["response"].get("chunks"), list) + and self._audio_json_length(payload) > max_chars + ): + while ( + len(payload["response"]["chunks"]) > 3 + and self._audio_json_length(payload) > max_chars + ): + payload["response"]["chunks"].pop() + truncation["omitted_chunk_statuses"] = truncation.get( + "omitted_chunk_statuses", 0 + ) + 1 + + if self._audio_json_length(payload) > max_chars and isinstance( + payload.get("options"), dict + ): + if isinstance(payload["options"].get("context_bias"), list): + truncation["omitted_context_bias_phrases"] = len( + payload["options"]["context_bias"] + ) + payload["options"].pop("context_bias", None) + + return json.dumps(payload, indent=2, ensure_ascii=True) + + def audio_transcribe( + self, + path: str, + diarize: bool | None = None, + timestamp_granularities: list[str] | None = None, + context_bias: list[str] | None = None, + language: str | None = None, + model: str | None = None, + temperature: float | None = None, + chunking: str | None = None, + chunk_max_seconds: int | None = None, + chunk_overlap_seconds: float | None = None, + max_chunks: int | None = None, + continue_on_chunk_error: bool | None = None, + ) -> str: + resolved = self._resolve_path(path) + if not resolved.exists(): + return f"File not found: {path}" + if resolved.is_dir(): + return f"Path is a directory, not a file: {path}" + ext = resolved.suffix.lower() + if ext not in self._AUDIO_EXTENSIONS and ext not in self._VIDEO_EXTENSIONS: + return ( + f"Unsupported audio format: {ext or '(none)'}. " + f"Supported: {', '.join(sorted(self._AUDIO_EXTENSIONS | self._VIDEO_EXTENSIONS))}" + ) + if language and timestamp_granularities: + return ( + "language cannot be combined with timestamp_granularities for " + "Mistral offline transcription" + ) + chunk_mode = (chunking or "auto").strip().lower() + if chunk_mode not in self._AUDIO_CHUNKING_MODES: + return "chunking must be one of auto, off, or force" + if chunk_max_seconds is not None and not ( + self._AUDIO_MIN_CHUNK_SECONDS + <= float(chunk_max_seconds) + <= self._AUDIO_MAX_CHUNK_SECONDS + ): + return ( + "chunk_max_seconds must be between " + f"{int(self._AUDIO_MIN_CHUNK_SECONDS)} and {int(self._AUDIO_MAX_CHUNK_SECONDS)}" + ) + if chunk_overlap_seconds is not None and not ( + 0.0 <= float(chunk_overlap_seconds) <= self._AUDIO_MAX_CHUNK_OVERLAP_SECONDS + ): + return ( + "chunk_overlap_seconds must be between 0 and " + f"{int(self._AUDIO_MAX_CHUNK_OVERLAP_SECONDS)}" + ) + if max_chunks is not None and not (1 <= max_chunks <= self._AUDIO_MAX_CHUNKS): + return f"max_chunks must be between 1 and {self._AUDIO_MAX_CHUNKS}" + normalized_timestamps: list[str] | None = None + if timestamp_granularities: + seen: set[str] = set() + normalized_timestamps = [] + for item in timestamp_granularities: + value = item.strip().lower() + if not value: + continue + if value not in self._TIMESTAMP_GRANULARITIES: + return ( + "timestamp_granularities must be drawn from " + f"{', '.join(sorted(self._TIMESTAMP_GRANULARITIES))}" + ) + if value not in seen: + normalized_timestamps.append(value) + seen.add(value) + normalized_bias = [item.strip() for item in (context_bias or []) if item.strip()] + if len(normalized_bias) > 100: + return "context_bias supports at most 100 phrases" + chosen_model = (model or self.mistral_transcription_model or "").strip() + if not chosen_model: + return "No Mistral transcription model configured" + self._files_read.add(resolved) + rel = resolved.relative_to(self.root).as_posix() + options = self._audio_transcribe_options( + diarize=diarize, + timestamp_granularities=normalized_timestamps, + context_bias=normalized_bias, + language=language, + temperature=temperature, + chunking=chunk_mode, + chunk_max_seconds=chunk_max_seconds, + chunk_overlap_seconds=chunk_overlap_seconds, + max_chunks=max_chunks, + continue_on_chunk_error=continue_on_chunk_error, + ) + + try: + with tempfile.TemporaryDirectory(prefix="openplanter-audio-") as temp_root: + temp_dir = Path(temp_root) + upload_source = resolved + if self._is_video_extension(ext): + self._ensure_media_tools() + upload_source = temp_dir / "video-source.wav" + self._extract_audio_source(resolved, upload_source) + + try: + upload_size = upload_source.stat().st_size + except OSError as exc: + raise ToolError( + f"Failed to inspect audio file {upload_source.name}: {exc}" + ) from exc + + chunk_requested = chunk_mode == "force" or ( + chunk_mode == "auto" + and upload_size > self.mistral_transcription_max_bytes + ) + + if not chunk_requested: + parsed = self._mistral_transcription_request( + resolved=upload_source, + model=chosen_model, + diarize=diarize, + timestamp_granularities=normalized_timestamps, + context_bias=normalized_bias, + language=language, + temperature=temperature, + ) + envelope = { + "provider": "mistral", + "service": "transcription", + "path": rel, + "model": chosen_model, + "options": options, + "text": str(parsed.get("text", "")), + "response": parsed, + } + return self._serialize_audio_envelope( + envelope, max_chars=self._audio_transcribe_max_chars() + ) + + self._ensure_media_tools() + duration_sec = self._probe_media_duration(upload_source) + requested_chunk_seconds = float( + chunk_max_seconds or self.mistral_transcription_chunk_max_seconds + ) + requested_chunk_seconds = min( + requested_chunk_seconds, self._AUDIO_MAX_CHUNK_SECONDS + ) + effective_chunk_seconds = self._audio_chunk_seconds_budget( + requested_chunk_seconds + ) + effective_overlap_seconds = min( + float( + chunk_overlap_seconds + if chunk_overlap_seconds is not None + else self.mistral_transcription_chunk_overlap_seconds + ), + max(0.0, effective_chunk_seconds - 0.001), + ) + effective_max_chunks = max_chunks or self.mistral_transcription_max_chunks + chunk_plan = self._plan_audio_chunks( + duration_sec=duration_sec, + chunk_seconds=effective_chunk_seconds, + overlap_seconds=effective_overlap_seconds, + max_chunks=effective_max_chunks, + ) + warnings: list[str] = [] + chunk_statuses: list[dict[str, Any]] = [] + stitched_text = "" + partial = False + aggregated_response: dict[str, Any] = { + "speaker_scope": ( + "chunk_local_prefixed" if diarize else "not_requested" + ), + "chunks": chunk_statuses, + } + + for plan_entry in chunk_plan: + index = int(plan_entry["index"]) + start_sec = float(plan_entry["start_sec"]) + end_sec = float(plan_entry["end_sec"]) + duration_value = float(plan_entry["duration_sec"]) + leading_overlap_sec = float(plan_entry["leading_overlap_sec"]) + chunk_path = temp_dir / f"chunk-{index:03d}.wav" + try: + self._extract_audio_chunk( + upload_source, + chunk_path, + start_sec=start_sec, + duration_sec=duration_value, + ) + parsed = self._mistral_transcription_request( + resolved=chunk_path, + model=chosen_model, + diarize=diarize, + timestamp_granularities=normalized_timestamps, + context_bias=normalized_bias, + language=language, + temperature=temperature, + ) + except ToolError as exc: + partial = True + message = f"chunk {index} failed: {exc}" + chunk_statuses.append( + { + "index": index, + "start_sec": start_sec, + "end_sec": end_sec, + "status": "error", + "error": str(exc), + } + ) + if continue_on_chunk_error: + warnings.append(message) + continue + return f"audio_transcribe failed in chunk {index}: {exc}" + + chunk_text = str(parsed.get("text", "")).strip() + deduped_text = self._dedupe_audio_overlap_text( + stitched_text, chunk_text + ) + if deduped_text: + stitched_text = ( + f"{stitched_text} {deduped_text}".strip() + if stitched_text + else deduped_text + ) + + metadata = self._collect_chunk_metadata( + parsed, + chunk_start_sec=start_sec, + leading_overlap_sec=leading_overlap_sec, + speaker_prefix=f"c{index}_", + ) + for key, values in metadata.items(): + if values: + aggregated_response.setdefault(key, []).extend(values) + + chunk_statuses.append( + { + "index": index, + "start_sec": start_sec, + "end_sec": end_sec, + "status": "ok", + "text_chars": len(chunk_text), + } + ) + + if not any( + chunk.get("status") == "ok" for chunk in chunk_statuses + ): + return "audio_transcribe failed: no chunk completed successfully" + + envelope = { + "provider": "mistral", + "service": "transcription", + "mode": "chunked", + "path": rel, + "model": chosen_model, + "options": options, + "chunking": { + "strategy": "overlap_window", + "chunk_seconds": round(effective_chunk_seconds, 3), + "overlap_seconds": round(effective_overlap_seconds, 3), + "total_chunks": len(chunk_plan), + "failed_chunks": sum( + 1 for chunk in chunk_statuses if chunk["status"] != "ok" + ), + "partial": partial, + }, + "text": stitched_text, + "response": aggregated_response, + } + if warnings: + envelope["warnings"] = warnings + return self._serialize_audio_envelope( + envelope, max_chars=self._audio_transcribe_max_chars() + ) + except ToolError as exc: + return str(exc) + def write_file(self, path: str, content: str) -> str: resolved = self._resolve_path(path) if resolved.exists() and resolved.is_file() and resolved not in self._files_read: diff --git a/agent/tui.py b/agent/tui.py index 0d7184ec..3217f3dc 100644 --- a/agent/tui.py +++ b/agent/tui.py @@ -424,6 +424,7 @@ def _clip_event(text: str) -> str: _KEY_ARGS: dict[str, str] = { "read_file": "path", "read_image": "path", + "audio_transcribe": "path", "write_file": "path", "edit_file": "path", "hashline_edit": "path", diff --git a/openplanter-desktop/Cargo.lock b/openplanter-desktop/Cargo.lock index 39951ed9..f67ed7f2 100644 --- a/openplanter-desktop/Cargo.lock +++ b/openplanter-desktop/Cargo.lock @@ -2026,6 +2026,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3113,6 +3123,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -4580,6 +4591,12 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.24" diff --git a/openplanter-desktop/crates/op-core/Cargo.toml b/openplanter-desktop/crates/op-core/Cargo.toml index eaf10099..08671359 100644 --- a/openplanter-desktop/crates/op-core/Cargo.toml +++ b/openplanter-desktop/crates/op-core/Cargo.toml @@ -15,7 +15,7 @@ uuid = { workspace = true } regex = { workspace = true } async-trait = "0.1" tokio-util = { workspace = true } -reqwest = { version = "0.12", features = ["json", "stream"] } +reqwest = { version = "0.12", features = ["json", "multipart", "stream"] } reqwest-eventsource = "0.6" futures = "0.3" petgraph = { version = "0.7", features = ["serde-1"] } diff --git a/openplanter-desktop/crates/op-core/src/builder.rs b/openplanter-desktop/crates/op-core/src/builder.rs index a0c4e319..8be61129 100644 --- a/openplanter-desktop/crates/op-core/src/builder.rs +++ b/openplanter-desktop/crates/op-core/src/builder.rs @@ -386,6 +386,17 @@ mod tests { assert_eq!(resolve_provider(&cfg).unwrap(), "ollama"); } + #[test] + fn test_resolve_provider_auto_ignores_mistral_transcription_key() { + let cfg = AgentConfig { + provider: "auto".into(), + model: "some-unknown-model".into(), + mistral_transcription_api_key: Some("mistral-test".into()), + ..Default::default() + }; + assert_eq!(resolve_provider(&cfg).unwrap(), "anthropic"); + } + #[test] fn test_resolve_provider_anthropic_key_preferred_first() { let cfg = AgentConfig { diff --git a/openplanter-desktop/crates/op-core/src/config.rs b/openplanter-desktop/crates/op-core/src/config.rs index 194e9a16..2a92486c 100644 --- a/openplanter-desktop/crates/op-core/src/config.rs +++ b/openplanter-desktop/crates/op-core/src/config.rs @@ -5,6 +5,13 @@ use std::sync::LazyLock; use serde::{Deserialize, Serialize}; +pub const MISTRAL_TRANSCRIPTION_BASE_URL: &str = "https://api.mistral.ai"; +pub const MISTRAL_TRANSCRIPTION_DEFAULT_MODEL: &str = "voxtral-mini-latest"; +pub const MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS: i64 = 900; +pub const MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS: f64 = 2.0; +pub const MISTRAL_TRANSCRIPTION_MAX_CHUNKS: i64 = 48; +pub const MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC: i64 = 180; + /// Default model for each supported provider. pub static PROVIDER_DEFAULT_MODELS: LazyLock> = LazyLock::new(|| { @@ -64,6 +71,7 @@ pub struct AgentConfig { pub cerebras_base_url: String, pub ollama_base_url: String, pub exa_base_url: String, + pub mistral_transcription_base_url: String, // API keys pub api_key: Option, @@ -73,6 +81,13 @@ pub struct AgentConfig { pub cerebras_api_key: Option, pub exa_api_key: Option, pub voyage_api_key: Option, + pub mistral_transcription_api_key: Option, + pub mistral_transcription_model: String, + pub mistral_transcription_max_bytes: i64, + pub mistral_transcription_chunk_max_seconds: i64, + pub mistral_transcription_chunk_overlap_seconds: f64, + pub mistral_transcription_max_chunks: i64, + pub mistral_transcription_request_timeout_sec: i64, // Limits pub max_depth: i64, @@ -116,6 +131,7 @@ impl Default for AgentConfig { cerebras_base_url: "https://api.cerebras.ai/v1".into(), ollama_base_url: "http://localhost:11434/v1".into(), exa_base_url: "https://api.exa.ai".into(), + mistral_transcription_base_url: MISTRAL_TRANSCRIPTION_BASE_URL.into(), api_key: None, openai_api_key: None, anthropic_api_key: None, @@ -123,6 +139,14 @@ impl Default for AgentConfig { cerebras_api_key: None, exa_api_key: None, voyage_api_key: None, + mistral_transcription_api_key: None, + mistral_transcription_model: MISTRAL_TRANSCRIPTION_DEFAULT_MODEL.into(), + mistral_transcription_max_bytes: 100 * 1024 * 1024, + mistral_transcription_chunk_max_seconds: MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS, + mistral_transcription_chunk_overlap_seconds: + MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS, + mistral_transcription_max_chunks: MISTRAL_TRANSCRIPTION_MAX_CHUNKS, + mistral_transcription_request_timeout_sec: MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC, max_depth: 4, max_steps_per_call: 100, budget_extension_enabled: true, @@ -175,6 +199,10 @@ impl AgentConfig { let voyage_api_key = env_opt("OPENPLANTER_VOYAGE_API_KEY") .or_else(|| env_opt("VOYAGE_API_KEY")); + let mistral_transcription_api_key = env_opt("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + .or_else(|| env_opt("MISTRAL_TRANSCRIPTION_API_KEY")) + .or_else(|| env_opt("MISTRAL_API_KEY")); + let openai_base_url = env_opt("OPENPLANTER_OPENAI_BASE_URL") .or_else(|| env_opt("OPENPLANTER_BASE_URL")) .unwrap_or_else(|| "https://api.openai.com/v1".into()); @@ -222,12 +250,40 @@ impl AgentConfig { "http://localhost:11434/v1", ), exa_base_url: env_or("OPENPLANTER_EXA_BASE_URL", "https://api.exa.ai"), + mistral_transcription_base_url: env_opt("OPENPLANTER_MISTRAL_TRANSCRIPTION_BASE_URL") + .or_else(|| env_opt("MISTRAL_TRANSCRIPTION_BASE_URL")) + .or_else(|| env_opt("MISTRAL_BASE_URL")) + .unwrap_or_else(|| MISTRAL_TRANSCRIPTION_BASE_URL.into()), openai_api_key, anthropic_api_key, openrouter_api_key, cerebras_api_key, exa_api_key, voyage_api_key, + mistral_transcription_api_key, + mistral_transcription_model: env_opt("OPENPLANTER_MISTRAL_TRANSCRIPTION_MODEL") + .or_else(|| env_opt("MISTRAL_TRANSCRIPTION_MODEL")) + .unwrap_or_else(|| MISTRAL_TRANSCRIPTION_DEFAULT_MODEL.into()), + mistral_transcription_max_bytes: env_int( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_MAX_BYTES", + 100 * 1024 * 1024, + ), + mistral_transcription_chunk_max_seconds: env_int( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS", + MISTRAL_TRANSCRIPTION_CHUNK_MAX_SECONDS, + ), + mistral_transcription_chunk_overlap_seconds: env_float( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS", + MISTRAL_TRANSCRIPTION_CHUNK_OVERLAP_SECONDS, + ), + mistral_transcription_max_chunks: env_int( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_MAX_CHUNKS", + MISTRAL_TRANSCRIPTION_MAX_CHUNKS, + ), + mistral_transcription_request_timeout_sec: env_int( + "OPENPLANTER_MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC", + MISTRAL_TRANSCRIPTION_REQUEST_TIMEOUT_SEC, + ), max_depth: env_int("OPENPLANTER_MAX_DEPTH", 4), max_steps_per_call: env_int("OPENPLANTER_MAX_STEPS", 100), budget_extension_enabled: env_bool("OPENPLANTER_BUDGET_EXTENSION_ENABLED", true), diff --git a/openplanter-desktop/crates/op-core/src/config_hydration.rs b/openplanter-desktop/crates/op-core/src/config_hydration.rs index 840e31a4..7bf5de7b 100644 --- a/openplanter-desktop/crates/op-core/src/config_hydration.rs +++ b/openplanter-desktop/crates/op-core/src/config_hydration.rs @@ -36,6 +36,7 @@ pub fn merge_credentials_into_config( merge!(cerebras_api_key); merge!(exa_api_key); merge!(voyage_api_key); + merge!(mistral_transcription_api_key); } pub fn apply_settings_to_config(cfg: &mut AgentConfig, settings: &PersistentSettings) { diff --git a/openplanter-desktop/crates/op-core/src/credentials.rs b/openplanter-desktop/crates/op-core/src/credentials.rs index cafed675..22615d73 100644 --- a/openplanter-desktop/crates/op-core/src/credentials.rs +++ b/openplanter-desktop/crates/op-core/src/credentials.rs @@ -18,24 +18,23 @@ pub struct CredentialBundle { pub cerebras_api_key: Option, pub exa_api_key: Option, pub voyage_api_key: Option, + pub mistral_transcription_api_key: Option, } impl CredentialBundle { /// Returns `true` if any key has a non-empty value. pub fn has_any(&self) -> bool { - let keys: [&Option; 6] = [ + let keys = [ &self.openai_api_key, &self.anthropic_api_key, &self.openrouter_api_key, &self.cerebras_api_key, &self.exa_api_key, &self.voyage_api_key, + &self.mistral_transcription_api_key, ]; - keys.iter().any(|k| { - k.as_ref() - .map(|v| !v.trim().is_empty()) - .unwrap_or(false) - }) + keys.iter() + .any(|k| k.as_ref().map(|v| !v.trim().is_empty()).unwrap_or(false)) } /// Fill in missing keys from `other`. @@ -53,6 +52,7 @@ impl CredentialBundle { fill!(cerebras_api_key); fill!(exa_api_key); fill!(voyage_api_key); + fill!(mistral_transcription_api_key); } /// Serialize to JSON map, omitting `None` values. @@ -71,6 +71,10 @@ impl CredentialBundle { add!(cerebras_api_key, "cerebras_api_key"); add!(exa_api_key, "exa_api_key"); add!(voyage_api_key, "voyage_api_key"); + add!( + mistral_transcription_api_key, + "mistral_transcription_api_key" + ); out } @@ -89,6 +93,7 @@ impl CredentialBundle { cerebras_api_key: get_str(payload, "cerebras_api_key"), exa_api_key: get_str(payload, "exa_api_key"), voyage_api_key: get_str(payload, "voyage_api_key"), + mistral_transcription_api_key: get_str(payload, "mistral_transcription_api_key"), } } } @@ -152,13 +157,15 @@ pub fn parse_env_file(path: &Path) -> CredentialBundle { "OPENROUTER_API_KEY", "OPENPLANTER_OPENROUTER_API_KEY", ), - cerebras_api_key: get_key( - &env_map, - "CEREBRAS_API_KEY", - "OPENPLANTER_CEREBRAS_API_KEY", - ), + cerebras_api_key: get_key(&env_map, "CEREBRAS_API_KEY", "OPENPLANTER_CEREBRAS_API_KEY"), exa_api_key: get_key(&env_map, "EXA_API_KEY", "OPENPLANTER_EXA_API_KEY"), voyage_api_key: get_key(&env_map, "VOYAGE_API_KEY", "OPENPLANTER_VOYAGE_API_KEY"), + mistral_transcription_api_key: env_map + .get("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + .or_else(|| env_map.get("MISTRAL_TRANSCRIPTION_API_KEY")) + .or_else(|| env_map.get("MISTRAL_API_KEY")) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), } } @@ -179,6 +186,12 @@ pub fn credentials_from_env() -> CredentialBundle { cerebras_api_key: env_key("OPENPLANTER_CEREBRAS_API_KEY", "CEREBRAS_API_KEY"), exa_api_key: env_key("OPENPLANTER_EXA_API_KEY", "EXA_API_KEY"), voyage_api_key: env_key("OPENPLANTER_VOYAGE_API_KEY", "VOYAGE_API_KEY"), + mistral_transcription_api_key: env::var("OPENPLANTER_MISTRAL_TRANSCRIPTION_API_KEY") + .ok() + .or_else(|| env::var("MISTRAL_TRANSCRIPTION_API_KEY").ok()) + .or_else(|| env::var("MISTRAL_API_KEY").ok()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()), } } @@ -316,6 +329,33 @@ mod tests { assert!(bundle.has_any()); } + #[test] + fn test_credential_bundle_has_any_with_voyage_key() { + let bundle = CredentialBundle { + voyage_api_key: Some("voyage-test".into()), + ..Default::default() + }; + assert!(bundle.has_any()); + } + + #[test] + fn test_credential_bundle_whitespace_only_values_do_not_count() { + let bundle = CredentialBundle { + voyage_api_key: Some(" ".into()), + ..Default::default() + }; + assert!(!bundle.has_any()); + } + + #[test] + fn test_credential_bundle_has_any_with_mistral_transcription_key() { + let bundle = CredentialBundle { + mistral_transcription_api_key: Some("mistral-test".into()), + ..Default::default() + }; + assert!(bundle.has_any()); + } + #[test] fn test_credential_bundle_merge_missing() { let mut a = CredentialBundle { @@ -325,11 +365,16 @@ mod tests { let b = CredentialBundle { openai_api_key: Some("should-not-overwrite".into()), anthropic_api_key: Some("new-key".into()), + mistral_transcription_api_key: Some("mistral-key".into()), ..Default::default() }; a.merge_missing(&b); assert_eq!(a.openai_api_key, Some("existing".into())); assert_eq!(a.anthropic_api_key, Some("new-key".into())); + assert_eq!( + a.mistral_transcription_api_key, + Some("mistral-key".into()) + ); } #[test] @@ -338,12 +383,17 @@ mod tests { openai_api_key: Some("sk-123".into()), anthropic_api_key: None, openrouter_api_key: Some("or-456".into()), + mistral_transcription_api_key: Some("mistral-789".into()), ..Default::default() }; let json = bundle.to_json(); assert_eq!(json.get("openai_api_key").unwrap(), "sk-123"); assert!(!json.contains_key("anthropic_api_key")); assert_eq!(json.get("openrouter_api_key").unwrap(), "or-456"); + assert_eq!( + json.get("mistral_transcription_api_key").unwrap(), + "mistral-789" + ); } #[test] @@ -357,6 +407,7 @@ mod tests { OPENAI_API_KEY=sk-from-env export ANTHROPIC_API_KEY='ant-key' EXA_API_KEY="exa-quoted" +MISTRAL_API_KEY=mistral-from-env UNRELATED_VAR=foo "#, ) @@ -366,6 +417,10 @@ UNRELATED_VAR=foo assert_eq!(bundle.openai_api_key, Some("sk-from-env".into())); assert_eq!(bundle.anthropic_api_key, Some("ant-key".into())); assert_eq!(bundle.exa_api_key, Some("exa-quoted".into())); + assert_eq!( + bundle.mistral_transcription_api_key, + Some("mistral-from-env".into()) + ); assert!(bundle.cerebras_api_key.is_none()); } @@ -384,7 +439,10 @@ UNRELATED_VAR=foo env_map.get("OPENPLANTER_WORKSPACE"), Some(&"workspace".to_string()) ); - assert_eq!(env_map.get("OPENAI_API_KEY"), Some(&"sk-from-env".to_string())); + assert_eq!( + env_map.get("OPENAI_API_KEY"), + Some(&"sk-from-env".to_string()) + ); } #[test] @@ -406,12 +464,17 @@ UNRELATED_VAR=foo let bundle = CredentialBundle { openai_api_key: Some("sk-test".into()), anthropic_api_key: Some("ant-test".into()), + mistral_transcription_api_key: Some("mistral-test".into()), ..Default::default() }; store.save(&bundle).unwrap(); let loaded = store.load(); assert_eq!(loaded.openai_api_key, Some("sk-test".into())); assert_eq!(loaded.anthropic_api_key, Some("ant-test".into())); + assert_eq!( + loaded.mistral_transcription_api_key, + Some("mistral-test".into()) + ); } #[test] diff --git a/openplanter-desktop/crates/op-core/src/engine/mod.rs b/openplanter-desktop/crates/op-core/src/engine/mod.rs index ea3b1517..4b6faa18 100644 --- a/openplanter-desktop/crates/op-core/src/engine/mod.rs +++ b/openplanter-desktop/crates/op-core/src/engine/mod.rs @@ -481,6 +481,7 @@ fn is_recon_tool(name: &str) -> bool { | "fetch_url" | "read_file" | "read_image" + | "audio_transcribe" | "list_artifacts" | "read_artifact" ) diff --git a/openplanter-desktop/crates/op-core/src/model/mod.rs b/openplanter-desktop/crates/op-core/src/model/mod.rs index 4f2781ec..81b04ca3 100644 --- a/openplanter-desktop/crates/op-core/src/model/mod.rs +++ b/openplanter-desktop/crates/op-core/src/model/mod.rs @@ -8,6 +8,24 @@ use serde::{Deserialize, Serialize}; use crate::events::DeltaEvent; use tokio_util::sync::CancellationToken; +/// Structured model error for provider rate limiting. +#[derive(Debug, Clone)] +pub struct RateLimitError { + pub message: String, + pub status_code: Option, + pub provider_code: Option, + pub body: String, + pub retry_after_sec: Option, +} + +impl std::fmt::Display for RateLimitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for RateLimitError {} + /// A single tool call returned by the model. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { diff --git a/openplanter-desktop/crates/op-core/src/tools/audio.rs b/openplanter-desktop/crates/op-core/src/tools/audio.rs new file mode 100644 index 00000000..cfee6185 --- /dev/null +++ b/openplanter-desktop/crates/op-core/src/tools/audio.rs @@ -0,0 +1,1514 @@ +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::time::Duration; + +use reqwest::multipart::{Form, Part}; +use serde_json::{Map, Value, json}; +use tokio::process::Command; +use tokio::time::timeout; +use uuid::Uuid; + +use super::{ToolResult, filesystem}; + +const AUDIO_EXTENSIONS: &[&str] = &[ + ".aac", ".flac", ".m4a", ".mp3", ".mpeg", ".mpga", ".oga", ".ogg", ".opus", ".wav", +]; +const VIDEO_EXTENSIONS: &[&str] = &[".avi", ".m4v", ".mkv", ".mov", ".mp4", ".webm"]; +const TIMESTAMP_GRANULARITIES: &[&str] = &["segment", "word"]; +const CHUNKING_MODES: &[&str] = &["auto", "force", "off"]; +const AUDIO_CHUNK_TARGET_FILL_RATIO: f64 = 0.85; +const AUDIO_CHUNK_BYTES_PER_SECOND: f64 = 32_000.0; +const AUDIO_MIN_CHUNK_SECONDS: f64 = 30.0; +const AUDIO_MAX_CHUNK_SECONDS: f64 = 1800.0; +const AUDIO_MAX_CHUNK_OVERLAP_SECONDS: f64 = 15.0; +const AUDIO_MAX_CHUNKS: i64 = 200; +const SPEAKER_FIELDS: &[&str] = &["speaker", "speaker_id", "speaker_label"]; + +#[derive(Debug, Clone)] +struct ChunkPlan { + index: usize, + start_sec: f64, + end_sec: f64, + duration_sec: f64, + leading_overlap_sec: f64, +} + +struct TempAudioDir { + path: PathBuf, +} + +impl TempAudioDir { + fn new() -> Result { + let path = std::env::temp_dir().join(format!("openplanter-audio-{}", Uuid::new_v4())); + std::fs::create_dir_all(&path) + .map_err(|error| format!("Failed to create temp audio directory: {error}"))?; + Ok(Self { path }) + } +} + +impl Drop for TempAudioDir { + fn drop(&mut self) { + let _ = std::fs::remove_dir_all(&self.path); + } +} + +fn transcription_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.ends_with("/v1") { + format!("{trimmed}/audio/transcriptions") + } else { + format!("{trimmed}/v1/audio/transcriptions") + } +} + +fn audio_media_type(path: &Path) -> &'static str { + match path + .extension() + .and_then(|value| value.to_str()) + .map(|value| value.to_ascii_lowercase()) + .as_deref() + { + Some("aac") => "audio/aac", + Some("flac") => "audio/flac", + Some("m4a") => "audio/mp4", + Some("mp3") | Some("mpga") => "audio/mpeg", + Some("mpeg") => "audio/mpeg", + Some("oga") | Some("ogg") | Some("opus") => "audio/ogg", + Some("wav") => "audio/wav", + _ => "application/octet-stream", + } +} + +fn rel_path(root: &Path, path: &Path) -> String { + let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.to_path_buf()); + path.strip_prefix(&canon_root) + .unwrap_or(path) + .to_string_lossy() + .replace('\\', "/") +} + +fn is_video_extension(ext: &str) -> bool { + VIDEO_EXTENSIONS.iter().any(|value| *value == ext) +} + +fn is_supported_extension(ext: &str) -> bool { + AUDIO_EXTENSIONS.iter().any(|value| *value == ext) || is_video_extension(ext) +} + +fn json_length(payload: &Value) -> usize { + serde_json::to_string_pretty(payload) + .unwrap_or_else(|_| payload.to_string()) + .len() +} + +fn build_options( + diarize: Option, + timestamp_granularities: Option<&[String]>, + context_bias: Option<&[String]>, + language: Option<&str>, + temperature: Option, + chunking: &str, + chunk_max_seconds: Option, + chunk_overlap_seconds: Option, + max_chunks: Option, + continue_on_chunk_error: Option, +) -> Value { + let mut options = Map::new(); + options.insert("chunking".into(), Value::String(chunking.to_string())); + if let Some(value) = diarize { + options.insert("diarize".into(), Value::Bool(value)); + } + if let Some(values) = timestamp_granularities.filter(|values| !values.is_empty()) { + options.insert( + "timestamp_granularities".into(), + Value::Array(values.iter().cloned().map(Value::String).collect()), + ); + } + if let Some(values) = context_bias.filter(|values| !values.is_empty()) { + options.insert( + "context_bias".into(), + Value::Array(values.iter().cloned().map(Value::String).collect()), + ); + } + if let Some(value) = language.filter(|value| !value.trim().is_empty()) { + options.insert("language".into(), Value::String(value.to_string())); + } + if let Some(value) = temperature { + if let Some(number) = serde_json::Number::from_f64(value) { + options.insert("temperature".into(), Value::Number(number)); + } + } + if let Some(value) = chunk_max_seconds { + options.insert("chunk_max_seconds".into(), Value::Number(value.into())); + } + if let Some(value) = chunk_overlap_seconds { + if let Some(number) = serde_json::Number::from_f64(value) { + options.insert("chunk_overlap_seconds".into(), Value::Number(number)); + } + } + if let Some(value) = max_chunks { + options.insert("max_chunks".into(), Value::Number(value.into())); + } + if let Some(value) = continue_on_chunk_error { + options.insert("continue_on_chunk_error".into(), Value::Bool(value)); + } + Value::Object(options) +} + +fn normalize_audio_token(token: &str) -> String { + token + .chars() + .filter(|ch| ch.is_ascii_alphanumeric()) + .flat_map(char::to_lowercase) + .collect() +} + +fn dedupe_audio_overlap_text(existing: &str, incoming: &str) -> String { + if existing.trim().is_empty() { + return incoming.trim().to_string(); + } + let current_tokens: Vec<&str> = incoming.split_whitespace().collect(); + if current_tokens.is_empty() { + return String::new(); + } + let previous_tokens: Vec<&str> = existing.split_whitespace().collect(); + let max_window = previous_tokens.len().min(current_tokens.len()).min(80); + if max_window < 5 { + return incoming.trim().to_string(); + } + let previous_norm: Vec = previous_tokens[previous_tokens.len() - max_window..] + .iter() + .map(|token| normalize_audio_token(token)) + .collect(); + let current_norm: Vec = current_tokens[..max_window] + .iter() + .map(|token| normalize_audio_token(token)) + .collect(); + for match_len in (5..=max_window).rev() { + if previous_norm[max_window - match_len..] == current_norm[..match_len] { + return current_tokens[match_len..].join(" ").trim().to_string(); + } + } + incoming.trim().to_string() +} + +fn which_binary(name: &str) -> bool { + std::env::var_os("PATH") + .map(|paths| { + std::env::split_paths(&paths).any(|path| { + let candidate = path.join(name); + let executable = candidate.is_file(); + if executable { + return true; + } + #[cfg(windows)] + { + return path.join(format!("{name}.exe")).is_file(); + } + #[cfg(not(windows))] + { + false + } + }) + }) + .unwrap_or(false) +} + +fn ensure_media_tools() -> Result<(), String> { + let missing: Vec<&str> = ["ffmpeg", "ffprobe"] + .into_iter() + .filter(|name| !which_binary(name)) + .collect(); + if missing.is_empty() { + Ok(()) + } else { + Err(format!( + "Long-form transcription requires {}. Install ffmpeg/ffprobe and retry.", + missing.join(", ") + )) + } +} + +async fn run_media_command( + program: &str, + args: &[String], + timeout_sec: u64, +) -> Result { + let mut command = Command::new(program); + command + .args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .stdin(Stdio::null()); + let output = timeout(Duration::from_secs(timeout_sec), command.output()) + .await + .map_err(|_| format!("{program} timed out after {timeout_sec}s"))? + .map_err(|error| format!("Media tooling not available: {program}: {error}"))?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + return Err(format!( + "{program} failed: {}", + if !stderr.is_empty() { + stderr + } else if !stdout.is_empty() { + stdout + } else { + "unknown error".to_string() + } + )); + } + Ok(String::from_utf8_lossy(&output.stdout).to_string()) +} + +async fn probe_media_duration(path: &Path, timeout_sec: u64) -> Result { + let stdout = run_media_command( + "ffprobe", + &[ + "-v".to_string(), + "error".to_string(), + "-print_format".to_string(), + "json".to_string(), + "-show_format".to_string(), + path.display().to_string(), + ], + timeout_sec, + ) + .await?; + let parsed: Value = serde_json::from_str(&stdout) + .map_err(|error| format!("ffprobe returned invalid JSON: {error}"))?; + let duration_value = parsed + .get("format") + .and_then(Value::as_object) + .and_then(|format| format.get("duration")) + .cloned() + .ok_or_else(|| { + format!( + "ffprobe did not return a valid duration for {}", + path.display() + ) + })?; + let parsed_duration = match duration_value { + Value::String(value) => value + .parse::() + .map_err(|error| format!("ffprobe did not return a valid duration: {error}"))?, + Value::Number(value) => value + .as_f64() + .ok_or_else(|| "ffprobe did not return a valid numeric duration".to_string())?, + _ => { + return Err(format!( + "ffprobe did not return a valid duration for {}", + path.display() + )); + } + }; + if parsed_duration <= 0.0 { + return Err(format!( + "ffprobe reported non-positive duration for {}", + path.display() + )); + } + Ok(parsed_duration) +} + +async fn extract_audio_source( + source: &Path, + output: &Path, + timeout_sec: u64, +) -> Result<(), String> { + run_media_command( + "ffmpeg", + &[ + "-nostdin".to_string(), + "-y".to_string(), + "-i".to_string(), + source.display().to_string(), + "-vn".to_string(), + "-ac".to_string(), + "1".to_string(), + "-ar".to_string(), + "16000".to_string(), + "-c:a".to_string(), + "pcm_s16le".to_string(), + output.display().to_string(), + ], + timeout_sec, + ) + .await + .map(|_| ()) +} + +async fn extract_audio_chunk( + source: &Path, + output: &Path, + start_sec: f64, + duration_sec: f64, + timeout_sec: u64, +) -> Result<(), String> { + run_media_command( + "ffmpeg", + &[ + "-nostdin".to_string(), + "-y".to_string(), + "-ss".to_string(), + format!("{start_sec:.3}"), + "-i".to_string(), + source.display().to_string(), + "-t".to_string(), + format!("{duration_sec:.3}"), + "-vn".to_string(), + "-ac".to_string(), + "1".to_string(), + "-ar".to_string(), + "16000".to_string(), + "-c:a".to_string(), + "pcm_s16le".to_string(), + output.display().to_string(), + ], + timeout_sec, + ) + .await + .map(|_| ()) +} + +fn audio_chunk_seconds_budget(max_bytes: usize, requested_seconds: f64) -> Result { + let safe_seconds = + (max_bytes as f64 * AUDIO_CHUNK_TARGET_FILL_RATIO) / AUDIO_CHUNK_BYTES_PER_SECOND; + if safe_seconds <= 0.0 { + return Err("Mistral transcription max-bytes budget is too small to chunk audio".into()); + } + Ok(requested_seconds.min(safe_seconds)) +} + +fn plan_audio_chunks( + duration_sec: f64, + chunk_seconds: f64, + overlap_seconds: f64, + max_chunks: i64, +) -> Result, String> { + if duration_sec <= 0.0 { + return Err("Cannot chunk media with non-positive duration".into()); + } + let chunk_seconds = chunk_seconds.max(1.0); + let overlap_seconds = overlap_seconds + .max(0.0) + .min((chunk_seconds - 0.001).max(0.0)); + let mut chunks = Vec::new(); + let mut start = 0.0; + while start < duration_sec - 1e-6 { + let end = (start + chunk_seconds).min(duration_sec); + let index = chunks.len(); + chunks.push(ChunkPlan { + index, + start_sec: (start * 1000.0).round() / 1000.0, + end_sec: (end * 1000.0).round() / 1000.0, + duration_sec: ((end - start) * 1000.0).round() / 1000.0, + leading_overlap_sec: if index == 0 { + 0.0 + } else { + (overlap_seconds * 1000.0).round() / 1000.0 + }, + }); + if chunks.len() as i64 > max_chunks { + return Err(format!( + "Chunk plan would create {} chunks (max {max_chunks})", + chunks.len() + )); + } + if end >= duration_sec - 1e-6 { + break; + } + let mut next_start = end - overlap_seconds; + if next_start <= start + 1e-6 { + next_start = end; + } + start = next_start; + } + Ok(chunks) +} + +fn entry_time_bounds(entry: &Map) -> Option<(f64, f64)> { + if let (Some(start), Some(end)) = ( + entry.get("start").and_then(Value::as_f64), + entry.get("end").and_then(Value::as_f64), + ) { + return Some((start, end)); + } + let timestamps = entry.get("timestamps")?.as_array()?; + if timestamps.len() < 2 { + return None; + } + Some((timestamps[0].as_f64()?, timestamps[1].as_f64()?)) +} + +fn set_entry_time_bounds(entry: &mut Map, start: f64, end: f64) { + if entry.contains_key("start") || entry.contains_key("end") { + entry.insert("start".into(), json!(((start * 1000.0).round() / 1000.0))); + entry.insert("end".into(), json!(((end * 1000.0).round() / 1000.0))); + } else if let Some(timestamps) = entry.get_mut("timestamps").and_then(Value::as_array_mut) { + while timestamps.len() < 2 { + timestamps.push(json!(0.0)); + } + timestamps[0] = json!(((start * 1000.0).round() / 1000.0)); + timestamps[1] = json!(((end * 1000.0).round() / 1000.0)); + } +} + +fn prefix_audio_speakers(value: &Value, prefix: &str) -> Value { + match value { + Value::Array(items) => Value::Array( + items + .iter() + .map(|item| prefix_audio_speakers(item, prefix)) + .collect(), + ), + Value::Object(object) => Value::Object( + object + .iter() + .map(|(key, item)| { + let value = if SPEAKER_FIELDS.contains(&key.as_str()) { + item.as_str() + .map(|speaker| Value::String(format!("{prefix}{speaker}"))) + .unwrap_or_else(|| prefix_audio_speakers(item, prefix)) + } else { + prefix_audio_speakers(item, prefix) + }; + (key.clone(), value) + }) + .collect(), + ), + _ => value.clone(), + } +} + +fn shift_audio_items( + items: &[Value], + chunk_start_sec: f64, + leading_overlap_sec: f64, + speaker_prefix: &str, +) -> Vec { + let mut shifted = Vec::new(); + for item in items { + let mut copied = prefix_audio_speakers(item, speaker_prefix); + if let Some(object) = copied.as_object_mut() { + if let Some((mut start, end)) = entry_time_bounds(object) { + if end <= leading_overlap_sec + 1e-6 { + continue; + } + if start < leading_overlap_sec { + start = leading_overlap_sec; + } + set_entry_time_bounds(object, start + chunk_start_sec, end + chunk_start_sec); + } + } + shifted.push(copied); + } + shifted +} + +fn collect_chunk_metadata( + parsed: &Value, + chunk_start_sec: f64, + leading_overlap_sec: f64, + speaker_prefix: &str, +) -> Map { + let mut aggregated = Map::new(); + if let Some(items) = parsed.get("segments").and_then(Value::as_array) { + aggregated.insert( + "segments".into(), + Value::Array(shift_audio_items( + items, + chunk_start_sec, + leading_overlap_sec, + speaker_prefix, + )), + ); + } else if let Some(items) = parsed.get("chunks").and_then(Value::as_array) { + aggregated.insert( + "segments".into(), + Value::Array(shift_audio_items( + items, + chunk_start_sec, + leading_overlap_sec, + speaker_prefix, + )), + ); + } + if let Some(items) = parsed.get("words").and_then(Value::as_array) { + aggregated.insert( + "words".into(), + Value::Array(shift_audio_items( + items, + chunk_start_sec, + leading_overlap_sec, + speaker_prefix, + )), + ); + } + if let Some(items) = parsed.get("diarization").and_then(Value::as_array) { + aggregated.insert( + "diarization".into(), + Value::Array(shift_audio_items( + items, + chunk_start_sec, + leading_overlap_sec, + speaker_prefix, + )), + ); + } + aggregated +} + +fn truncate_audio_text(payload: &mut Value, max_chars: usize) { + let original = payload + .get("text") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + if original.is_empty() { + return; + } + let mut base = payload.clone(); + base["text"] = Value::String(String::new()); + if json_length(&base) > max_chars { + payload["text"] = Value::String(String::new()); + payload["truncation"]["text_truncated_chars"] = json!(original.len()); + return; + } + + let mut low = 0usize; + let mut high = original.len(); + while low < high { + let mid = (low + high + 1) / 2; + let idx = original.floor_char_boundary(mid); + base["text"] = Value::String(original[..idx].to_string()); + if json_length(&base) <= max_chars { + low = idx; + } else if idx == 0 { + high = 0; + } else { + high = idx - 1; + } + } + let final_idx = original.floor_char_boundary(low); + payload["text"] = Value::String(original[..final_idx].to_string()); + let omitted = original.len().saturating_sub(final_idx); + if omitted > 0 { + payload["truncation"]["text_truncated_chars"] = json!(omitted); + } +} + +fn serialize_audio_envelope(mut payload: Value, max_chars: usize) -> String { + if payload.get("truncation").is_none() { + payload["truncation"] = json!({"applied": false}); + } + if json_length(&payload) <= max_chars { + return serde_json::to_string_pretty(&payload).unwrap_or_else(|_| payload.to_string()); + } + + payload["truncation"]["applied"] = Value::Bool(true); + let mode = payload + .get("mode") + .and_then(Value::as_str) + .unwrap_or("") + .to_string(); + + let mut omitted_response_fields = Map::new(); + let mut removal_order = vec!["words", "diarization", "segments"]; + if mode != "chunked" { + removal_order.push("chunks"); + } + for key in removal_order { + let removed = payload + .get_mut("response") + .and_then(Value::as_object_mut) + .and_then(|response| response.remove(key)); + if let Some(value) = removed { + if let Some(items) = value.as_array() { + if !items.is_empty() { + omitted_response_fields.insert(key.into(), json!(items.len())); + } + } + if json_length(&payload) <= max_chars { + break; + } + } + } + if !omitted_response_fields.is_empty() { + payload["truncation"]["omitted_response_fields"] = Value::Object(omitted_response_fields); + } + + if mode == "chunked" && json_length(&payload) > max_chars { + let omitted = payload + .get_mut("response") + .and_then(Value::as_object_mut) + .and_then(|response| response.get_mut("chunks")) + .and_then(Value::as_array_mut) + .map(|chunks| { + let keep = chunks.len().min(12); + let omitted = chunks.len().saturating_sub(keep); + if omitted > 0 { + chunks.truncate(keep); + } + omitted + }) + .unwrap_or(0); + if omitted > 0 { + payload["truncation"]["omitted_chunk_statuses"] = json!(omitted); + } + } + + if json_length(&payload) > max_chars { + truncate_audio_text(&mut payload, max_chars); + } + + if json_length(&payload) > max_chars { + while json_length(&payload) > max_chars { + let popped = payload + .get_mut("response") + .and_then(Value::as_object_mut) + .and_then(|response| response.get_mut("chunks")) + .and_then(Value::as_array_mut) + .map(|chunks| { + if chunks.len() > 3 { + chunks.pop(); + true + } else { + false + } + }) + .unwrap_or(false); + if !popped { + break; + } + let current = payload["truncation"] + .get("omitted_chunk_statuses") + .and_then(Value::as_u64) + .unwrap_or(0); + payload["truncation"]["omitted_chunk_statuses"] = json!(current + 1); + } + } + + if json_length(&payload) > max_chars { + if let Some(options) = payload.get_mut("options").and_then(Value::as_object_mut) { + if let Some(context_bias) = options.remove("context_bias") { + if let Some(items) = context_bias.as_array() { + payload["truncation"]["omitted_context_bias_phrases"] = json!(items.len()); + } + } + } + } + + serde_json::to_string_pretty(&payload).unwrap_or_else(|_| payload.to_string()) +} + +async fn mistral_transcription_request( + api_key: &str, + base_url: &str, + resolved: &Path, + model: &str, + diarize: Option, + timestamp_granularities: Option<&[String]>, + context_bias: Option<&[String]>, + language: Option<&str>, + temperature: Option, + max_bytes: usize, + request_timeout_sec: u64, +) -> Result { + let metadata = std::fs::metadata(resolved).map_err(|error| { + format!( + "Failed to inspect audio file {}: {error}", + resolved.display() + ) + })?; + if metadata.len() as usize > max_bytes { + return Err(format!( + "Audio file too large: {} bytes (max {} bytes)", + metadata.len(), + max_bytes + )); + } + let bytes = std::fs::read(resolved) + .map_err(|error| format!("Failed to read audio file {}: {error}", resolved.display()))?; + let filename = resolved + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or("audio"); + let mut form = Form::new() + .text("model", model.to_string()) + .text("stream", "false") + .part( + "file", + Part::bytes(bytes) + .file_name(filename.to_string()) + .mime_str(audio_media_type(resolved)) + .expect("audio_media_type always returns a valid MIME type"), + ); + if let Some(value) = diarize { + form = form.text("diarize", if value { "true" } else { "false" }); + } + if let Some(value) = language.filter(|value| !value.trim().is_empty()) { + form = form.text("language", value.to_string()); + } + if let Some(value) = temperature { + form = form.text("temperature", value.to_string()); + } + if let Some(values) = timestamp_granularities { + for value in values { + form = form.text("timestamp_granularities", value.clone()); + } + } + if let Some(values) = context_bias { + for value in values { + form = form.text("context_bias", value.clone()); + } + } + + let client = reqwest::Client::new(); + let response = client + .post(transcription_endpoint(base_url)) + .bearer_auth(api_key) + .timeout(Duration::from_secs(request_timeout_sec)) + .multipart(form) + .send() + .await + .map_err(|error| format!("Mistral transcription request failed: {error}"))?; + let status = response.status(); + let raw = response + .text() + .await + .map_err(|error| format!("Mistral transcription returned unreadable body: {error}"))?; + if !status.is_success() { + return Err(format!( + "Mistral transcription HTTP {}: {}", + status.as_u16(), + raw + )); + } + serde_json::from_str(&raw).map_err(|error| { + format!( + "Mistral transcription returned non-JSON payload: {error}: {}", + filesystem::clip(&raw, 500) + ) + }) +} + +#[allow(clippy::too_many_arguments)] +pub async fn audio_transcribe( + root: &Path, + api_key: Option<&str>, + base_url: &str, + default_model: &str, + max_bytes: usize, + default_chunk_max_seconds: i64, + default_chunk_overlap_seconds: f64, + default_max_chunks: i64, + path: &str, + diarize: Option, + timestamp_granularities: Option<&[String]>, + context_bias: Option<&[String]>, + language: Option<&str>, + model: Option<&str>, + temperature: Option, + chunking: Option<&str>, + chunk_max_seconds: Option, + chunk_overlap_seconds: Option, + max_chunks: Option, + continue_on_chunk_error: Option, + max_chars: usize, + command_timeout_sec: u64, + request_timeout_sec: u64, + files_read: &mut HashSet, +) -> ToolResult { + let resolved = match filesystem::resolve_path(root, path) { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + if !resolved.exists() { + return ToolResult::error(format!("File not found: {path}")); + } + if resolved.is_dir() { + return ToolResult::error(format!("Path is a directory, not a file: {path}")); + } + let ext = resolved + .extension() + .and_then(|value| value.to_str()) + .map(|value| format!(".{}", value.to_ascii_lowercase())) + .unwrap_or_default(); + if !is_supported_extension(&ext) { + let mut supported: Vec<&str> = AUDIO_EXTENSIONS.iter().copied().collect(); + supported.extend(VIDEO_EXTENSIONS.iter().copied()); + supported.sort_unstable(); + return ToolResult::error(format!( + "Unsupported audio format: {}. Supported: {}", + if ext.is_empty() { "(none)" } else { &ext }, + supported.join(", ") + )); + } + if language.is_some() && timestamp_granularities.is_some() { + return ToolResult::error( + "language cannot be combined with timestamp_granularities for Mistral offline transcription" + .into(), + ); + } + let chunk_mode = chunking.unwrap_or("auto").trim().to_ascii_lowercase(); + if !CHUNKING_MODES.iter().any(|value| *value == chunk_mode) { + return ToolResult::error("chunking must be one of auto, off, or force".into()); + } + if chunk_max_seconds + .map(|value| { + !(AUDIO_MIN_CHUNK_SECONDS as i64..=AUDIO_MAX_CHUNK_SECONDS as i64).contains(&value) + }) + .unwrap_or(false) + { + return ToolResult::error(format!( + "chunk_max_seconds must be between {} and {}", + AUDIO_MIN_CHUNK_SECONDS as i64, AUDIO_MAX_CHUNK_SECONDS as i64 + )); + } + if chunk_overlap_seconds + .map(|value| !(0.0..=AUDIO_MAX_CHUNK_OVERLAP_SECONDS).contains(&value)) + .unwrap_or(false) + { + return ToolResult::error(format!( + "chunk_overlap_seconds must be between 0 and {}", + AUDIO_MAX_CHUNK_OVERLAP_SECONDS as i64 + )); + } + if max_chunks + .map(|value| !(1..=AUDIO_MAX_CHUNKS).contains(&value)) + .unwrap_or(false) + { + return ToolResult::error(format!( + "max_chunks must be between 1 and {AUDIO_MAX_CHUNKS}" + )); + } + + let api_key = match api_key { + Some(value) if !value.trim().is_empty() => value, + _ => return ToolResult::error("Mistral transcription API key not configured".into()), + }; + let chosen_model = model.unwrap_or(default_model).trim(); + if chosen_model.is_empty() { + return ToolResult::error("No Mistral transcription model configured".into()); + } + let normalized_timestamps = timestamp_granularities.map(|values| { + values + .iter() + .map(|value| value.trim().to_ascii_lowercase()) + .filter(|value| !value.is_empty()) + .collect::>() + }); + if normalized_timestamps.as_ref().is_some_and(|values| { + values + .iter() + .any(|value| !TIMESTAMP_GRANULARITIES.contains(&value.as_str())) + }) { + return ToolResult::error(format!( + "timestamp_granularities must be drawn from {}", + TIMESTAMP_GRANULARITIES.join(", ") + )); + } + let normalized_bias = context_bias.map(|values| { + values + .iter() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + .collect::>() + }); + if normalized_bias + .as_ref() + .is_some_and(|values| values.len() > 100) + { + return ToolResult::error("context_bias supports at most 100 phrases".into()); + } + + let options = build_options( + diarize, + normalized_timestamps.as_deref(), + normalized_bias.as_deref(), + language, + temperature, + &chunk_mode, + chunk_max_seconds, + chunk_overlap_seconds, + max_chunks, + continue_on_chunk_error, + ); + + let temp_dir = match TempAudioDir::new() { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + let mut upload_source = resolved.clone(); + if is_video_extension(&ext) { + if let Err(error) = ensure_media_tools() { + return ToolResult::error(error); + } + let extracted = temp_dir.path.join("video-source.wav"); + if let Err(error) = extract_audio_source(&resolved, &extracted, command_timeout_sec).await { + return ToolResult::error(error); + } + upload_source = extracted; + } + + let upload_size = match std::fs::metadata(&upload_source) { + Ok(value) => value.len() as usize, + Err(error) => { + return ToolResult::error(format!( + "Failed to inspect audio file {}: {error}", + upload_source.display() + )); + } + }; + files_read.insert(resolved.clone()); + + let chunk_requested = + chunk_mode == "force" || (chunk_mode == "auto" && upload_size > max_bytes); + + if !chunk_requested { + let parsed = match mistral_transcription_request( + api_key, + base_url, + &upload_source, + chosen_model, + diarize, + normalized_timestamps.as_deref(), + normalized_bias.as_deref(), + language, + temperature, + max_bytes, + request_timeout_sec, + ) + .await + { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + let envelope = json!({ + "provider": "mistral", + "service": "transcription", + "path": rel_path(root, &resolved), + "model": chosen_model, + "options": options, + "text": parsed.get("text").and_then(Value::as_str).unwrap_or_default(), + "response": parsed, + }); + return ToolResult::ok(serialize_audio_envelope(envelope, max_chars)); + } + + if let Err(error) = ensure_media_tools() { + return ToolResult::error(error); + } + + let duration_sec = match probe_media_duration(&upload_source, command_timeout_sec).await { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + let requested_chunk_seconds = (chunk_max_seconds.unwrap_or(default_chunk_max_seconds) as f64) + .min(AUDIO_MAX_CHUNK_SECONDS); + let effective_chunk_seconds = + match audio_chunk_seconds_budget(max_bytes, requested_chunk_seconds) { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + let effective_overlap_seconds = chunk_overlap_seconds + .unwrap_or(default_chunk_overlap_seconds) + .min((effective_chunk_seconds - 0.001).max(0.0)); + let effective_max_chunks = max_chunks.unwrap_or(default_max_chunks); + let chunk_plan = match plan_audio_chunks( + duration_sec, + effective_chunk_seconds, + effective_overlap_seconds, + effective_max_chunks, + ) { + Ok(value) => value, + Err(error) => return ToolResult::error(error), + }; + + let mut chunk_statuses: Vec = Vec::new(); + let mut warnings: Vec = Vec::new(); + let mut stitched_text = String::new(); + let mut aggregated_response = Map::new(); + aggregated_response.insert( + "speaker_scope".into(), + Value::String(if diarize.unwrap_or(false) { + "chunk_local_prefixed".into() + } else { + "not_requested".into() + }), + ); + aggregated_response.insert("chunks".into(), Value::Array(Vec::new())); + let mut partial = false; + let continue_on_chunk_error = continue_on_chunk_error.unwrap_or(false); + + for chunk in &chunk_plan { + let chunk_path = temp_dir.path.join(format!("chunk-{:03}.wav", chunk.index)); + if let Err(error) = extract_audio_chunk( + &upload_source, + &chunk_path, + chunk.start_sec, + chunk.duration_sec, + command_timeout_sec, + ) + .await + { + partial = true; + chunk_statuses.push(json!({ + "index": chunk.index, + "start_sec": chunk.start_sec, + "end_sec": chunk.end_sec, + "status": "error", + "error": error, + })); + if continue_on_chunk_error { + warnings.push(format!("chunk {} failed: {error}", chunk.index)); + continue; + } + return ToolResult::error(format!( + "audio_transcribe failed in chunk {}: {error}", + chunk.index + )); + } + + let parsed = match mistral_transcription_request( + api_key, + base_url, + &chunk_path, + chosen_model, + diarize, + normalized_timestamps.as_deref(), + normalized_bias.as_deref(), + language, + temperature, + max_bytes, + request_timeout_sec, + ) + .await + { + Ok(value) => value, + Err(error) => { + partial = true; + chunk_statuses.push(json!({ + "index": chunk.index, + "start_sec": chunk.start_sec, + "end_sec": chunk.end_sec, + "status": "error", + "error": error, + })); + if continue_on_chunk_error { + warnings.push(format!("chunk {} failed: {error}", chunk.index)); + continue; + } + return ToolResult::error(format!( + "audio_transcribe failed in chunk {}: {error}", + chunk.index + )); + } + }; + + let chunk_text = parsed + .get("text") + .and_then(Value::as_str) + .unwrap_or_default(); + let deduped_text = dedupe_audio_overlap_text(&stitched_text, chunk_text); + if !deduped_text.is_empty() { + if stitched_text.is_empty() { + stitched_text = deduped_text; + } else { + stitched_text = format!("{stitched_text} {deduped_text}"); + } + } + + let metadata = collect_chunk_metadata( + &parsed, + chunk.start_sec, + chunk.leading_overlap_sec, + &format!("c{}_", chunk.index), + ); + for (key, value) in metadata { + if let Some(existing) = aggregated_response + .get_mut(&key) + .and_then(Value::as_array_mut) + { + if let Some(items) = value.as_array() { + existing.extend(items.iter().cloned()); + } + } else { + aggregated_response.insert(key, value); + } + } + + chunk_statuses.push(json!({ + "index": chunk.index, + "start_sec": chunk.start_sec, + "end_sec": chunk.end_sec, + "status": "ok", + "text_chars": chunk_text.len(), + })); + } + + if !chunk_statuses + .iter() + .any(|chunk| chunk.get("status").and_then(Value::as_str) == Some("ok")) + { + return ToolResult::error( + "audio_transcribe failed: no chunk completed successfully".into(), + ); + } + + aggregated_response.insert("chunks".into(), Value::Array(chunk_statuses.clone())); + let mut envelope = json!({ + "provider": "mistral", + "service": "transcription", + "mode": "chunked", + "path": rel_path(root, &resolved), + "model": chosen_model, + "options": options, + "chunking": { + "strategy": "overlap_window", + "chunk_seconds": ((effective_chunk_seconds * 1000.0).round() / 1000.0), + "overlap_seconds": ((effective_overlap_seconds * 1000.0).round() / 1000.0), + "total_chunks": chunk_plan.len(), + "failed_chunks": chunk_statuses.iter().filter(|chunk| { + chunk.get("status").and_then(Value::as_str) != Some("ok") + }).count(), + "partial": partial, + }, + "text": stitched_text.trim(), + "response": Value::Object(aggregated_response), + }); + if !warnings.is_empty() { + envelope["warnings"] = Value::Array(warnings.into_iter().map(Value::String).collect()); + } + ToolResult::ok(serialize_audio_envelope(envelope, max_chars)) +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{Json, Router, body::Bytes, routing::post}; + use std::sync::{Arc, Mutex}; + use tempfile::tempdir; + use tokio::net::TcpListener; + + async fn capture_transcription(body: Bytes) -> Json { + Json(json!({ + "text": "hello world", + "chunks": [{"text": "hello world", "timestamps": [0.0, 1.0]}], + "raw_body": String::from_utf8_lossy(&body).to_string(), + })) + } + + async fn spawn_server() -> String { + let app = Router::new().route("/v1/audio/transcriptions", post(capture_transcription)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + format!("http://{}", addr) + } + + fn install_fake_media_tools(root: &Path) { + let ffprobe = root.join("ffprobe"); + let ffmpeg = root.join("ffmpeg"); + std::fs::write( + &ffprobe, + "#!/bin/sh\nprintf '{\"format\":{\"duration\":\"50.0\"}}'\n", + ) + .unwrap(); + std::fs::write( + &ffmpeg, + "#!/bin/sh\nout=\"\"\nfor arg in \"$@\"; do out=\"$arg\"; done\nprintf 'chunk' > \"$out\"\n", + ) + .unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&ffprobe, std::fs::Permissions::from_mode(0o755)).unwrap(); + std::fs::set_permissions(&ffmpeg, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + } + + fn install_budget_sensitive_media_tools(root: &Path, duration_seconds: f64) { + let ffprobe = root.join("ffprobe"); + let ffmpeg = root.join("ffmpeg"); + std::fs::write( + &ffprobe, + format!("#!/bin/sh\nprintf '{{\"format\":{{\"duration\":\"{duration_seconds}\"}}}}'\n"), + ) + .unwrap(); + std::fs::write( + &ffmpeg, + "#!/bin/sh\nout=\"\"\nduration=\"\"\nprev=\"\"\nfor arg in \"$@\"; do\n if [ \"$prev\" = \"-t\" ]; then duration=\"$arg\"; fi\n prev=\"$arg\"\n out=\"$arg\"\ndone\nif [ -n \"$duration\" ]; then\n bytes=$(awk \"BEGIN { printf \\\"%d\\\", $duration * 32000 }\")\n dd if=/dev/zero of=\"$out\" bs=1 count=\"$bytes\" status=none\nelse\n printf 'chunk' > \"$out\"\nfi\n", + ) + .unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&ffprobe, std::fs::Permissions::from_mode(0o755)).unwrap(); + std::fs::set_permissions(&ffmpeg, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + } + + #[tokio::test] + async fn test_audio_transcribe_success() { + let dir = tempdir().unwrap(); + let audio = dir.path().join("clip.wav"); + std::fs::write(&audio, b"RIFF\x00\x00\x00\x00WAVEfmt ").unwrap(); + let root = dir.path().to_path_buf(); + let base_url = spawn_server().await; + let mut files_read = HashSet::new(); + + let result = audio_transcribe( + &root, + Some("mistral-key"), + &base_url, + "voxtral-mini-latest", + 1024 * 1024, + 900, + 2.0, + 48, + "clip.wav", + Some(true), + Some(&["segment".to_string()]), + Some(&["OpenPlanter".to_string()]), + None, + None, + Some(0.2), + None, + None, + None, + None, + None, + 20_000, + 5, + 5, + &mut files_read, + ) + .await; + + assert!(!result.is_error, "unexpected error: {}", result.content); + let parsed: Value = serde_json::from_str(&result.content).unwrap(); + assert_eq!(parsed["provider"], "mistral"); + assert_eq!(parsed["path"], "clip.wav"); + assert_eq!(parsed["text"], "hello world"); + assert_eq!(parsed["options"]["diarize"], true); + let raw_body = parsed["response"]["raw_body"].as_str().unwrap(); + assert!(raw_body.contains("name=\"model\"")); + assert!(raw_body.contains("name=\"timestamp_granularities\"")); + assert!(raw_body.contains("name=\"context_bias\"")); + } + + #[tokio::test] + async fn test_audio_transcribe_rejects_language_and_timestamps() { + let dir = tempdir().unwrap(); + let audio = dir.path().join("clip.wav"); + std::fs::write(&audio, b"RIFF\x00\x00\x00\x00WAVEfmt ").unwrap(); + let root = dir.path().to_path_buf(); + let mut files_read = HashSet::new(); + + let result = audio_transcribe( + &root, + Some("mistral-key"), + "https://api.mistral.ai", + "voxtral-mini-latest", + 1024 * 1024, + 900, + 2.0, + 48, + "clip.wav", + None, + Some(&["word".to_string()]), + None, + Some("en"), + None, + None, + None, + None, + None, + None, + None, + 20_000, + 5, + 5, + &mut files_read, + ) + .await; + + assert!(result.is_error); + assert!(result.content.contains("cannot be combined")); + } + + #[tokio::test] + async fn test_audio_transcribe_chunks_oversize_audio() { + let dir = tempdir().unwrap(); + install_fake_media_tools(dir.path()); + let original_path = std::env::var_os("PATH"); + unsafe { + let mut parts = vec![dir.path().to_path_buf()]; + if let Some(existing) = &original_path { + parts.extend(std::env::split_paths(existing)); + } + std::env::set_var("PATH", std::env::join_paths(parts).unwrap()); + } + + let counter = Arc::new(Mutex::new(0usize)); + let counter_clone = counter.clone(); + let app = Router::new().route( + "/v1/audio/transcriptions", + post(move |_body: Bytes| { + let counter = counter_clone.clone(); + async move { + let mut state = counter.lock().unwrap(); + let response = if *state == 0 { + json!({ + "text": "hello there general kenobi from tatooine", + "segments": [{"text":"hello there general kenobi from tatooine","start":0.0,"end":4.0,"speaker":"speaker_a"}] + }) + } else { + json!({ + "text": "there general kenobi from tatooine today", + "segments": [{"text":"there general kenobi from tatooine today","start":0.0,"end":4.0,"speaker":"speaker_a"}] + }) + }; + *state += 1; + Json(response) + } + }), + ); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let audio = dir.path().join("clip.wav"); + std::fs::write(&audio, vec![b'x'; 1_200_000]).unwrap(); + let root = dir.path().to_path_buf(); + let mut files_read = HashSet::new(); + + let result = audio_transcribe( + &root, + Some("mistral-key"), + &format!("http://{}", addr), + "voxtral-mini-latest", + 1_100_000, + 900, + 2.0, + 48, + "clip.wav", + Some(true), + None, + None, + None, + None, + None, + Some("auto"), + Some(30), + Some(2.0), + None, + None, + 20_000, + 5, + 5, + &mut files_read, + ) + .await; + + if let Some(value) = original_path { + unsafe { std::env::set_var("PATH", value) }; + } + + assert!(!result.is_error, "unexpected error: {}", result.content); + let parsed: Value = serde_json::from_str(&result.content).unwrap(); + assert_eq!(parsed["mode"], "chunked"); + assert_eq!( + parsed["text"], + "hello there general kenobi from tatooine today" + ); + assert_eq!(parsed["chunking"]["total_chunks"], 2); + assert_eq!(parsed["response"]["segments"][0]["speaker"], "c0_speaker_a"); + assert_eq!(parsed["response"]["segments"][1]["speaker"], "c1_speaker_a"); + } + + #[tokio::test] + async fn test_audio_transcribe_preserves_byte_budgeted_chunk_size() { + let dir = tempdir().unwrap(); + install_budget_sensitive_media_tools(dir.path(), 35.0); + let original_path = std::env::var_os("PATH"); + unsafe { + let mut parts = vec![dir.path().to_path_buf()]; + if let Some(existing) = &original_path { + parts.extend(std::env::split_paths(existing)); + } + std::env::set_var("PATH", std::env::join_paths(parts).unwrap()); + } + + let counter = Arc::new(Mutex::new(0usize)); + let counter_clone = counter.clone(); + let app = Router::new().route( + "/v1/audio/transcriptions", + post(move |_body: Bytes| { + let counter = counter_clone.clone(); + async move { + let mut state = counter.lock().unwrap(); + *state += 1; + Json(json!({ + "text": format!("chunk {}", *state), + })) + } + }), + ); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let audio = dir.path().join("clip.wav"); + std::fs::write(&audio, vec![b'x'; 512]).unwrap(); + let root = dir.path().to_path_buf(); + let mut files_read = HashSet::new(); + + let result = audio_transcribe( + &root, + Some("mistral-key"), + &format!("http://{}", addr), + "voxtral-mini-latest", + 300_000, + 900, + 0.0, + 48, + "clip.wav", + None, + None, + None, + None, + None, + None, + Some("force"), + Some(30), + Some(0.0), + None, + None, + 20_000, + 5, + 5, + &mut files_read, + ) + .await; + + if let Some(value) = original_path { + unsafe { std::env::set_var("PATH", value) }; + } + + assert!(!result.is_error, "unexpected error: {}", result.content); + let parsed: Value = serde_json::from_str(&result.content).unwrap(); + assert_eq!(parsed["mode"], "chunked"); + assert!(parsed["chunking"]["chunk_seconds"].as_f64().unwrap() < 30.0); + assert!(parsed["chunking"]["total_chunks"].as_u64().unwrap() >= 5); + } +} diff --git a/openplanter-desktop/crates/op-core/src/tools/defs.rs b/openplanter-desktop/crates/op-core/src/tools/defs.rs index 9f630fcb..0c422bc3 100644 --- a/openplanter-desktop/crates/op-core/src/tools/defs.rs +++ b/openplanter-desktop/crates/op-core/src/tools/defs.rs @@ -67,6 +67,68 @@ fn mvp_tool_defs() -> Vec { "additionalProperties": false }), }, + ToolDef { + name: "audio_transcribe", + description: "Transcribe a local audio file with Mistral's offline transcription API. Supports diarization, timestamp granularity, context bias, language, model override, temperature, and optional chunking for long-form audio/video.", + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Relative or absolute path to the audio file within the workspace." + }, + "diarize": { + "type": "boolean", + "description": "Whether to request speaker diarization." + }, + "timestamp_granularities": { + "type": "array", + "items": { "type": "string" }, + "description": "Optional timestamp granularity values such as 'segment' or 'word'." + }, + "context_bias": { + "type": "array", + "items": { "type": "string" }, + "description": "Optional bias phrases to steer transcription toward expected terms." + }, + "language": { + "type": "string", + "description": "Optional ISO language hint. Cannot be combined with timestamp_granularities." + }, + "model": { + "type": "string", + "description": "Optional transcription model override." + }, + "temperature": { + "type": "number", + "description": "Optional decoding temperature." + }, + "chunking": { + "type": "string", + "description": "Long-form handling mode: 'auto', 'off', or 'force'.", + "enum": ["auto", "off", "force"] + }, + "chunk_max_seconds": { + "type": "integer", + "description": "Maximum chunk duration in seconds for chunked transcription." + }, + "chunk_overlap_seconds": { + "type": "number", + "description": "Overlap between adjacent chunks in seconds." + }, + "max_chunks": { + "type": "integer", + "description": "Maximum number of chunks allowed for a transcription run." + }, + "continue_on_chunk_error": { + "type": "boolean", + "description": "Continue remaining chunks and return partial output if a chunk fails." + } + }, + "required": ["path"], + "additionalProperties": false + }), + }, ToolDef { name: "write_file", description: "Create or overwrite a file in the workspace with the given content.", @@ -501,6 +563,7 @@ mod tests { fn test_tool_names() { let names = tool_names(); assert!(names.contains(&"read_file")); + assert!(names.contains(&"audio_transcribe")); assert!(names.contains(&"run_shell")); assert!(names.contains(&"web_search")); assert!(names.contains(&"think")); diff --git a/openplanter-desktop/crates/op-core/src/tools/mod.rs b/openplanter-desktop/crates/op-core/src/tools/mod.rs index 6ae2065d..ce85b3b8 100644 --- a/openplanter-desktop/crates/op-core/src/tools/mod.rs +++ b/openplanter-desktop/crates/op-core/src/tools/mod.rs @@ -2,7 +2,7 @@ /// /// The `WorkspaceTools` struct is the central dispatcher that owns tool state /// (files-read set, background jobs) and routes tool calls to the appropriate module. - +pub mod audio; pub mod defs; pub mod filesystem; pub mod shell; @@ -56,6 +56,14 @@ pub struct WorkspaceTools { max_observation_chars: usize, exa_api_key: Option, exa_base_url: String, + mistral_transcription_api_key: Option, + mistral_transcription_base_url: String, + mistral_transcription_model: String, + mistral_transcription_max_bytes: usize, + mistral_transcription_chunk_max_seconds: i64, + mistral_transcription_chunk_overlap_seconds: f64, + mistral_transcription_max_chunks: i64, + mistral_transcription_request_timeout_sec: u64, files_read: HashSet, bg_jobs: shell::BgJobs, } @@ -74,6 +82,17 @@ impl WorkspaceTools { max_observation_chars: config.max_observation_chars as usize, exa_api_key: config.exa_api_key.clone(), exa_base_url: config.exa_base_url.clone(), + mistral_transcription_api_key: config.mistral_transcription_api_key.clone(), + mistral_transcription_base_url: config.mistral_transcription_base_url.clone(), + mistral_transcription_model: config.mistral_transcription_model.clone(), + mistral_transcription_max_bytes: config.mistral_transcription_max_bytes as usize, + mistral_transcription_chunk_max_seconds: config.mistral_transcription_chunk_max_seconds, + mistral_transcription_chunk_overlap_seconds: config + .mistral_transcription_chunk_overlap_seconds, + mistral_transcription_max_chunks: config.mistral_transcription_max_chunks, + mistral_transcription_request_timeout_sec: config + .mistral_transcription_request_timeout_sec + as u64, files_read: HashSet::new(), bg_jobs: shell::BgJobs::new(), } @@ -97,6 +116,17 @@ impl WorkspaceTools { max_observation_chars: config.max_observation_chars as usize, exa_api_key: config.exa_api_key.clone(), exa_base_url: config.exa_base_url.clone(), + mistral_transcription_api_key: config.mistral_transcription_api_key.clone(), + mistral_transcription_base_url: config.mistral_transcription_base_url.clone(), + mistral_transcription_model: config.mistral_transcription_model.clone(), + mistral_transcription_max_bytes: config.mistral_transcription_max_bytes as usize, + mistral_transcription_chunk_max_seconds: config.mistral_transcription_chunk_max_seconds, + mistral_transcription_chunk_overlap_seconds: config + .mistral_transcription_chunk_overlap_seconds, + mistral_transcription_max_chunks: config.mistral_transcription_max_chunks, + mistral_transcription_request_timeout_sec: config + .mistral_transcription_request_timeout_sec + as u64, files_read: HashSet::new(), bg_jobs: shell::BgJobs::new(), } @@ -181,6 +211,101 @@ impl WorkspaceTools { self.command_timeout_sec, ) } + "audio_transcribe" => { + let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(""); + let diarize = args.get("diarize").and_then(|v| v.as_bool()); + let timestamp_granularities: Option> = args + .get("timestamp_granularities") + .and_then(|v| { + if let Some(values) = v.as_array() { + Some( + values + .iter() + .filter_map(|value| { + value.as_str().map(|s| s.trim().to_string()) + }) + .filter(|value| !value.is_empty()) + .collect::>(), + ) + } else { + v.as_str().map(|value| vec![value.trim().to_string()]) + } + }) + .filter(|values| !values.is_empty()); + let context_bias: Option> = args + .get("context_bias") + .and_then(|v| { + if let Some(values) = v.as_array() { + Some( + values + .iter() + .filter_map(|value| { + value.as_str().map(|s| s.trim().to_string()) + }) + .filter(|value| !value.is_empty()) + .collect::>(), + ) + } else { + v.as_str().map(|value| { + value + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + .map(ToString::to_string) + .collect::>() + }) + } + }) + .filter(|values| !values.is_empty()); + let language = args + .get("language") + .and_then(|v| v.as_str()) + .filter(|value| !value.trim().is_empty()); + let model = args + .get("model") + .and_then(|v| v.as_str()) + .filter(|value| !value.trim().is_empty()); + let temperature = args.get("temperature").and_then(|v| v.as_f64()); + let chunking = args + .get("chunking") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|value| !value.is_empty()); + let chunk_max_seconds = args.get("chunk_max_seconds").and_then(|v| v.as_i64()); + let chunk_overlap_seconds = + args.get("chunk_overlap_seconds").and_then(|v| v.as_f64()); + let max_chunks = args.get("max_chunks").and_then(|v| v.as_i64()); + let continue_on_chunk_error = args + .get("continue_on_chunk_error") + .and_then(|v| v.as_bool()); + audio::audio_transcribe( + &self.root, + self.mistral_transcription_api_key.as_deref(), + &self.mistral_transcription_base_url, + &self.mistral_transcription_model, + self.mistral_transcription_max_bytes, + self.mistral_transcription_chunk_max_seconds, + self.mistral_transcription_chunk_overlap_seconds, + self.mistral_transcription_max_chunks, + path, + diarize, + timestamp_granularities.as_deref(), + context_bias.as_deref(), + language, + model, + temperature, + chunking, + chunk_max_seconds, + chunk_overlap_seconds, + max_chunks, + continue_on_chunk_error, + self.max_file_chars.min(self.max_observation_chars), + self.command_timeout_sec, + self.mistral_transcription_request_timeout_sec, + &mut self.files_read, + ) + .await + } // Shell "run_shell" => { diff --git a/openplanter-desktop/crates/op-core/tests/test_model_streaming.rs b/openplanter-desktop/crates/op-core/tests/test_model_streaming.rs index 5e792de0..2b8eab1a 100644 --- a/openplanter-desktop/crates/op-core/tests/test_model_streaming.rs +++ b/openplanter-desktop/crates/op-core/tests/test_model_streaming.rs @@ -13,7 +13,7 @@ use axum::routing::post; use axum::Router; use tokio_util::sync::CancellationToken; -use op_core::events::{DeltaEvent, DeltaKind}; +use op_core::events::{CompletionMeta, DeltaEvent, DeltaKind, LoopMetrics}; use op_core::model::openai::OpenAIModel; use op_core::model::anthropic::AnthropicModel; use op_core::model::{BaseModel, Message}; @@ -448,7 +448,12 @@ async fn test_solve_with_mock_anthropic() { fn emit_step(&self, event: StepEvent) { self.events.lock().unwrap().push(Ev::Step(event)); } - fn emit_complete(&self, result: &str) { + fn emit_complete( + &self, + result: &str, + _: Option, + _: Option, + ) { self.events.lock().unwrap().push(Ev::Complete(result.to_string())); } fn emit_error(&self, message: &str) { @@ -539,7 +544,12 @@ async fn test_solve_with_mock_openai() { fn emit_step(&self, event: StepEvent) { self.events.lock().unwrap().push(Ev2::Step(event)); } - fn emit_complete(&self, result: &str) { + fn emit_complete( + &self, + result: &str, + _: Option, + _: Option, + ) { self.events.lock().unwrap().push(Ev2::Complete(result.to_string())); } fn emit_error(&self, message: &str) { @@ -619,7 +629,7 @@ async fn test_solve_http_error_emits_error() { fn emit_trace(&self, _: &str) {} fn emit_delta(&self, _: DeltaEvent) {} fn emit_step(&self, _: StepEvent) {} - fn emit_complete(&self, _: &str) {} + fn emit_complete(&self, _: &str, _: Option, _: Option) {} fn emit_error(&self, msg: &str) { self.errors.lock().unwrap().push(msg.to_string()); } @@ -664,7 +674,7 @@ async fn test_solve_cancel_emits_cancelled() { fn emit_trace(&self, _: &str) {} fn emit_delta(&self, _: DeltaEvent) {} fn emit_step(&self, _: StepEvent) {} - fn emit_complete(&self, _: &str) {} + fn emit_complete(&self, _: &str, _: Option, _: Option) {} fn emit_error(&self, msg: &str) { self.events.lock().unwrap().push(msg.to_string()); } @@ -707,7 +717,12 @@ async fn test_solve_demo_mode_bypasses_llm() { fn emit_trace(&self, _: &str) {} fn emit_delta(&self, _: DeltaEvent) {} fn emit_step(&self, _: StepEvent) {} - fn emit_complete(&self, result: &str) { + fn emit_complete( + &self, + result: &str, + _: Option, + _: Option, + ) { self.events.lock().unwrap().push(result.to_string()); } fn emit_error(&self, msg: &str) { @@ -746,7 +761,7 @@ async fn test_solve_missing_key_emits_error() { fn emit_trace(&self, _: &str) {} fn emit_delta(&self, _: DeltaEvent) {} fn emit_step(&self, _: StepEvent) {} - fn emit_complete(&self, _: &str) {} + fn emit_complete(&self, _: &str, _: Option, _: Option) {} fn emit_error(&self, msg: &str) { self.errors.lock().unwrap().push(msg.to_string()); } @@ -872,7 +887,12 @@ async fn test_solve_multi_step_agentic_loop() { fn emit_step(&self, event: StepEvent) { self.events.lock().unwrap().push(Ev3::Step(event)); } - fn emit_complete(&self, result: &str) { + fn emit_complete( + &self, + result: &str, + _: Option, + _: Option, + ) { self.events.lock().unwrap().push(Ev3::Complete(result.to_string())); } fn emit_error(&self, message: &str) { diff --git a/openplanter-desktop/crates/op-tauri/src/commands/agent.rs b/openplanter-desktop/crates/op-tauri/src/commands/agent.rs index 0bf58ff4..d3eeb81f 100644 --- a/openplanter-desktop/crates/op-tauri/src/commands/agent.rs +++ b/openplanter-desktop/crates/op-tauri/src/commands/agent.rs @@ -203,14 +203,14 @@ mod tests { } #[tokio::test] - async fn test_build_solve_initial_context_degrades_to_no_packet_on_load_failure() { + async fn test_build_solve_initial_context_ignores_invalid_typed_state_without_warning() { let tmp = tempdir().unwrap(); fs::write(tmp.path().join("investigation_state.json"), "{not-json") .await .unwrap(); let (context, warning) = build_solve_initial_context(tmp.path(), "sid").await; - assert!(warning.is_some()); + assert!(warning.is_none()); assert!(context.question_reasoning_packet.is_none()); assert_eq!(context.session_id, Some("sid".to_string())); assert_eq!(context.session_dir, Some(tmp.path().display().to_string())); diff --git a/openplanter-desktop/crates/op-tauri/src/commands/config.rs b/openplanter-desktop/crates/op-tauri/src/commands/config.rs index 2015140c..e88fb985 100644 --- a/openplanter-desktop/crates/op-tauri/src/commands/config.rs +++ b/openplanter-desktop/crates/op-tauri/src/commands/config.rs @@ -142,6 +142,10 @@ pub fn build_credential_status(cfg: &op_core::config::AgentConfig) -> HashMap { cerebras: false, ollama: true, exa: false, + mistral_transcription: true, })); const status = await getCredentialsStatus(); expect(status.openai).toBe(true); expect(status.openrouter).toBe(false); + expect(status.mistral_transcription).toBe(true); }); it("listSessions sends limit", async () => { diff --git a/openplanter-desktop/frontend/src/commands/model.test.ts b/openplanter-desktop/frontend/src/commands/model.test.ts index 1e8bc2bd..5b191178 100644 --- a/openplanter-desktop/frontend/src/commands/model.test.ts +++ b/openplanter-desktop/frontend/src/commands/model.test.ts @@ -30,6 +30,11 @@ describe("inferProvider", () => { expect(inferProvider("llama3.2")).toBe("ollama"); }); + it("mistral chat models stay ollama while voxtral stays tool-only", () => { + expect(inferProvider("mistral")).toBe("ollama"); + expect(inferProvider("voxtral-mini-latest")).toBeNull(); + }); + it("qwen-3 returns cerebras", () => { expect(inferProvider("qwen-3-235b-a22b-instruct-2507")).toBe("cerebras"); }); diff --git a/openplanter-desktop/frontend/src/components/App.test.ts b/openplanter-desktop/frontend/src/components/App.test.ts index 1aaf912c..e2947a86 100644 --- a/openplanter-desktop/frontend/src/components/App.test.ts +++ b/openplanter-desktop/frontend/src/components/App.test.ts @@ -58,7 +58,7 @@ describe("createApp", () => { __setHandler("list_sessions", () => [SESSION_B, SESSION_A]); __setHandler("get_credentials_status", () => ({ openai: true, anthropic: true, openrouter: false, - cerebras: false, ollama: true, exa: false, + cerebras: false, ollama: true, exa: false, mistral_transcription: true, })); __setHandler("open_session", () => ({ id: "20260227-120000-cccc3333", @@ -105,7 +105,7 @@ describe("createApp", () => { await vi.waitFor(() => { const creds = root.querySelector(".cred-status"); - expect(creds!.children.length).toBe(6); + expect(creds!.children.length).toBe(7); expect(creds!.querySelector(".cred-ok")!.textContent).toContain("openai"); expect(creds!.querySelector(".cred-missing")!.textContent).toContain("openrouter"); }); diff --git a/openplanter-desktop/frontend/src/components/App.ts b/openplanter-desktop/frontend/src/components/App.ts index cdb5f98e..7655209f 100644 --- a/openplanter-desktop/frontend/src/components/App.ts +++ b/openplanter-desktop/frontend/src/components/App.ts @@ -311,7 +311,15 @@ async function loadCredentials(container: HTMLElement): Promise { try { const status = await getCredentialsStatus(); container.innerHTML = ""; - const providers = ["openai", "anthropic", "openrouter", "cerebras", "ollama", "exa"]; + const providers = [ + "openai", + "anthropic", + "openrouter", + "cerebras", + "ollama", + "exa", + "mistral_transcription", + ]; for (const p of providers) { const row = document.createElement("div"); const hasKey = status[p] ?? false; diff --git a/tests/test_audio_transcribe.py b/tests/test_audio_transcribe.py new file mode 100644 index 00000000..215a38f3 --- /dev/null +++ b/tests/test_audio_transcribe.py @@ -0,0 +1,445 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from agent.tool_defs import TOOL_DEFINITIONS +from agent.tools import ToolError, WorkspaceTools + + +def _write_audio(path: Path, payload: bytes = b"RIFF\x00\x00\x00\x00WAVEfmt ") -> None: + path.write_bytes(payload) + + +def _make_tools(tmp_path: Path, **overrides: object) -> WorkspaceTools: + defaults: dict[str, object] = { + "root": tmp_path, + "mistral_transcription_api_key": "mistral-key", + "max_file_chars": 20_000, + "max_observation_chars": 20_000, + } + defaults.update(overrides) + return WorkspaceTools(**defaults) + + +class TestAudioTranscribeTool: + def test_audio_transcribe_success_returns_wrapped_response(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio) + tools = _make_tools(tmp_path) + mocked = { + "text": "hello world", + "chunks": [{"text": "hello world", "timestamps": [0.0, 1.0]}], + } + + with pytest.MonkeyPatch.context() as mp: + mp.setattr( + tools, + "_mistral_transcription_request", + lambda **_: mocked, + ) + raw = tools.audio_transcribe( + "clip.wav", + diarize=True, + timestamp_granularities=["segment"], + context_bias=["OpenPlanter", "Mistral"], + model="voxtral-mini-latest", + temperature=0.2, + ) + + parsed = json.loads(raw) + assert parsed["provider"] == "mistral" + assert parsed["path"] == "clip.wav" + assert parsed["text"] == "hello world" + assert parsed["options"]["diarize"] is True + assert parsed["options"]["timestamp_granularities"] == ["segment"] + assert parsed["options"]["context_bias"] == ["OpenPlanter", "Mistral"] + assert parsed["response"]["chunks"][0]["text"] == "hello world" + + def test_audio_transcribe_requires_key(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio) + tools = WorkspaceTools(root=tmp_path) + out = tools.audio_transcribe("clip.wav") + assert "Mistral transcription API key not configured" in out + + def test_audio_transcribe_rejects_language_with_timestamps(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio) + tools = _make_tools(tmp_path) + out = tools.audio_transcribe( + "clip.wav", + language="en", + timestamp_granularities=["word"], + ) + assert "cannot be combined" in out + + def test_audio_transcribe_rejects_non_audio_extension(self, tmp_path: Path) -> None: + note = tmp_path / "notes.txt" + note.write_text("hello", encoding="utf-8") + tools = _make_tools(tmp_path) + out = tools.audio_transcribe("notes.txt") + assert "Unsupported audio format" in out + + def test_audio_transcribe_path_escape_blocked(self, tmp_path: Path) -> None: + tools = _make_tools(tmp_path) + with pytest.raises(ToolError, match="escapes workspace"): + tools.audio_transcribe("../../etc/passwd.wav") + + def test_audio_transcribe_auto_chunks_oversize_files(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 1_200_000) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=1_100_000, + ) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + mp.setattr(tools, "_probe_media_duration", lambda _: 50.0) + + def fake_extract( + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + output.write_bytes(b"chunk") + + responses = iter( + [ + { + "text": "hello there general kenobi from tatooine", + "segments": [ + { + "text": "hello there general kenobi from tatooine", + "start": 0.0, + "end": 4.0, + "speaker": "speaker_a", + } + ], + }, + { + "text": "there general kenobi from tatooine today", + "segments": [ + { + "text": "there general kenobi from tatooine today", + "start": 0.0, + "end": 4.0, + "speaker": "speaker_a", + } + ], + }, + ] + ) + mp.setattr(tools, "_extract_audio_chunk", fake_extract) + mp.setattr( + tools, + "_mistral_transcription_request", + lambda **_: next(responses), + ) + + raw = tools.audio_transcribe( + "clip.wav", + diarize=True, + chunk_max_seconds=30, + chunk_overlap_seconds=2, + ) + + parsed = json.loads(raw) + assert parsed["mode"] == "chunked" + assert parsed["text"] == "hello there general kenobi from tatooine today" + assert parsed["chunking"]["total_chunks"] == 2 + assert parsed["response"]["segments"][0]["speaker"] == "c0_speaker_a" + assert parsed["response"]["segments"][1]["speaker"] == "c1_speaker_a" + assert parsed["response"]["segments"][1]["start"] == pytest.approx( + parsed["chunking"]["chunk_seconds"], abs=0.01 + ) + assert parsed["response"]["segments"][1]["end"] == pytest.approx( + parsed["chunking"]["chunk_seconds"] + 2.0, abs=0.01 + ) + + def test_audio_transcribe_off_keeps_oversize_rejection(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 512) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=64, + ) + out = tools.audio_transcribe("clip.wav", chunking="off") + assert "Audio file too large" in out + + def test_audio_transcribe_preserves_byte_budgeted_chunk_size( + self, tmp_path: Path + ) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 512) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=300_000, + ) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + mp.setattr(tools, "_probe_media_duration", lambda _: 35.0) + + def fake_extract( + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + output.write_bytes(b"x" * int(duration_sec * 32_000)) + + observed_sizes: list[int] = [] + + def fake_request(*, resolved: Path, **_: object) -> dict[str, object]: + observed_sizes.append(resolved.stat().st_size) + if observed_sizes[-1] > tools.mistral_transcription_max_bytes: + raise ToolError( + f"Audio file too large: {observed_sizes[-1]:,} bytes " + f"(max {tools.mistral_transcription_max_bytes:,} bytes)" + ) + return {"text": f"chunk {len(observed_sizes)}"} + + mp.setattr(tools, "_extract_audio_chunk", fake_extract) + mp.setattr(tools, "_mistral_transcription_request", fake_request) + + raw = tools.audio_transcribe( + "clip.wav", + chunking="force", + chunk_max_seconds=30, + chunk_overlap_seconds=0, + ) + + parsed = json.loads(raw) + assert parsed["mode"] == "chunked" + assert parsed["chunking"]["chunk_seconds"] < 30 + assert observed_sizes + assert max(observed_sizes) <= tools.mistral_transcription_max_bytes + + def test_audio_transcribe_force_chunks_even_when_under_limit(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 32) + tools = _make_tools(tmp_path) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + mp.setattr(tools, "_probe_media_duration", lambda _: 58.0) + mp.setattr( + tools, + "_extract_audio_chunk", + lambda *args, **kwargs: kwargs["output"].write_bytes(b"chunk"), + raising=False, + ) + responses = iter( + [ + {"text": "one two three four five"}, + {"text": "three four five six"}, + ] + ) + + def fake_chunk( + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + output.write_bytes(b"chunk") + + mp.setattr(tools, "_extract_audio_chunk", fake_chunk) + mp.setattr( + tools, + "_mistral_transcription_request", + lambda **_: next(responses), + ) + raw = tools.audio_transcribe( + "clip.wav", + chunking="force", + chunk_max_seconds=30, + chunk_overlap_seconds=2, + ) + + parsed = json.loads(raw) + assert parsed["mode"] == "chunked" + assert parsed["options"]["chunking"] == "force" + + def test_audio_transcribe_reports_missing_media_tools(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 512) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=64, + ) + with pytest.MonkeyPatch.context() as mp: + mp.setattr( + tools, + "_ensure_media_tools", + lambda: (_ for _ in ()).throw( + ToolError( + "Long-form transcription requires ffmpeg, ffprobe. Install ffmpeg/ffprobe and retry." + ) + ), + ) + out = tools.audio_transcribe("clip.wav") + assert "ffmpeg" in out and "ffprobe" in out + + def test_audio_transcribe_extracts_video_before_upload(self, tmp_path: Path) -> None: + video = tmp_path / "clip.mp4" + video.write_bytes(b"video") + tools = _make_tools(tmp_path) + extracted: dict[str, str] = {} + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + + def fake_extract(source: Path, output: Path) -> None: + extracted["source"] = source.name + output.write_bytes(b"wav") + + def fake_request(*, resolved: Path, **_: object) -> dict[str, object]: + extracted["uploaded_suffix"] = resolved.suffix + return {"text": "video transcript"} + + mp.setattr(tools, "_extract_audio_source", fake_extract) + mp.setattr(tools, "_mistral_transcription_request", fake_request) + raw = tools.audio_transcribe("clip.mp4", chunking="off") + + parsed = json.loads(raw) + assert extracted["source"] == "clip.mp4" + assert extracted["uploaded_suffix"] == ".wav" + assert parsed["text"] == "video transcript" + + def test_audio_transcribe_fail_fast_on_chunk_error(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 1_200_000) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=1_100_000, + ) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + mp.setattr(tools, "_probe_media_duration", lambda _: 50.0) + + def fake_extract( + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + output.write_bytes(b"chunk") + + calls = {"count": 0} + + def fake_request(**_: object) -> dict[str, object]: + calls["count"] += 1 + if calls["count"] == 2: + raise ToolError("boom") + return {"text": "alpha beta gamma delta epsilon"} + + mp.setattr(tools, "_extract_audio_chunk", fake_extract) + mp.setattr(tools, "_mistral_transcription_request", fake_request) + out = tools.audio_transcribe( + "clip.wav", + chunk_max_seconds=30, + chunk_overlap_seconds=2, + ) + + assert "audio_transcribe failed in chunk 1" in out + + def test_audio_transcribe_can_return_partial_chunked_output(self, tmp_path: Path) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio, payload=b"x" * 1_200_000) + tools = _make_tools( + tmp_path, + mistral_transcription_max_bytes=1_100_000, + ) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_ensure_media_tools", lambda: None) + mp.setattr(tools, "_probe_media_duration", lambda _: 60.0) + + def fake_extract( + source: Path, + output: Path, + *, + start_sec: float, + duration_sec: float, + ) -> None: + output.write_bytes(b"chunk") + + calls = {"count": 0} + + def fake_request(**_: object) -> dict[str, object]: + calls["count"] += 1 + if calls["count"] == 2: + raise ToolError("boom") + return {"text": f"chunk {calls['count']} transcript words words words"} + + mp.setattr(tools, "_extract_audio_chunk", fake_extract) + mp.setattr(tools, "_mistral_transcription_request", fake_request) + raw = tools.audio_transcribe( + "clip.wav", + chunk_max_seconds=30, + chunk_overlap_seconds=1, + continue_on_chunk_error=True, + ) + + parsed = json.loads(raw) + assert parsed["chunking"]["partial"] is True + assert parsed["chunking"]["failed_chunks"] == 1 + assert parsed["warnings"][0].startswith("chunk 1 failed") + + def test_audio_transcribe_structured_truncation_keeps_valid_json( + self, + tmp_path: Path, + ) -> None: + audio = tmp_path / "clip.wav" + _write_audio(audio) + tools = _make_tools( + tmp_path, + max_file_chars=400, + max_observation_chars=400, + ) + mocked = { + "text": "word " * 200, + "segments": [ + {"text": "segment", "start": 0.0, "end": 1.0, "speaker": "speaker_a"} + for _ in range(30) + ], + "words": [ + {"text": "word", "start": 0.0, "end": 0.1, "speaker": "speaker_a"} + for _ in range(60) + ], + } + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(tools, "_mistral_transcription_request", lambda **_: mocked) + raw = tools.audio_transcribe("clip.wav") + + parsed = json.loads(raw) + assert parsed["truncation"]["applied"] is True + assert "text_truncated_chars" in parsed["truncation"] + + +class TestAudioTranscribeToolDef: + def test_audio_transcribe_in_tool_definitions(self) -> None: + names = [d["name"] for d in TOOL_DEFINITIONS] + assert "audio_transcribe" in names + + def test_audio_transcribe_definition_schema(self) -> None: + defn = next(d for d in TOOL_DEFINITIONS if d["name"] == "audio_transcribe") + assert defn["parameters"]["required"] == ["path"] + props = defn["parameters"]["properties"] + assert "context_bias" in props + assert props["context_bias"]["type"] == "array" + assert props["chunking"]["enum"] == ["auto", "off", "force"] + assert props["chunk_max_seconds"]["type"] == "integer" + assert props["continue_on_chunk_error"]["type"] == "boolean" diff --git a/tests/test_credentials.py b/tests/test_credentials.py index dd4a5ac9..78729f4d 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -24,6 +24,7 @@ def test_parse_env_file_extracts_supported_keys(self) -> None: "ANTHROPIC_API_KEY=an-key", "OPENROUTER_API_KEY=or-key", "EXA_API_KEY=exa-key", + "MISTRAL_API_KEY=mistral-key", ] ), encoding="utf-8", @@ -33,6 +34,7 @@ def test_parse_env_file_extracts_supported_keys(self) -> None: self.assertEqual(creds.anthropic_api_key, "an-key") self.assertEqual(creds.openrouter_api_key, "or-key") self.assertEqual(creds.exa_api_key, "exa-key") + self.assertEqual(creds.mistral_transcription_api_key, "mistral-key") def test_parse_env_assignments_preserves_generic_workspace_keys(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: @@ -59,6 +61,7 @@ def test_store_roundtrip(self) -> None: anthropic_api_key="an", openrouter_api_key="or", exa_api_key="exa", + mistral_transcription_api_key="mistral", ) store.save(creds) loaded = store.load() diff --git a/tests/test_settings.py b/tests/test_settings.py index 2f85fa12..1f1af13a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -4,7 +4,9 @@ import unittest from pathlib import Path +from agent.__main__ import _resolve_provider from agent.builder import _validate_model_provider, infer_provider_for_model +from agent.credentials import CredentialBundle from agent.model import ModelError from agent.settings import PersistentSettings, SettingsStore, normalize_reasoning_effort from agent.tui import SLASH_COMMANDS, _compute_suggestions @@ -216,5 +218,11 @@ def test_unknown_model_passes(self) -> None: _validate_model_provider("some-random-model", "anthropic") +class ResolveProviderTests(unittest.TestCase): + def test_mistral_transcription_key_does_not_change_chat_provider(self) -> None: + creds = CredentialBundle(mistral_transcription_api_key="mistral-test") + self.assertEqual(_resolve_provider("auto", creds), "anthropic") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tool_defs.py b/tests/test_tool_defs.py index 5efccd53..a985725b 100644 --- a/tests/test_tool_defs.py +++ b/tests/test_tool_defs.py @@ -30,7 +30,7 @@ def test_tool_count(self) -> None: self.assertEqual(len(names), len(TOOL_DEFINITIONS)) expected = { "list_files", "search_files", "repo_map", "web_search", "fetch_url", - "read_file", "read_image", "write_file", "apply_patch", "edit_file", + "read_file", "read_image", "audio_transcribe", "write_file", "apply_patch", "edit_file", "hashline_edit", "run_shell", "run_shell_bg", "check_shell_bg", "kill_shell_bg", "think", "subtask", "execute",