From 3b5ddd822544df9baeda74e3f7841653db551825 Mon Sep 17 00:00:00 2001 From: Pascal Berrang Date: Sun, 22 Feb 2026 18:12:05 +0100 Subject: [PATCH 1/3] Add model warmup command and download-aware TUI progress for first-run setup --- README.md | 11 ++ src/ownscribe/cli.py | 29 +++ src/ownscribe/pipeline.py | 36 ++++ src/ownscribe/progress.py | 186 +++++++++++++++++- src/ownscribe/transcription/base.py | 4 + .../transcription/whisperx_transcriber.py | 144 +++++++++++--- tests/test_cli.py | 19 ++ tests/test_pipeline.py | 15 ++ tests/test_progress.py | 157 +++++++++++++++ tests/test_transcription.py | 136 +++++++++++++ 10 files changed, 705 insertions(+), 32 deletions(-) create mode 100644 tests/test_progress.py diff --git a/README.md b/README.md index 0808866..162dcaf 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,8 @@ This will: 3. Summarize with your local LLM 4. Save everything to `~/ownscribe/YYYY-MM-DD_HHMMSS/` +On first run, WhisperX / pyannote may download model files. ownscribe shows a `Preparing models` step and best-effort download progress in the TUI while this happens. + ### Options ```bash @@ -117,6 +119,7 @@ ownscribe --template lecture # use the lecture summarization te ```bash ownscribe devices # list audio devices (uses native CoreAudio when available) ownscribe apps # list running apps with PIDs for use with --pid +ownscribe warmup # prefetch WhisperX/pyannote models before a meeting ownscribe transcribe recording.wav # transcribe an audio file (saves alongside the input) ownscribe summarize transcript.md # summarize a transcript (saves alongside the input) ownscribe resume ./2026-02-20_1736 # resume a failed/partial pipeline in a directory @@ -125,6 +128,14 @@ ownscribe config # open config file in $EDITOR ownscribe cleanup # remove ownscribe data from disk ``` +Use `warmup` ahead of time to avoid first-run model download delays while recording: + +```bash +ownscribe warmup # prefetch Whisper model (+ diarization if enabled in config) +ownscribe warmup --language en # also prefetch alignment model for English +ownscribe warmup --with-diarization # force diarization warmup for this run +``` + ### Searching Meeting Notes Use `ask` to search across all your meeting notes with natural-language questions: diff --git a/src/ownscribe/cli.py b/src/ownscribe/cli.py index c0d6790..3bcb35f 100644 --- a/src/ownscribe/cli.py +++ b/src/ownscribe/cli.py @@ -149,6 +149,35 @@ def transcribe( run_transcribe(config, file) +@cli.command() +@click.option("--model", default=None, help="Whisper model size.") +@click.option("--language", default=None, help="Language code to prefetch alignment model for (e.g. en, de, fr).") +@click.option( + "--with-diarization/--no-diarization", + "with_diarization", + default=None, + help="Override diarization warmup (defaults to config setting).", +) +@click.pass_context +def warmup( + ctx: click.Context, + model: str | None, + language: str | None, + with_diarization: bool | None, +) -> None: + """Prefetch WhisperX/pyannote models to avoid first-run stalls.""" + config = ctx.obj["config"] + if model: + config.transcription.model = model + if language: + config.transcription.language = language + if with_diarization is not None: + config.diarization.enabled = with_diarization + + from ownscribe.pipeline import run_warmup + run_warmup(config) + + @cli.command() @click.argument("file", type=click.Path(exists=True)) @click.option("--template", default=None, help="Summarization template (meeting, lecture, brief, or custom).") diff --git a/src/ownscribe/pipeline.py b/src/ownscribe/pipeline.py index aecc0db..c81de3a 100644 --- a/src/ownscribe/pipeline.py +++ b/src/ownscribe/pipeline.py @@ -222,6 +222,42 @@ def run_transcribe(config: Config, audio_file: str) -> None: _do_transcribe_and_summarize(config, audio_path, out_dir, summarize=False) +def run_warmup(config: Config) -> None: + """Prefetch transcription/diarization models without processing audio.""" + diar_enabled = config.diarization.enabled and bool(config.diarization.hf_token) + hf_token_warning = ( + config.diarization.enabled and not config.diarization.hf_token + ) + + with PipelineProgress(diarize=False, summarize=False, transcribe=False) as progress: + try: + transcriber = _create_transcriber(config, progress=progress) + except ImportError: + click.echo( + "Error: WhisperX is not installed. Install with:\n" + " uv pip install 'ownscribe[transcription]'", + err=True, + ) + raise SystemExit(1) from None + + transcriber.prepare_models(language=config.transcription.language or None) + + click.echo(f"Whisper model ready: {config.transcription.model}") + if config.transcription.language: + click.echo(f"Alignment model ready: {config.transcription.language}") + else: + click.echo("Alignment model not preloaded (language auto-detect).") + + if diar_enabled: + click.echo("Diarization pipeline ready.") + elif hf_token_warning: + click.echo( + "Warning: Diarization enabled but no HF token configured. " + "Skipping diarization warmup.", + err=True, + ) + + def run_summarize(config: Config, transcript_file: str) -> None: """Summarize a transcript file and save the summary alongside the input.""" transcript_path = Path(transcript_file).resolve() diff --git a/src/ownscribe/progress.py b/src/ownscribe/progress.py index 6fca686..0ce49b6 100644 --- a/src/ownscribe/progress.py +++ b/src/ownscribe/progress.py @@ -17,6 +17,18 @@ _INTERVAL = 0.1 _PROGRESS_RE = re.compile(r"Progress:\s*([\d.]+)%") +_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[A-Za-z]") +_TQDM_RE = re.compile( + r"(?:(?P[^:\r\n]+):\s*)?" + r"(?P[\d.]+)%\|.*?\|\s*" + r"(?P[\d.]+)\s*(?P[kKMGTPE]?i?B)\s*/\s*" + r"(?P[\d.]+)\s*(?P[kKMGTPE]?i?B)" +) +_BYTES_RE = re.compile( + r"(?P[\d.]+)\s*(?P[kKMGTPE]?i?B)\s*/\s*" + r"(?P[\d.]+)\s*(?P[kKMGTPE]?i?B)" +) +_PERCENT_RE = re.compile(r"(?P[\d.]+)%") class Spinner: @@ -115,6 +127,130 @@ def flush(self) -> None: pass +@dataclass +class DownloadProgressEvent: + """Best-effort parsed progress for model downloads/preparation.""" + + filename: str | None = None + percent: float | None = None + bytes_done: int | None = None + bytes_total: int | None = None + + +def _parse_size_to_bytes(value: str, unit: str) -> int: + multipliers = { + "B": 1, + "KB": 1024, + "MB": 1024**2, + "GB": 1024**3, + "TB": 1024**4, + "PB": 1024**5, + "KIB": 1024, + "MIB": 1024**2, + "GIB": 1024**3, + "TIB": 1024**4, + "PIB": 1024**5, + } + factor = multipliers.get(unit.strip().upper()) + if factor is None: + raise ValueError(f"Unknown size unit: {unit}") + return int(float(value) * factor) + + +def parse_download_progress(text: str) -> DownloadProgressEvent | None: + """Parse a tqdm/HF-style progress line into a structured event.""" + clean = _ANSI_RE.sub("", text).strip() + if not clean: + return None + + if m := _TQDM_RE.search(clean): + filename = (m.group("filename") or "").strip() or None + return DownloadProgressEvent( + filename=filename, + percent=float(m.group("percent")), + bytes_done=_parse_size_to_bytes(m.group("done"), m.group("done_unit")), + bytes_total=_parse_size_to_bytes(m.group("total"), m.group("total_unit")), + ) + + if m := _BYTES_RE.search(clean): + percent = None + if m2 := _PERCENT_RE.search(clean): + percent = float(m2.group("percent")) + return DownloadProgressEvent( + percent=percent, + bytes_done=_parse_size_to_bytes(m.group("done"), m.group("done_unit")), + bytes_total=_parse_size_to_bytes(m.group("total"), m.group("total_unit")), + ) + + if m := _PERCENT_RE.search(clean): + return DownloadProgressEvent(percent=float(m.group("percent"))) + + return None + + +def _human_bytes(value: int) -> str: + size = float(value) + for unit in ("B", "KB", "MB", "GB", "TB"): + if size < 1024 or unit == "TB": + if unit == "B": + return f"{int(size)} {unit}" + return f"{size:.1f} {unit}" + size /= 1024 + return f"{size:.1f} TB" + + +def format_download_progress(event: DownloadProgressEvent, *, include_percent: bool = True) -> str: + """Format a parsed download progress event for display in the TUI.""" + parts: list[str] = [] + if event.filename: + parts.append(event.filename) + if event.bytes_done is not None and event.bytes_total is not None: + parts.append(f"{_human_bytes(event.bytes_done)} / {_human_bytes(event.bytes_total)}") + if include_percent and event.percent is not None: + parts.append(f"{int(event.percent)}%") + return " ".join(parts).strip() + + +def download_event_fraction(event: DownloadProgressEvent) -> float | None: + """Convert a parsed download event to a progress-bar fraction.""" + if event.bytes_done is not None and event.bytes_total and event.bytes_total > 0: + return max(0.0, min(1.0, event.bytes_done / event.bytes_total)) + if event.percent is not None: + return max(0.0, min(1.0, event.percent / 100.0)) + return None + + +class DownloadProgressWriter: + """File-like object that parses download progress from captured output.""" + + def __init__(self, update_fn: Callable[[DownloadProgressEvent], None]) -> None: + self._update_fn = update_fn + self._buffer = "" + + def write(self, text: str) -> int: + self._buffer += text + while True: + idx_r = self._buffer.find("\r") + idx_n = self._buffer.find("\n") + idxs = [idx for idx in (idx_r, idx_n) if idx != -1] + if not idxs: + break + idx = min(idxs) + chunk = self._buffer[:idx] + self._buffer = self._buffer[idx + 1:] + self._consume(chunk) + return len(text) + + def flush(self) -> None: + if self._buffer: + self._consume(self._buffer) + self._buffer = "" + + def _consume(self, chunk: str) -> None: + if event := parse_download_progress(chunk): + self._update_fn(event) + + # --------------------------------------------------------------------------- # Pipeline-level checklist progress # --------------------------------------------------------------------------- @@ -136,8 +272,10 @@ class _Step: class PipelineProgress: """Full-pipeline checklist display.""" - def __init__(self, *, diarize: bool = False, summarize: bool = False) -> None: - steps: list[_Step] = [_Step("transcribing", "Transcribing", indent=0)] + def __init__(self, *, diarize: bool = False, summarize: bool = False, transcribe: bool = True) -> None: + steps: list[_Step] = [_Step("preparing_models", "Preparing models", indent=0)] + if transcribe: + steps.append(_Step("transcribing", "Transcribing", indent=0)) if diarize: steps.append(_Step("diarizing", "Diarizing", indent=0)) steps.extend([ @@ -153,6 +291,7 @@ def __init__(self, *, diarize: bool = False, summarize: bool = False) -> None: self._active: set[str] = set() self._completed: set[str] = set() self._progress: dict[str, float] = {} + self._details: dict[str, str] = {} self._lock = threading.Lock() self._lines_rendered = 0 self._stderr = sys.stderr @@ -174,6 +313,7 @@ def __exit__(self, *_exc) -> None: for key in list(self._active): self._completed.add(key) self._progress.pop(key, None) + self._details.pop(key, None) self._active.clear() self._render_all(final=True) @@ -190,8 +330,11 @@ def begin(self, key: str) -> None: if other.indent == step.indent: self._active.discard(other_key) self._completed.add(other_key) + self._progress.pop(other_key, None) + self._details.pop(other_key, None) self._active.add(key) self._progress.pop(key, None) + self._details.pop(key, None) # Lazy-start animation thread on first begin() if self._thread is None: self._stop.clear() @@ -206,6 +349,7 @@ def complete(self, key: str) -> None: self._active.discard(key) self._completed.add(key) self._progress.pop(key, None) + self._details.pop(key, None) # If top-level step, also complete any active sub-steps if step.indent == 0: for s in self._steps: @@ -213,18 +357,29 @@ def complete(self, key: str) -> None: self._active.discard(s.key) self._completed.add(s.key) self._progress.pop(s.key, None) + self._details.pop(s.key, None) def fail(self, key: str) -> None: """Mark a step as failed — removes from active without completing.""" with self._lock: self._active.discard(key) self._progress.pop(key, None) + self._details.pop(key, None) def update(self, key: str, fraction: float) -> None: with self._lock: if key in self._step_map: self._progress[key] = max(0.0, min(1.0, fraction)) + def set_detail(self, key: str, text: str | None) -> None: + with self._lock: + if key not in self._step_map: + return + if text: + self._details[key] = text + else: + self._details.pop(key, None) + def diarization_hook(self, step_name: str, _artifact, **kwargs) -> None: """Pyannote-compatible hook callback for diarization progress.""" # Map pyannote step names to our keys @@ -251,6 +406,7 @@ def _render_all(self, *, final: bool = False) -> None: active = set(self._active) completed = set(self._completed) progress = dict(self._progress) + details = dict(self._details) # Pick a spinner frame (not needed for final) frame = "" @@ -269,20 +425,29 @@ def _render_all(self, *, final: bool = False) -> None: filled = int(frac * _BAR_WIDTH) bar = _FILLED * filled + _EMPTY * (_BAR_WIDTH - filled) pct = int(frac * 100) - lines.append(f"{indent}{step.label:<20s} [{bar}] {pct:3d}%") + lines.append(f"{indent}{frame} {step.label:<20s} [{bar}] {pct:3d}%") else: lines.append(f"{indent}{frame} {step.label}") + if detail := details.get(step.key): + lines.append(f"{indent} {detail}") else: lines.append(f"{indent}\u25cb {step.label}") # Move cursor up to overwrite previous render - if self._lines_rendered > 0: - self._stderr.write(f"\033[{self._lines_rendered}A") - - output = "\n".join(lines) - self._stderr.write(f"{output}\033[K\n") + prev_lines = self._lines_rendered + if prev_lines > 0: + self._stderr.write(f"\033[{prev_lines}A") + + # If the render shrinks (e.g. detail line disappears), explicitly clear the + # now-stale trailing rows by writing blank cleared lines. + render_lines = list(lines) + if prev_lines > len(render_lines): + render_lines.extend([""] * (prev_lines - len(render_lines))) + + for line in render_lines: + self._stderr.write(f"{line}\033[K\n") self._stderr.flush() - self._lines_rendered = len(lines) + self._lines_rendered = len(render_lines) def _animate(self) -> None: while not self._stop.is_set(): @@ -305,5 +470,8 @@ def fail(self, key: str) -> None: def update(self, key: str, fraction: float) -> None: pass + def set_detail(self, key: str, text: str | None) -> None: + pass + def diarization_hook(self, step_name: str, _artifact, **kwargs) -> None: pass diff --git a/src/ownscribe/transcription/base.py b/src/ownscribe/transcription/base.py index c259556..dbbec43 100644 --- a/src/ownscribe/transcription/base.py +++ b/src/ownscribe/transcription/base.py @@ -11,6 +11,10 @@ class Transcriber(abc.ABC): """Base class for transcription backends.""" + def prepare_models(self, language: str | None = None) -> None: + """Optional hook to prefetch/load models before transcription.""" + _ = language + @abc.abstractmethod def transcribe(self, audio_path: Path) -> TranscriptResult: """Transcribe an audio file and return structured results.""" diff --git a/src/ownscribe/transcription/whisperx_transcriber.py b/src/ownscribe/transcription/whisperx_transcriber.py index c8ed403..dc8807f 100644 --- a/src/ownscribe/transcription/whisperx_transcriber.py +++ b/src/ownscribe/transcription/whisperx_transcriber.py @@ -11,7 +11,14 @@ import click from ownscribe.config import DiarizationConfig, TranscriptionConfig -from ownscribe.progress import NullProgress, ProgressWriter +from ownscribe.progress import ( + DownloadProgressEvent, + DownloadProgressWriter, + NullProgress, + ProgressWriter, + download_event_fraction, + format_download_progress, +) from ownscribe.transcription.base import Transcriber from ownscribe.transcription.models import Segment, TranscriptResult, Word @@ -31,6 +38,8 @@ def __init__( self._diar_config = diarization_config self._progress = progress or NullProgress() self._model = None + self._align_models: dict[str, tuple[object, object]] = {} + self._diarize_model = None def _load_model(self): import whisperx @@ -44,6 +53,111 @@ def _load_model(self): language=self._tx_config.language or None, ) + def _configure_runtime_env(self) -> None: + os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + if self._diar_config is None or not self._diar_config.telemetry: + os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") + os.environ.setdefault("PYANNOTE_METRICS_ENABLED", "0") + + def _set_detail(self, key: str, text: str | None) -> None: + set_detail = getattr(self._progress, "set_detail", None) + if callable(set_detail): + set_detail(key, text) + + def _set_prepare_detail(self, text: str | None) -> None: + self._set_detail("preparing_models", text) + + def _on_download_progress(self, step_key: str, stage_label: str, event: DownloadProgressEvent) -> None: + fraction = download_event_fraction(event) + if fraction is not None: + self._progress.update(step_key, fraction) + formatted = format_download_progress(event, include_percent=fraction is None) + if formatted: + self._set_detail(step_key, f"{stage_label}: {formatted}") + elif fraction is None and event.percent is not None: + self._set_detail(step_key, f"{stage_label}: {int(event.percent)}%") + + def _capture_download_output(self, step_key: str, stage_label: str, fn, *args, **kwargs): + writer = DownloadProgressWriter( + lambda event: self._on_download_progress(step_key, stage_label, event) + ) + self._progress.update(step_key, 0.0) + self._set_detail(step_key, stage_label) + with contextlib.ExitStack() as stack: + stack.enter_context(contextlib.redirect_stdout(writer)) + stack.enter_context(contextlib.redirect_stderr(writer)) + result = fn(*args, **kwargs) + writer.flush() + return result + + def _capture_prep_output(self, stage_label: str, fn, *args, **kwargs): + return self._capture_download_output("preparing_models", stage_label, fn, *args, **kwargs) + + def _load_align_model(self, language: str, *, step_key: str = "preparing_models") -> tuple[object, object]: + import whisperx + + if language in self._align_models: + return self._align_models[language] + + align_model, align_metadata = self._capture_download_output( + step_key, + f"Loading alignment model ({language})", + whisperx.load_align_model, + language_code=language, + device="cpu", + ) + self._align_models[language] = (align_model, align_metadata) + return align_model, align_metadata + + def _load_diarization_pipeline(self, *, step_key: str = "preparing_models"): + from whisperx.diarize import DiarizationPipeline + + if self._diarize_model is not None: + return self._diarize_model + + device = self._resolve_diarization_device(self._diar_config.device) + self._diarize_model = self._capture_download_output( + step_key, + "Loading diarization pipeline", + DiarizationPipeline, + use_auth_token=self._diar_config.hf_token, + device=device, + ) + return self._diarize_model + + def prepare_models(self, language: str | None = None) -> None: + self._configure_runtime_env() + progress = self._progress + progress.begin("preparing_models") + try: + if self._model is None: + self._capture_prep_output( + f"Loading Whisper model ({self._tx_config.model})", + self._load_model, + ) + else: + self._set_prepare_detail(f"Whisper model ready ({self._tx_config.model})") + + align_language = language or self._tx_config.language or None + if align_language: + self._load_align_model(align_language) + else: + self._set_prepare_detail( + "Whisper model ready. Alignment model will load after language detection." + ) + + if ( + self._diar_config + and self._diar_config.enabled + and self._diar_config.hf_token + ): + self._load_diarization_pipeline() + + progress.complete("preparing_models") + except Exception: + progress.fail("preparing_models") + raise + def transcribe(self, audio_path: Path) -> TranscriptResult: import shutil @@ -56,10 +170,7 @@ def transcribe(self, audio_path: Path) -> TranscriptResult: raise SystemExit(1) # --- Telemetry toggle (must happen before importing whisperx) --- - os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") - if self._diar_config is None or not self._diar_config.telemetry: - os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") - os.environ.setdefault("PYANNOTE_METRICS_ENABLED", "0") + self._configure_runtime_env() hf_token_warning: str | None = None if ( @@ -89,6 +200,7 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: import whisperx progress = self._progress + self.prepare_models(language=self._tx_config.language or None) devnull = open(os.devnull, "w") # noqa: SIM115 try: @@ -96,10 +208,6 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: with contextlib.redirect_stdout(devnull): progress.begin("transcribing") - if self._model is None: - with contextlib.redirect_stderr(devnull): - self._load_model() - audio = whisperx.load_audio(str(audio_path)) tx_writer = ProgressWriter( @@ -119,9 +227,7 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: language = result.get("language", "") - align_model, align_metadata = whisperx.load_align_model( - language_code=language, device="cpu" - ) + align_model, align_metadata = self._load_align_model(language, step_key="transcribing") with contextlib.redirect_stdout(align_writer): result = whisperx.align( result["segments"], @@ -142,7 +248,7 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: and self._diar_config.enabled and self._diar_config.hf_token ): - result = self._diarize(audio, result, devnull) + result = self._diarize(audio, result) finally: devnull.close() @@ -181,22 +287,14 @@ def _resolve_diarization_device(device_cfg: str) -> str: return "mps" if torch.backends.mps.is_available() else "cpu" return device_cfg - def _diarize(self, audio, result, devnull): + def _diarize(self, audio, result): import pandas as pd import torch import whisperx - from whisperx.diarize import DiarizationPipeline progress = self._progress progress.begin("diarizing") - - device = self._resolve_diarization_device(self._diar_config.device) - - # Load the diarization pipeline (model loading happens inside) - with contextlib.redirect_stderr(devnull): - diarize_model = DiarizationPipeline( - use_auth_token=self._diar_config.hf_token, device=device - ) + diarize_model = self._load_diarization_pipeline(step_key="diarizing") # Build audio_data dict the same way whisperx does internally audio_data = { diff --git a/tests/test_cli.py b/tests/test_cli.py index 952b4e9..35bc4be 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -95,6 +95,12 @@ def test_resume_help(self): assert result.exit_code == 0 assert "Resume a partially-completed pipeline" in result.output + def test_warmup_help(self): + runner = CliRunner() + result = runner.invoke(cli, ["warmup", "--help"]) + assert result.exit_code == 0 + assert "Prefetch WhisperX/pyannote models" in result.output + def test_cleanup_help(self): runner = CliRunner() result = runner.invoke(cli, ["cleanup", "--help"]) @@ -120,6 +126,19 @@ def test_keep_recording_default_is_true(self): assert config.output.keep_recording is True +class TestWarmupCommand: + def test_warmup_invokes_pipeline_with_overrides(self): + runner = CliRunner() + with _mock_config(), mock.patch("ownscribe.pipeline.run_warmup") as mock_warmup: + result = runner.invoke(cli, ["warmup", "--model", "large-v3", "--language", "de", "--with-diarization"]) + + assert result.exit_code == 0 + config = mock_warmup.call_args[0][0] + assert config.transcription.model == "large-v3" + assert config.transcription.language == "de" + assert config.diarization.enabled is True + + class TestCleanup: def test_all_yes_removes_dirs(self, tmp_path): config_dir = tmp_path / "config" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e34f1d8..94ab98a 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -294,6 +294,21 @@ def test_summarization_failure_preserves_transcript(self, tmp_path): assert not (tmp_path / "summary.md").exists() +class TestRunWarmup: + def test_run_warmup_calls_prepare_models(self): + from ownscribe.pipeline import run_warmup + + config = Config() + config.transcription.language = "en" + + mock_transcriber = mock.MagicMock() + + with mock.patch("ownscribe.pipeline._create_transcriber", return_value=mock_transcriber): + run_warmup(config) + + mock_transcriber.prepare_models.assert_called_once_with(language="en") + + class TestRunTranscribeColocation: """Test that run_transcribe saves output alongside the input file.""" diff --git a/tests/test_progress.py b/tests/test_progress.py new file mode 100644 index 0000000..a8cdce7 --- /dev/null +++ b/tests/test_progress.py @@ -0,0 +1,157 @@ +"""Tests for progress parsing and TUI rendering helpers.""" + +from __future__ import annotations + +import io +from unittest import mock + +from ownscribe.progress import ( + _BRAILLE, + DownloadProgressEvent, + DownloadProgressWriter, + PipelineProgress, + download_event_fraction, + format_download_progress, + parse_download_progress, +) + + +class TestDownloadProgressParsing: + def test_parses_tqdm_style_line(self): + event = parse_download_progress( + "model.bin: 26%|##5 | 123MB/466MB [00:10<00:20, 12.3MB/s]" + ) + + assert event is not None + assert event.filename == "model.bin" + assert event.percent == 26.0 + assert event.bytes_done is not None + assert event.bytes_total is not None + assert format_download_progress(event).startswith("model.bin") + + def test_ignores_non_progress_noise(self): + assert parse_download_progress("Some unrelated log line") is None + + def test_writer_handles_partial_carriage_return_updates(self): + events = [] + writer = DownloadProgressWriter(events.append) + + writer.write("model.bin: 12%|##") + writer.write(" | 12MB/100MB [00:01<00:08, 10MB/s]\r") + writer.flush() + + assert events + assert events[-1].percent == 12.0 + assert events[-1].filename == "model.bin" + + +class TestDownloadProgressFraction: + def test_prefers_bytes_ratio(self): + event = parse_download_progress("model.bin: 26%|##5| 123MB/466MB [00:10<00:20]") + assert event is not None + fraction = download_event_fraction(event) + assert fraction is not None + assert 0.26 < fraction < 0.27 + + def test_uses_percent_when_bytes_missing(self): + event = parse_download_progress("Progress 75% complete") + assert event is not None + assert download_event_fraction(event) == 0.75 + + def test_clamps_fraction(self): + from ownscribe.progress import DownloadProgressEvent + + assert download_event_fraction(DownloadProgressEvent(percent=150)) == 1.0 + assert download_event_fraction(DownloadProgressEvent(percent=-5)) == 0.0 + + def test_returns_none_without_progress_numbers(self): + assert download_event_fraction(DownloadProgressEvent(filename="model.bin")) is None + + +class TestDownloadProgressFormatting: + def test_can_omit_percent_when_bar_already_shows_it(self): + event = DownloadProgressEvent( + filename="model.bin", + percent=10.0, + bytes_done=10 * 1024**2, + bytes_total=100 * 1024**2, + ) + text = format_download_progress(event, include_percent=False) + + assert "model.bin" in text + assert "%" not in text + + +class TestPipelineProgressDetails: + def test_renders_detail_line_for_active_step(self): + progress = PipelineProgress(transcribe=False) + progress._stderr = io.StringIO() + progress.begin("preparing_models") + progress.set_detail("preparing_models", "Downloading model.bin 12 MB / 100 MB (12%)") + + progress._render_all(final=True) + output = progress._stderr.getvalue() + + assert "Preparing models" in output + assert "Downloading model.bin" in output + + progress._stop.set() + if progress._thread is not None: + progress._thread.join() + + def test_detail_clears_on_complete(self): + progress = PipelineProgress(transcribe=False) + progress._stderr = io.StringIO() + progress.begin("preparing_models") + progress.set_detail("preparing_models", "Downloading...") + progress.complete("preparing_models") + + progress._render_all(final=True) + output = progress._stderr.getvalue() + + assert "Downloading..." not in output + + progress._stop.set() + if progress._thread is not None: + progress._thread.join() + + def test_determinate_active_row_includes_spinner_glyph(self): + progress = PipelineProgress(transcribe=False) + progress._stderr = io.StringIO() + progress.begin("preparing_models") + progress.update("preparing_models", 0.1) + + with mock.patch("ownscribe.progress.time.time", return_value=0.0): + progress._render_all(final=False) + + output = progress._stderr.getvalue() + assert f" {_BRAILLE[0]} Preparing models" in output + assert "[██" in output or "[█" in output + + progress._stop.set() + if progress._thread is not None: + progress._thread.join() + + def test_renderer_clears_each_line_and_stale_rows_when_detail_disappears(self): + progress = PipelineProgress(transcribe=False) + progress._stderr = io.StringIO() + progress.begin("preparing_models") + progress.set_detail("preparing_models", "Downloading...") + + progress._render_all(final=True) + first = progress._stderr.getvalue() + first_len = len(first) + assert first.count("\033[K\n") >= 2 + + progress.set_detail("preparing_models", None) + progress._render_all(final=True) + second_delta = progress._stderr.getvalue()[first_len:] + + assert "\033[2A" in second_delta + # One real line + one blank clearing line for the removed detail row. + assert second_delta.count("\033[K\n") == 2 + assert "Downloading..." not in second_delta + + progress._stop.set() + if progress._thread is not None: + progress._thread.join() diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 187cfd7..967b29f 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -16,3 +16,139 @@ def test_missing_ffmpeg_exits(self): with mock.patch("shutil.which", return_value=None), pytest.raises(SystemExit): transcriber.transcribe(mock.MagicMock()) + + +class _FakeProgress: + def __init__(self) -> None: + self.calls: list[tuple[str, str]] = [] + self.details: dict[str, str] = {} + self.updates: list[tuple[str, float]] = [] + + def begin(self, key: str) -> None: + self.calls.append(("begin", key)) + + def complete(self, key: str) -> None: + self.calls.append(("complete", key)) + + def fail(self, key: str) -> None: + self.calls.append(("fail", key)) + + def update(self, key: str, fraction: float) -> None: + self.updates.append((key, fraction)) + + def set_detail(self, key: str, text: str | None) -> None: + if text is None: + self.details.pop(key, None) + else: + self.details[key] = text + + def diarization_hook(self, step_name: str, _artifact, **kwargs) -> None: + _ = (step_name, _artifact, kwargs) + + +class TestPrepareModels: + def test_prepare_models_emits_preparing_models_lifecycle(self): + from ownscribe.config import TranscriptionConfig + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + progress = _FakeProgress() + transcriber = WhisperXTranscriber(TranscriptionConfig(language="en"), None, progress=progress) + + def passthrough(stage_label, fn, *args, **kwargs): + _ = stage_label + return fn(*args, **kwargs) + + with ( + mock.patch.object(transcriber, "_capture_prep_output", side_effect=passthrough), + mock.patch.object(transcriber, "_load_model", side_effect=lambda: setattr(transcriber, "_model", object())), + mock.patch.object(transcriber, "_load_align_model", return_value=(object(), object())), + ): + transcriber.prepare_models(language="en") + + assert ("begin", "preparing_models") in progress.calls + assert ("complete", "preparing_models") in progress.calls + assert ("fail", "preparing_models") not in progress.calls + + def test_prepare_models_skips_diarization_without_token(self): + from ownscribe.config import DiarizationConfig, TranscriptionConfig + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + progress = _FakeProgress() + diar = DiarizationConfig(enabled=True, hf_token="") + transcriber = WhisperXTranscriber(TranscriptionConfig(language="en"), diar, progress=progress) + + def passthrough(stage_label, fn, *args, **kwargs): + _ = stage_label + return fn(*args, **kwargs) + + with ( + mock.patch.object(transcriber, "_capture_prep_output", side_effect=passthrough), + mock.patch.object(transcriber, "_load_model", side_effect=lambda: setattr(transcriber, "_model", object())), + mock.patch.object(transcriber, "_load_align_model", return_value=(object(), object())), + mock.patch.object(transcriber, "_load_diarization_pipeline") as mock_diar_load, + ): + transcriber.prepare_models(language="en") + + mock_diar_load.assert_not_called() + + def test_prepare_models_reuses_loaded_whisper_model(self): + from ownscribe.config import TranscriptionConfig + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + progress = _FakeProgress() + transcriber = WhisperXTranscriber(TranscriptionConfig(language="en"), None, progress=progress) + + def passthrough(stage_label, fn, *args, **kwargs): + _ = stage_label + return fn(*args, **kwargs) + + with ( + mock.patch.object(transcriber, "_capture_prep_output", side_effect=passthrough), + mock.patch.object( + transcriber, + "_load_model", + side_effect=lambda: setattr(transcriber, "_model", object()), + ) as mock_load_model, + mock.patch.object(transcriber, "_load_align_model", return_value=(object(), object())), + ): + transcriber.prepare_models(language="en") + transcriber.prepare_models(language="en") + + assert mock_load_model.call_count == 1 + + +class TestDownloadProgressHooks: + def test_on_download_progress_updates_detail_and_bar(self): + from ownscribe.config import TranscriptionConfig + from ownscribe.progress import DownloadProgressEvent + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + progress = _FakeProgress() + transcriber = WhisperXTranscriber(TranscriptionConfig(), None, progress=progress) + + transcriber._on_download_progress( + "preparing_models", + "Loading Whisper model (base)", + DownloadProgressEvent(filename="model.bin", percent=25.0), + ) + + assert ("preparing_models", 0.25) in progress.updates + assert "Loading Whisper model (base)" in progress.details["preparing_models"] + assert "model.bin" in progress.details["preparing_models"] + assert "25%" not in progress.details["preparing_models"] + + def test_capture_download_output_resets_bar_to_zero(self): + from ownscribe.config import TranscriptionConfig + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + progress = _FakeProgress() + transcriber = WhisperXTranscriber(TranscriptionConfig(), None, progress=progress) + + def fake_loader(): + print("model.bin: 12%|##| 12MB/100MB [00:01<00:08]") + + transcriber._capture_download_output("preparing_models", "Loading Whisper model (base)", fake_loader) + + assert progress.updates + assert progress.updates[0] == ("preparing_models", 0.0) + assert any(key == "preparing_models" and frac > 0 for key, frac in progress.updates[1:]) From 4a781fc624c62e17628a03221874d86929890247 Mon Sep 17 00:00:00 2001 From: Pascal Berrang Date: Sun, 22 Feb 2026 20:27:34 +0100 Subject: [PATCH 2/3] Improve pipeline steps (remove prepare and move into substeps) --- src/ownscribe/pipeline.py | 2 +- src/ownscribe/progress.py | 13 ++++- .../transcription/whisperx_transcriber.py | 48 ++++++++++++++----- tests/test_pipeline.py | 18 +++++++ tests/test_progress.py | 16 +++++-- tests/test_transcription.py | 36 ++++++++++++++ 6 files changed, 113 insertions(+), 20 deletions(-) diff --git a/src/ownscribe/pipeline.py b/src/ownscribe/pipeline.py index c81de3a..05d1ae4 100644 --- a/src/ownscribe/pipeline.py +++ b/src/ownscribe/pipeline.py @@ -229,7 +229,7 @@ def run_warmup(config: Config) -> None: config.diarization.enabled and not config.diarization.hf_token ) - with PipelineProgress(diarize=False, summarize=False, transcribe=False) as progress: + with PipelineProgress(diarize=False, summarize=False, transcribe=False, include_prepare=True) as progress: try: transcriber = _create_transcriber(config, progress=progress) except ImportError: diff --git a/src/ownscribe/progress.py b/src/ownscribe/progress.py index 0ce49b6..7853af2 100644 --- a/src/ownscribe/progress.py +++ b/src/ownscribe/progress.py @@ -272,8 +272,17 @@ class _Step: class PipelineProgress: """Full-pipeline checklist display.""" - def __init__(self, *, diarize: bool = False, summarize: bool = False, transcribe: bool = True) -> None: - steps: list[_Step] = [_Step("preparing_models", "Preparing models", indent=0)] + def __init__( + self, + *, + diarize: bool = False, + summarize: bool = False, + transcribe: bool = True, + include_prepare: bool = False, + ) -> None: + steps: list[_Step] = [] + if include_prepare: + steps.append(_Step("preparing_models", "Preparing models", indent=0)) if transcribe: steps.append(_Step("transcribing", "Transcribing", indent=0)) if diarize: diff --git a/src/ownscribe/transcription/whisperx_transcriber.py b/src/ownscribe/transcription/whisperx_transcriber.py index dc8807f..d8ee3d3 100644 --- a/src/ownscribe/transcription/whisperx_transcriber.py +++ b/src/ownscribe/transcription/whisperx_transcriber.py @@ -93,6 +93,28 @@ def _capture_download_output(self, step_key: str, stage_label: str, fn, *args, * def _capture_prep_output(self, stage_label: str, fn, *args, **kwargs): return self._capture_download_output("preparing_models", stage_label, fn, *args, **kwargs) + def _prepare_transcription_models( + self, + *, + language: str | None, + step_key: str, + show_deferred_align_note: bool = False, + ) -> None: + if self._model is None: + self._capture_download_output( + step_key, + f"Loading Whisper model ({self._tx_config.model})", + self._load_model, + ) + + if language: + self._load_align_model(language, step_key=step_key) + elif show_deferred_align_note: + self._set_detail( + step_key, + "Whisper model ready. Alignment model will load after language detection.", + ) + def _load_align_model(self, language: str, *, step_key: str = "preparing_models") -> tuple[object, object]: import whisperx @@ -130,21 +152,15 @@ def prepare_models(self, language: str | None = None) -> None: progress = self._progress progress.begin("preparing_models") try: - if self._model is None: - self._capture_prep_output( - f"Loading Whisper model ({self._tx_config.model})", - self._load_model, - ) - else: + if self._model is not None: self._set_prepare_detail(f"Whisper model ready ({self._tx_config.model})") align_language = language or self._tx_config.language or None - if align_language: - self._load_align_model(align_language) - else: - self._set_prepare_detail( - "Whisper model ready. Alignment model will load after language detection." - ) + self._prepare_transcription_models( + language=align_language, + step_key="preparing_models", + show_deferred_align_note=True, + ) if ( self._diar_config @@ -200,7 +216,6 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: import whisperx progress = self._progress - self.prepare_models(language=self._tx_config.language or None) devnull = open(os.devnull, "w") # noqa: SIM115 try: @@ -208,6 +223,13 @@ def _transcribe_inner(self, audio_path: Path) -> TranscriptResult: with contextlib.redirect_stdout(devnull): progress.begin("transcribing") + self._prepare_transcription_models( + language=self._tx_config.language or None, + step_key="transcribing", + show_deferred_align_note=False, + ) + self._set_detail("transcribing", None) + audio = whisperx.load_audio(str(audio_path)) tx_writer = ProgressWriter( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 94ab98a..c9abbe3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -308,6 +308,24 @@ def test_run_warmup_calls_prepare_models(self): mock_transcriber.prepare_models.assert_called_once_with(language="en") + def test_run_warmup_enables_prepare_step_in_progress(self): + from ownscribe.pipeline import run_warmup + + config = Config() + mock_transcriber = mock.MagicMock() + fake_progress = mock.MagicMock() + + with ( + mock.patch("ownscribe.pipeline._create_transcriber", return_value=mock_transcriber), + mock.patch("ownscribe.pipeline.PipelineProgress") as mock_progress_cls, + ): + mock_progress_cls.return_value.__enter__.return_value = fake_progress + run_warmup(config) + + _, kwargs = mock_progress_cls.call_args + assert kwargs["include_prepare"] is True + assert kwargs["transcribe"] is False + class TestRunTranscribeColocation: """Test that run_transcribe saves output alongside the input file.""" diff --git a/tests/test_progress.py b/tests/test_progress.py index a8cdce7..02ca19f 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -84,7 +84,7 @@ def test_can_omit_percent_when_bar_already_shows_it(self): class TestPipelineProgressDetails: def test_renders_detail_line_for_active_step(self): - progress = PipelineProgress(transcribe=False) + progress = PipelineProgress(transcribe=False, include_prepare=True) progress._stderr = io.StringIO() progress.begin("preparing_models") progress.set_detail("preparing_models", "Downloading model.bin 12 MB / 100 MB (12%)") @@ -100,7 +100,7 @@ def test_renders_detail_line_for_active_step(self): progress._thread.join() def test_detail_clears_on_complete(self): - progress = PipelineProgress(transcribe=False) + progress = PipelineProgress(transcribe=False, include_prepare=True) progress._stderr = io.StringIO() progress.begin("preparing_models") progress.set_detail("preparing_models", "Downloading...") @@ -116,7 +116,7 @@ def test_detail_clears_on_complete(self): progress._thread.join() def test_determinate_active_row_includes_spinner_glyph(self): - progress = PipelineProgress(transcribe=False) + progress = PipelineProgress(transcribe=False, include_prepare=True) progress._stderr = io.StringIO() progress.begin("preparing_models") progress.update("preparing_models", 0.1) @@ -133,7 +133,7 @@ def test_determinate_active_row_includes_spinner_glyph(self): progress._thread.join() def test_renderer_clears_each_line_and_stale_rows_when_detail_disappears(self): - progress = PipelineProgress(transcribe=False) + progress = PipelineProgress(transcribe=False, include_prepare=True) progress._stderr = io.StringIO() progress.begin("preparing_models") progress.set_detail("preparing_models", "Downloading...") @@ -155,3 +155,11 @@ def test_renderer_clears_each_line_and_stale_rows_when_detail_disappears(self): progress._stop.set() if progress._thread is not None: progress._thread.join() + + def test_preparing_models_is_not_included_by_default(self): + progress = PipelineProgress(transcribe=True) + assert "preparing_models" not in progress._step_map + + def test_preparing_models_can_be_enabled_explicitly(self): + progress = PipelineProgress(transcribe=False, include_prepare=True) + assert "preparing_models" in progress._step_map diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 967b29f..e616667 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -2,6 +2,7 @@ from __future__ import annotations +import types from unittest import mock import pytest @@ -152,3 +153,38 @@ def fake_loader(): assert progress.updates assert progress.updates[0] == ("preparing_models", 0.0) assert any(key == "preparing_models" and frac > 0 for key, frac in progress.updates[1:]) + + def test_transcribe_inner_does_not_use_preparing_models_step(self): + from ownscribe.config import TranscriptionConfig + from ownscribe.transcription.whisperx_transcriber import WhisperXTranscriber + + class _Audio: + shape = (16000,) + + fake_whisperx = types.SimpleNamespace( + load_audio=lambda _path: _Audio(), + align=lambda *args, **kwargs: {"segments": []}, + ) + + progress = _FakeProgress() + transcriber = WhisperXTranscriber(TranscriptionConfig(language="en"), None, progress=progress) + transcriber._model = mock.MagicMock() + transcriber._model.transcribe.return_value = {"segments": [], "language": "en"} + + with ( + mock.patch.dict("sys.modules", {"whisperx": fake_whisperx}), + mock.patch.object(transcriber, "_prepare_transcription_models") as mock_prepare_runtime, + mock.patch.object(transcriber, "_load_align_model", return_value=(object(), object())), + mock.patch.object(transcriber, "prepare_models") as mock_prepare_models, + ): + result = transcriber._transcribe_inner(mock.MagicMock()) + + mock_prepare_models.assert_not_called() + mock_prepare_runtime.assert_called_once_with( + language="en", + step_key="transcribing", + show_deferred_align_note=False, + ) + assert ("begin", "transcribing") in progress.calls + assert ("begin", "preparing_models") not in progress.calls + assert result.language == "en" From 5537c69b76f2e28879eb29a650e8032689d25fd9 Mon Sep 17 00:00:00 2001 From: Pascal Berrang Date: Sun, 22 Feb 2026 20:40:37 +0100 Subject: [PATCH 3/3] Harden download progress parsing against malformed output --- src/ownscribe/progress.py | 8 +++++++- tests/test_progress.py | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/ownscribe/progress.py b/src/ownscribe/progress.py index 7853af2..83ee9d5 100644 --- a/src/ownscribe/progress.py +++ b/src/ownscribe/progress.py @@ -3,6 +3,7 @@ from __future__ import annotations import itertools +import logging import re import sys import threading @@ -247,7 +248,12 @@ def flush(self) -> None: self._buffer = "" def _consume(self, chunk: str) -> None: - if event := parse_download_progress(chunk): + try: + event = parse_download_progress(chunk) + except (ValueError, OverflowError): + logging.getLogger(__name__).debug("Ignoring malformed download progress output: %r", chunk, exc_info=True) + return + if event: self._update_fn(event) diff --git a/tests/test_progress.py b/tests/test_progress.py index 02ca19f..b40c24e 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -44,6 +44,15 @@ def test_writer_handles_partial_carriage_return_updates(self): assert events[-1].percent == 12.0 assert events[-1].filename == "model.bin" + def test_writer_ignores_unknown_size_units_without_crashing(self): + events = [] + writer = DownloadProgressWriter(events.append) + + writer.write("model.bin: 10%|# | 1EiB/2EiB [00:01<00:09]\r") + writer.flush() + + assert events == [] + class TestDownloadProgressFraction: def test_prefers_bytes_ratio(self):