From 54068f09a7f163e51140e7d58910ddd6171e735d Mon Sep 17 00:00:00 2001 From: Will Griffin Date: Mon, 26 Jan 2026 14:42:51 -0700 Subject: [PATCH 1/2] feat: expand to multi-capability studio server Transforms tts-server into studio-server with modular backends: ## New Capabilities - **TTS**: Text-to-speech with voice cloning (Qwen3-TTS) - **Face**: Face embedding extraction for IP-Adapter FaceID (InsightFace) - **Transcription**: Audio transcription with word timings (Whisper) ## Architecture - Modular backend system with abstract base classes - Each capability can be enabled/disabled via environment variables - Mock TTS backend for development without GPU ## API Changes - New endpoints: /v1/tts/*, /v1/face/*, /v1/transcribe - Removed legacy backwards-compatibility endpoints - Clean REST API structure ## Environment Variables - TTS_BACKEND: qwen3-tts (default) or mock - FACE_ENABLED: true/false - TRANSCRIPTION_ENABLED: true/false --- CLAUDE.md | 90 +++++-- Dockerfile | 14 +- README.md | 240 +++++++++++------ backends/__init__.py | 27 ++ backends/base.py | 25 ++ backends/face.py | 254 ++++++++++++++++++ backends/transcription.py | 221 ++++++++++++++++ backends/tts.py | 340 ++++++++++++++++++++++++ requirements.txt | 12 +- server.py | 536 +++++++++++++++----------------------- tests/test_server.py | 89 ++++--- 11 files changed, 1392 insertions(+), 456 deletions(-) create mode 100644 backends/__init__.py create mode 100644 backends/base.py create mode 100644 backends/face.py create mode 100644 backends/transcription.py create mode 100644 backends/tts.py diff --git a/CLAUDE.md b/CLAUDE.md index 9d5a780..c518197 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,43 +4,95 @@ This file provides guidance to Claude Code when working with this repository. ## Project Overview -TTS Server is a multi-model text-to-speech API with voice cloning support. It's designed to be backend-agnostic, allowing different TTS models to be plugged in. +Studio Server is an AI-powered studio utilities API for video production. It provides modular backends for: + +- **TTS**: Text-to-speech with voice cloning (Qwen3-TTS) +- **Face**: Face embedding extraction for IP-Adapter FaceID (InsightFace) +- **Transcription**: Audio transcription with word-level timestamps (Whisper) ## Architecture ``` -server.py -├── TTSBackend (abstract base class) -│ └── Qwen3TTSBackend (implementation) -├── BACKENDS registry -└── FastAPI application +studio-server/ +├── server.py # FastAPI application with all endpoints +├── backends/ +│ ├── __init__.py # Backend exports +│ ├── base.py # Abstract Backend base class +│ ├── tts.py # TTSBackend + Qwen3TTSBackend +│ ├── face.py # FaceBackend + InsightFaceBackend +│ └── transcription.py # TranscriptionBackend + WhisperBackend +└── tests/ ``` ## Key Design Decisions -1. **Backend Abstraction**: All TTS models implement `TTSBackend` interface -2. **ref_text Support**: Voice cloning accepts both audio and transcript for quality -3. **Stateless**: No voice profile storage - consuming app manages assets -4. **GPU First**: Designed for CUDA, falls back to CPU +1. **Modular Backends**: Each capability (TTS, Face, Transcription) has its own backend abstraction +2. **Optional Loading**: Face and Transcription backends can be disabled via environment variables +3. **Legacy Compatibility**: Old TTS endpoints remain for backwards compatibility +4. **Stateless**: No asset storage - consuming app manages files +5. **GPU First**: Designed for CUDA, falls back to CPU + +## API Structure + +**TTS Endpoints:** +- `GET /v1/tts/speakers` - List available speakers +- `POST /v1/tts/extract` - Extract voice prompt from audio +- `POST /v1/tts/synthesize` - Synthesize speech + +**Face Endpoints:** +- `POST /v1/face/embed` - Extract face embedding from image +- `POST /v1/face/embed-all` - Extract all faces from image +- `POST /v1/face/compare` - Compare two embeddings + +**Transcription Endpoints:** +- `POST /v1/transcribe` - Transcribe audio with word timings + +**Health/Info:** +- `GET /health` - Health check with backend status +- `GET /v1/models` - List available backends + +## Environment Variables -## API +| Variable | Default | Description | +|----------|---------|-------------| +| `TTS_BACKEND` | `qwen3-tts` | TTS backend (`qwen3-tts`, `mock`) | +| `FACE_ENABLED` | `true` | Load face backend | +| `FACE_BACKEND` | `insightface` | Face backend | +| `TRANSCRIPTION_ENABLED` | `true` | Load transcription backend | +| `TRANSCRIPTION_BACKEND` | `whisper` | Transcription backend | -- `POST /v1/audio/speech` - Main synthesis endpoint -- `GET /health` - Health check -- `GET /v1/models` - List backends +## Development Mode + +For local development without GPU, use the mock TTS backend: + +```bash +TTS_BACKEND=mock FACE_ENABLED=false TRANSCRIPTION_ENABLED=false python server.py +``` ## Adding New Backends -1. Extend `TTSBackend` class -2. Implement `load()`, `synthesize()`, `get_info()` -3. Add to `BACKENDS` dict +1. Create or extend the appropriate backend class in `backends/` +2. Implement `load()`, `get_info()`, and capability-specific methods +3. Add to the `*_BACKENDS` registry dict +4. Update environment variable handling in `server.py` if needed ## Deployment Deployed to happyvertical k8s cluster via Flux GitOps. -Manifests in: `happyvertical/iac/manifests/applications/tts-server/` +Manifests in: `happyvertical/iac/manifests/applications/studio-server/` ## Related Packages -- `@happyvertical/ai` - SDK package that may consume this service +- `@happyvertical/histrio` - Video production agent that consumes this service +- `@happyvertical/ai` - SDK package with TTS client - SMRT voice packages - May integrate via TypeScript client + +## Testing + +```bash +# Run tests +pytest + +# Run with coverage +pytest --cov=backends --cov=server +``` diff --git a/Dockerfile b/Dockerfile index 4085ed4..b5b2db7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ -# TTS Server - Multi-model text-to-speech API -# Supports: Qwen3-TTS with voice cloning +# Studio Server - AI-powered studio utilities for video production +# Supports: TTS (Qwen3-TTS), Face Embedding (InsightFace), Transcription (Whisper) FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime @@ -11,6 +11,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ffmpeg \ sox \ libsox-dev \ + libgl1-mesa-glx \ + libglib2.0-0 \ && rm -rf /var/lib/apt/lists/* # Install Python dependencies @@ -18,11 +20,15 @@ COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Copy application +COPY backends/ ./backends/ COPY server.py . -# Environment variables +# Environment variables - defaults ENV TTS_BACKEND=qwen3-tts -ENV TTS_MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base +ENV FACE_ENABLED=true +ENV FACE_BACKEND=insightface +ENV TRANSCRIPTION_ENABLED=true +ENV TRANSCRIPTION_BACKEND=whisper # Expose port EXPOSE 8000 diff --git a/README.md b/README.md index 231730b..b2ef4eb 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,43 @@ -# TTS Server +# Studio Server -Multi-model text-to-speech API with voice cloning support. +AI-powered studio utilities for video production. ## Features -- **Multi-model architecture** - Pluggable backend system for different TTS models -- **Voice cloning** - Clone voices with reference audio + transcript -- **Voice prompt caching** - Extract and reuse voice embeddings for faster synthesis -- **ref_text support** - Provide transcript for better voice cloning quality +- **TTS (Text-to-Speech)** - Voice synthesis with cloning support (Qwen3-TTS) +- **Face Embedding** - Extract face embeddings for IP-Adapter FaceID (InsightFace) +- **Transcription** - Audio transcription with word-level timestamps (Whisper) +- **Modular backends** - Pluggable architecture for each capability - **GPU accelerated** - CUDA support for fast inference -## Supported Backends +## API Overview -| Backend | Voice Cloning | ref_text | Voice Prompt | -|---------|--------------|----------|--------------| -| `qwen3-tts` | ✅ | ✅ | ✅ | +| Capability | Endpoint | Description | +|------------|----------|-------------| +| **TTS** | `POST /v1/tts/synthesize` | Synthesize speech from text | +| **TTS** | `POST /v1/tts/extract` | Extract reusable voice prompt | +| **TTS** | `GET /v1/tts/speakers` | List available speakers | +| **Face** | `POST /v1/face/embed` | Extract face embedding from image | +| **Face** | `POST /v1/face/embed-all` | Extract all faces from image | +| **Face** | `POST /v1/face/compare` | Compare two face embeddings | +| **Transcription** | `POST /v1/transcribe` | Transcribe audio with word timings | -## API Endpoints +## Quick Start -### `POST /v1/audio/speech` +```bash +# Install dependencies +pip install -r requirements.txt + +# Run server (loads all backends) +python server.py + +# Run with specific backends disabled +FACE_ENABLED=false TRANSCRIPTION_ENABLED=false python server.py +``` + +## TTS Endpoints + +### `POST /v1/tts/synthesize` Synthesize speech from text. @@ -28,133 +47,200 @@ Synthesize speech from text. - `speaker`: Preset speaker for basic TTS (e.g., "Vivian", "Ryan") - `ref_audio`: Reference audio file for voice cloning (on-the-fly) - `ref_text`: Transcript of reference audio (improves cloning quality) -- `voice_prompt`: Pre-extracted voice prompt from `/v1/voice/extract` (cached) +- `voice_prompt`: Pre-extracted voice prompt from `/v1/tts/extract` (cached) - `speed`: Speech speed multiplier (default: 1.0) -**Example:** +**Examples:** ```bash # Basic TTS with default speaker -curl -X POST http://localhost:8000/v1/audio/speech \ - -F "text=Hello, world!" \ - -o output.wav - -# Basic TTS with specific speaker -curl -X POST http://localhost:8000/v1/audio/speech \ +curl -X POST http://localhost:8000/v1/tts/synthesize \ -F "text=Hello, world!" \ - -F "speaker=Ryan" \ -o output.wav -# Voice cloning (on-the-fly) - processes ref_audio each time -curl -X POST http://localhost:8000/v1/audio/speech \ +# Voice cloning with reference audio +curl -X POST http://localhost:8000/v1/tts/synthesize \ -F "text=Hello, this is my cloned voice." \ -F "ref_audio=@reference.wav" \ -F "ref_text=This is the transcript of my reference audio." \ -o cloned.wav - -# Voice cloning (cached) - faster, uses pre-extracted prompt -curl -X POST http://localhost:8000/v1/audio/speech \ - -F "text=Hello, this is my cloned voice." \ - -F "voice_prompt=$VOICE_PROMPT" \ - -o cloned.wav ``` -### `POST /v1/voice/extract` - -Extract a reusable voice prompt from reference audio. The returned prompt can be cached and reused with `/v1/audio/speech` to avoid re-processing the reference audio on every request. - -**Parameters:** -- `ref_audio` (required): Reference audio file -- `ref_text`: Transcript of reference audio (improves quality) -- `language`: Language of the reference audio (default: "English") +### `POST /v1/tts/extract` -**Returns:** -- `voice_prompt`: Base64-encoded voice embedding (store this) -- `format`: Encoding format (e.g., "base64-numpy") +Extract a reusable voice prompt from reference audio. -**Example:** ```bash -# Extract voice prompt -VOICE_PROMPT=$(curl -X POST http://localhost:8000/v1/voice/extract \ +VOICE_PROMPT=$(curl -X POST http://localhost:8000/v1/tts/extract \ -F "ref_audio=@reference.wav" \ - -F "ref_text=This is the transcript of my reference audio." \ + -F "ref_text=This is the transcript." \ | jq -r '.voice_prompt') -# Use the cached prompt for multiple synthesis requests -curl -X POST http://localhost:8000/v1/audio/speech \ - -F "text=First sentence with cloned voice." \ +# Reuse for multiple synthesis requests +curl -X POST http://localhost:8000/v1/tts/synthesize \ + -F "text=First sentence." \ -F "voice_prompt=$VOICE_PROMPT" \ - -o output1.wav + -o output.wav +``` -curl -X POST http://localhost:8000/v1/audio/speech \ - -F "text=Second sentence with same voice." \ - -F "voice_prompt=$VOICE_PROMPT" \ - -o output2.wav +## Face Embedding Endpoints + +### `POST /v1/face/embed` + +Extract face embedding from an image for IP-Adapter FaceID (PerformerDNA). + +```bash +curl -X POST http://localhost:8000/v1/face/embed \ + -F "image=@portrait.jpg" \ + -F "return_bbox=true" ``` -### `GET /health` +**Response:** +```json +{ + "embedding": "base64-encoded-512-dim-vector", + "embedding_dim": 512, + "confidence": 0.98, + "bbox": [100, 50, 300, 350] +} +``` -Health check endpoint. +### `POST /v1/face/embed-all` -### `GET /v1/models` +Extract embeddings for all faces in an image. -List available TTS backends. +```bash +curl -X POST http://localhost:8000/v1/face/embed-all \ + -F "image=@group-photo.jpg" \ + -F "max_faces=5" +``` -### `GET /v1/speakers` +### `POST /v1/face/compare` -List available preset speakers for basic TTS. +Compare two face embeddings for similarity. -## Docker +```bash +curl -X POST http://localhost:8000/v1/face/compare \ + -H "Content-Type: application/json" \ + -d '{"embedding1": "...", "embedding2": "..."}' +``` + +**Response:** +```json +{ + "similarity": 0.85, + "same_person": true +} +``` + +## Transcription Endpoints + +### `POST /v1/transcribe` + +Transcribe audio with word-level timestamps for lip-sync alignment. ```bash -# Build -docker build -t tts-server . +curl -X POST http://localhost:8000/v1/transcribe \ + -F "audio=@speech.wav" \ + -F "word_timestamps=true" +``` -# Run with GPU -docker run --gpus all -p 8000:8000 tts-server +**Response:** +```json +{ + "text": "Hello, this is a test.", + "language": "en", + "duration": 2.5, + "word_timings": [ + {"word": "Hello", "start": 0.0, "end": 0.4, "confidence": 0.98}, + {"word": "this", "start": 0.5, "end": 0.7, "confidence": 0.95}, + ... + ] +} ``` ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| -| `TTS_BACKEND` | `qwen3-tts` | TTS backend to use | +| `TTS_BACKEND` | `qwen3-tts` | TTS backend (`qwen3-tts`, `mock`) | +| `FACE_ENABLED` | `true` | Enable face embedding backend | +| `FACE_BACKEND` | `insightface` | Face backend to use | +| `TRANSCRIPTION_ENABLED` | `true` | Enable transcription backend | +| `TRANSCRIPTION_BACKEND` | `whisper` | Transcription backend to use | -## Development +### Development Mode (No GPU) + +For local development without GPU, use the mock TTS backend: ```bash -# Install dependencies -pip install -r requirements.txt -pip install -r requirements-dev.txt +TTS_BACKEND=mock FACE_ENABLED=false TRANSCRIPTION_ENABLED=false python server.py +``` -# Run tests -pytest +## Docker -# Run server locally -python server.py +```bash +# Build +docker build -t studio-server . + +# Run with GPU (all backends) +docker run --gpus all -p 8000:8000 studio-server + +# Run TTS only +docker run --gpus all -p 8000:8000 \ + -e FACE_ENABLED=false \ + -e TRANSCRIPTION_ENABLED=false \ + studio-server +``` + +## Project Structure + +``` +studio-server/ +├── server.py # FastAPI application +├── backends/ +│ ├── __init__.py +│ ├── base.py # Base Backend class +│ ├── tts.py # TTS backends (Qwen3-TTS) +│ ├── face.py # Face backends (InsightFace) +│ └── transcription.py # Transcription backends (Whisper) +├── tests/ +├── requirements.txt +├── Dockerfile +└── README.md ``` ## Adding New Backends -1. Create a class extending `TTSBackend` -2. Implement `load()`, `synthesize()`, and `get_info()` -3. Register in the `BACKENDS` dict +Each backend type has its own module. To add a new backend: ```python +# backends/tts.py class MyTTSBackend(TTSBackend): def load(self) -> None: # Load your model pass - def synthesize(self, text, language, ref_audio, ref_text, speed): + def synthesize(self, text, language, speaker, ref_audio, ref_text, speed): # Generate audio return wav_bytes, sample_rate def get_info(self) -> dict: return {"backend": "my-tts", ...} -BACKENDS["my-tts"] = MyTTSBackend + def get_speakers(self) -> List[str]: + return ["speaker1", "speaker2"] + +# Register in TTS_BACKENDS dict +TTS_BACKENDS["my-tts"] = MyTTSBackend ``` +## Legacy Endpoints + +For backwards compatibility, the old TTS endpoints are still available: +- `GET /v1/speakers` → `/v1/tts/speakers` +- `POST /v1/voice/extract` → `/v1/tts/extract` +- `POST /v1/audio/speech` → `/v1/tts/synthesize` + ## License MIT diff --git a/backends/__init__.py b/backends/__init__.py new file mode 100644 index 0000000..1357013 --- /dev/null +++ b/backends/__init__.py @@ -0,0 +1,27 @@ +""" +Studio Server Backends + +Modular backends for various AI-powered studio utilities: +- TTS: Text-to-speech with voice cloning +- Face: Face embedding extraction for IP-Adapter +- Transcription: Audio transcription with word timings +- Lighting: Scene lighting analysis +""" + +from .base import Backend +from .tts import TTSBackend, Qwen3TTSBackend, TTS_BACKENDS +from .face import FaceBackend, InsightFaceBackend, FACE_BACKENDS +from .transcription import TranscriptionBackend, WhisperBackend, TRANSCRIPTION_BACKENDS + +__all__ = [ + "Backend", + "TTSBackend", + "Qwen3TTSBackend", + "TTS_BACKENDS", + "FaceBackend", + "InsightFaceBackend", + "FACE_BACKENDS", + "TranscriptionBackend", + "WhisperBackend", + "TRANSCRIPTION_BACKENDS", +] diff --git a/backends/base.py b/backends/base.py new file mode 100644 index 0000000..0bdffab --- /dev/null +++ b/backends/base.py @@ -0,0 +1,25 @@ +""" +Base Backend Abstraction + +All studio backends implement this interface for consistent lifecycle management. +""" + +from abc import ABC, abstractmethod + + +class Backend(ABC): + """Abstract base class for all studio backends.""" + + @abstractmethod + def load(self) -> None: + """Load the model/resources into memory.""" + pass + + @abstractmethod + def get_info(self) -> dict: + """Return backend information for health checks.""" + pass + + def unload(self) -> None: + """Unload the model/resources from memory. Optional cleanup.""" + pass diff --git a/backends/face.py b/backends/face.py new file mode 100644 index 0000000..1fd4fbf --- /dev/null +++ b/backends/face.py @@ -0,0 +1,254 @@ +""" +Face Backend - Face embedding extraction for IP-Adapter consistency. + +Extracts face embeddings from images for use with IP-Adapter FaceID +to maintain consistent face generation across videos (PerformerDNA). + +Supported models: +- insightface: InsightFace with buffalo_l model (default) +""" + +import io +import base64 +from abc import abstractmethod +from typing import List, Optional, Dict, Any + +import numpy as np + +from .base import Backend + + +class FaceBackend(Backend): + """Abstract base class for face embedding backends.""" + + @abstractmethod + def extract_embedding( + self, + image: bytes, + return_bbox: bool = False, + ) -> Dict[str, Any]: + """ + Extract face embedding from an image. + + Args: + image: Image bytes (JPEG, PNG, etc.) + return_bbox: Whether to return bounding box + + Returns: + Dict with: + - embedding: Base64-encoded face embedding (512-dim) + - bbox: Optional bounding box [x1, y1, x2, y2] + - confidence: Detection confidence + """ + pass + + @abstractmethod + def extract_embeddings( + self, + image: bytes, + max_faces: int = 5, + ) -> List[Dict[str, Any]]: + """ + Extract embeddings for all faces in an image. + + Args: + image: Image bytes + max_faces: Maximum number of faces to extract + + Returns: + List of embedding dicts (same format as extract_embedding) + """ + pass + + @abstractmethod + def compare_embeddings( + self, + embedding1: str, + embedding2: str, + ) -> float: + """ + Compare two face embeddings for similarity. + + Args: + embedding1: Base64-encoded embedding + embedding2: Base64-encoded embedding + + Returns: + Cosine similarity score (0-1) + """ + pass + + +class InsightFaceBackend(FaceBackend): + """ + InsightFace backend for face embedding extraction. + + Uses the buffalo_l model for high-quality embeddings compatible + with IP-Adapter FaceID. + """ + + def __init__(self, model_name: str = "buffalo_l"): + self.model_name = model_name + self.app = None + self.device = "cuda" if self._has_cuda() else "cpu" + + def _has_cuda(self) -> bool: + """Check if CUDA is available.""" + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + + def load(self) -> None: + """Load InsightFace model.""" + from insightface.app import FaceAnalysis + + print(f"Loading InsightFace model: {self.model_name}") + + # Initialize FaceAnalysis + self.app = FaceAnalysis( + name=self.model_name, + providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] + if self.device == "cuda" else ['CPUExecutionProvider'] + ) + + # Prepare with detection size + self.app.prepare(ctx_id=0 if self.device == "cuda" else -1, det_size=(640, 640)) + + print(f"InsightFace model loaded on {self.device}") + + def extract_embedding( + self, + image: bytes, + return_bbox: bool = False, + ) -> Dict[str, Any]: + """Extract face embedding from the primary face in an image.""" + if self.app is None: + raise RuntimeError("Model not loaded") + + import cv2 + import numpy as np + + # Decode image + nparr = np.frombuffer(image, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is None: + raise ValueError("Could not decode image") + + # Detect faces + faces = self.app.get(img) + + if not faces: + raise ValueError("No face detected in image") + + # Get the largest face (most likely the main subject) + face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])) + + # Encode embedding to base64 + embedding_bytes = face.embedding.astype(np.float32).tobytes() + embedding_b64 = base64.b64encode(embedding_bytes).decode('utf-8') + + result = { + "embedding": embedding_b64, + "embedding_dim": len(face.embedding), + "confidence": float(face.det_score), + } + + if return_bbox: + result["bbox"] = face.bbox.tolist() + + return result + + def extract_embeddings( + self, + image: bytes, + max_faces: int = 5, + ) -> List[Dict[str, Any]]: + """Extract embeddings for all faces in an image.""" + if self.app is None: + raise RuntimeError("Model not loaded") + + import cv2 + import numpy as np + + # Decode image + nparr = np.frombuffer(image, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is None: + raise ValueError("Could not decode image") + + # Detect faces + faces = self.app.get(img) + + if not faces: + return [] + + # Sort by face size (largest first) + faces = sorted( + faces, + key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]), + reverse=True + )[:max_faces] + + results = [] + for face in faces: + embedding_bytes = face.embedding.astype(np.float32).tobytes() + embedding_b64 = base64.b64encode(embedding_bytes).decode('utf-8') + + results.append({ + "embedding": embedding_b64, + "embedding_dim": len(face.embedding), + "confidence": float(face.det_score), + "bbox": face.bbox.tolist(), + }) + + return results + + def compare_embeddings( + self, + embedding1: str, + embedding2: str, + ) -> float: + """Compare two face embeddings using cosine similarity.""" + import numpy as np + + # Decode embeddings + emb1 = np.frombuffer(base64.b64decode(embedding1), dtype=np.float32) + emb2 = np.frombuffer(base64.b64decode(embedding2), dtype=np.float32) + + # Normalize + emb1 = emb1 / np.linalg.norm(emb1) + emb2 = emb2 / np.linalg.norm(emb2) + + # Cosine similarity + similarity = float(np.dot(emb1, emb2)) + + return similarity + + def get_info(self) -> dict: + return { + "backend": "insightface", + "model": self.model_name, + "embedding_dim": 512, + "device": self.device, + } + + def unload(self) -> None: + """Unload model from memory.""" + self.app = None + + +# Backend registry +FACE_BACKENDS = { + "insightface": InsightFaceBackend, +} + + +def get_face_backend(name: str) -> FaceBackend: + """Get a face backend instance by name.""" + if name not in FACE_BACKENDS: + raise ValueError(f"Unknown face backend: {name}. Available: {list(FACE_BACKENDS.keys())}") + return FACE_BACKENDS[name]() diff --git a/backends/transcription.py b/backends/transcription.py new file mode 100644 index 0000000..0f260cd --- /dev/null +++ b/backends/transcription.py @@ -0,0 +1,221 @@ +""" +Transcription Backend - Audio transcription with word-level timing. + +Extracts transcripts and word timings from audio for lip-sync alignment. + +Supported models: +- whisper: OpenAI Whisper with word-level timestamps +""" + +import io +from abc import abstractmethod +from typing import List, Optional, Dict, Any + +from .base import Backend + + +class WordTiming: + """Word timing information for lip-sync.""" + + def __init__(self, word: str, start: float, end: float, confidence: float = 1.0): + self.word = word + self.start = start + self.end = end + self.confidence = confidence + + def to_dict(self) -> Dict[str, Any]: + return { + "word": self.word, + "start": self.start, + "end": self.end, + "confidence": self.confidence, + } + + +class TranscriptionResult: + """Transcription result with text and word timings.""" + + def __init__( + self, + text: str, + language: str, + duration: float, + word_timings: List[WordTiming], + ): + self.text = text + self.language = language + self.duration = duration + self.word_timings = word_timings + + def to_dict(self) -> Dict[str, Any]: + return { + "text": self.text, + "language": self.language, + "duration": self.duration, + "word_timings": [w.to_dict() for w in self.word_timings], + } + + +class TranscriptionBackend(Backend): + """Abstract base class for transcription backends.""" + + @abstractmethod + def transcribe( + self, + audio: bytes, + language: Optional[str] = None, + word_timestamps: bool = True, + ) -> TranscriptionResult: + """ + Transcribe audio to text with optional word-level timestamps. + + Args: + audio: Audio bytes (WAV, MP3, etc.) + language: Target language code (auto-detect if None) + word_timestamps: Whether to include word-level timestamps + + Returns: + TranscriptionResult with text and word timings + """ + pass + + +class WhisperBackend(TranscriptionBackend): + """ + OpenAI Whisper backend for transcription. + + Uses faster-whisper for efficient inference with word-level timestamps. + """ + + def __init__( + self, + model_size: str = "large-v3", + compute_type: str = "float16", + ): + self.model_size = model_size + self.compute_type = compute_type + self.model = None + self.device = "cuda" if self._has_cuda() else "cpu" + + def _has_cuda(self) -> bool: + """Check if CUDA is available.""" + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + + def load(self) -> None: + """Load Whisper model.""" + from faster_whisper import WhisperModel + + print(f"Loading Whisper model: {self.model_size}") + + # Adjust compute type for CPU + compute_type = self.compute_type + if self.device == "cpu" and compute_type == "float16": + compute_type = "int8" + + self.model = WhisperModel( + self.model_size, + device=self.device, + compute_type=compute_type, + ) + + print(f"Whisper model loaded on {self.device}") + + def transcribe( + self, + audio: bytes, + language: Optional[str] = None, + word_timestamps: bool = True, + ) -> TranscriptionResult: + """Transcribe audio with word-level timestamps.""" + if self.model is None: + raise RuntimeError("Model not loaded") + + import tempfile + import os + import soundfile as sf + import numpy as np + + # Write audio to temp file (faster-whisper requires file path) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + temp_path = f.name + + # Try to decode audio + try: + audio_buffer = io.BytesIO(audio) + audio_data, sample_rate = sf.read(audio_buffer) + + # Ensure mono + if len(audio_data.shape) > 1: + audio_data = audio_data.mean(axis=1) + + # Write as WAV + sf.write(temp_path, audio_data, sample_rate) + duration = len(audio_data) / sample_rate + except Exception as e: + os.unlink(temp_path) + raise ValueError(f"Could not decode audio: {e}") + + try: + # Transcribe + segments, info = self.model.transcribe( + temp_path, + language=language, + word_timestamps=word_timestamps, + vad_filter=True, # Filter out silence + ) + + # Collect results + full_text = [] + word_timings = [] + + for segment in segments: + full_text.append(segment.text.strip()) + + if word_timestamps and segment.words: + for word in segment.words: + word_timings.append(WordTiming( + word=word.word.strip(), + start=word.start, + end=word.end, + confidence=word.probability if hasattr(word, 'probability') else 1.0, + )) + + return TranscriptionResult( + text=" ".join(full_text), + language=info.language, + duration=duration, + word_timings=word_timings, + ) + + finally: + os.unlink(temp_path) + + def get_info(self) -> dict: + return { + "backend": "whisper", + "model_size": self.model_size, + "compute_type": self.compute_type, + "device": self.device, + "supports_word_timestamps": True, + } + + def unload(self) -> None: + """Unload model from memory.""" + self.model = None + + +# Backend registry +TRANSCRIPTION_BACKENDS = { + "whisper": WhisperBackend, +} + + +def get_transcription_backend(name: str) -> TranscriptionBackend: + """Get a transcription backend instance by name.""" + if name not in TRANSCRIPTION_BACKENDS: + raise ValueError(f"Unknown transcription backend: {name}. Available: {list(TRANSCRIPTION_BACKENDS.keys())}") + return TRANSCRIPTION_BACKENDS[name]() diff --git a/backends/tts.py b/backends/tts.py new file mode 100644 index 0000000..a9b768d --- /dev/null +++ b/backends/tts.py @@ -0,0 +1,340 @@ +""" +TTS Backend - Text-to-speech with voice cloning support. + +Supported models: +- qwen3-tts: Qwen3-TTS with voice cloning via ref_audio + ref_text +""" + +import io +import base64 +from abc import abstractmethod +from typing import Optional, Tuple, List, Any + +import torch +import numpy as np +import soundfile as sf + +from .base import Backend + + +class TTSBackend(Backend): + """Abstract base class for TTS model backends.""" + + @abstractmethod + def synthesize( + self, + text: str, + language: str = "English", + speaker: Optional[str] = None, + ref_audio: Optional[Tuple[bytes, int]] = None, + ref_text: Optional[str] = None, + speed: float = 1.0, + ) -> Tuple[bytes, int]: + """ + Synthesize speech from text. + + Args: + text: Text to synthesize + language: Target language + speaker: Preset speaker name (for basic TTS) + ref_audio: Optional (audio_bytes, sample_rate) for voice cloning + ref_text: Optional transcript of reference audio for voice cloning + speed: Speech speed multiplier + + Returns: + Tuple of (wav_bytes, sample_rate) + """ + pass + + @abstractmethod + def get_speakers(self) -> List[str]: + """Return available preset speakers.""" + pass + + def extract_voice_prompt( + self, + ref_audio: Tuple[Any, int], + ref_text: Optional[str] = None, + ) -> str: + """ + Extract a reusable voice prompt from reference audio. + + Args: + ref_audio: Tuple of (audio_data, sample_rate) + ref_text: Optional transcript of reference audio + + Returns: + Base64-encoded voice prompt that can be reused + """ + raise NotImplementedError("This backend does not support voice prompt extraction") + + def synthesize_with_prompt( + self, + text: str, + voice_prompt: str, + language: str = "English", + speed: float = 1.0, + ) -> Tuple[bytes, int]: + """ + Synthesize speech using a pre-extracted voice prompt. + + Args: + text: Text to synthesize + voice_prompt: Base64-encoded voice prompt from extract_voice_prompt + language: Target language + speed: Speech speed multiplier + + Returns: + Tuple of (wav_bytes, sample_rate) + """ + raise NotImplementedError("This backend does not support voice prompt synthesis") + + +class Qwen3TTSBackend(TTSBackend): + """ + Qwen3-TTS backend with voice cloning support. + + Uses CustomVoice model for basic TTS and Base model for voice cloning. + """ + + def __init__( + self, + custom_voice_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + base_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base", + ): + self.custom_voice_model_name = custom_voice_model + self.base_model_name = base_model + self.custom_voice_model = None + self.base_model = None + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + def load(self) -> None: + from qwen_tts import Qwen3TTSModel + from qwen_tts.inference.qwen3_tts_model import VoiceClonePromptItem + + # Allow VoiceClonePromptItem for safe torch.load deserialization + torch.serialization.add_safe_globals([VoiceClonePromptItem]) + + # Load CustomVoice model for basic TTS with preset speakers + print(f"Loading CustomVoice model: {self.custom_voice_model_name}") + self.custom_voice_model = Qwen3TTSModel.from_pretrained( + self.custom_voice_model_name, + device_map=self.device, + dtype=self.dtype, + ) + print("CustomVoice model loaded") + + # Load Base model for voice cloning + print(f"Loading Base model: {self.base_model_name}") + self.base_model = Qwen3TTSModel.from_pretrained( + self.base_model_name, + device_map=self.device, + dtype=self.dtype, + ) + print("Base model loaded") + + def synthesize( + self, + text: str, + language: str = "English", + speaker: Optional[str] = None, + ref_audio: Optional[Tuple[bytes, int]] = None, + ref_text: Optional[str] = None, + speed: float = 1.0, + ) -> Tuple[bytes, int]: + if self.custom_voice_model is None or self.base_model is None: + raise RuntimeError("Models not loaded") + + if ref_audio: + # Voice cloning path - use Base model + # Normalize ref_text - strip whitespace, convert empty to None + normalized_ref_text = ref_text.strip() if ref_text else None + + wavs, sr = self.base_model.generate_voice_clone( + text=text, + language=language, + ref_audio=ref_audio, + ref_text=normalized_ref_text, + max_new_tokens=2048, + ) + else: + # Basic TTS path - use CustomVoice model with preset speaker + speaker = speaker or "Vivian" + wavs, sr = self.custom_voice_model.generate_custom_voice( + text=text, + language=language, + speaker=speaker, + ) + + # Convert to WAV bytes + wav_buffer = io.BytesIO() + sf.write(wav_buffer, wavs[0], sr, format='WAV') + wav_buffer.seek(0) + + return wav_buffer.read(), sr + + def get_info(self) -> dict: + return { + "backend": "qwen3-tts", + "custom_voice_model": self.custom_voice_model_name, + "base_model": self.base_model_name, + "supports_voice_cloning": True, + "supports_ref_text": True, + "supports_voice_prompt": True, + "device": "cuda" if torch.cuda.is_available() else "cpu", + } + + def get_speakers(self) -> List[str]: + """Return available preset speakers from CustomVoice model.""" + if self.custom_voice_model is None: + return [] + try: + return self.custom_voice_model.get_supported_speakers() + except Exception: + # Fallback to known speakers + return ["Vivian", "Ryan", "Sophia", "Isabella", "Evan", "Lily"] + + def extract_voice_prompt( + self, + ref_audio: Tuple[Any, int], + ref_text: Optional[str] = None, + ) -> str: + """Extract a reusable voice prompt from reference audio.""" + if self.base_model is None: + raise RuntimeError("Base model not loaded") + + # Normalize ref_text - strip whitespace, convert empty to None + normalized_ref_text = ref_text.strip() if ref_text else None + + # Use the Base model's create_voice_clone_prompt method + voice_prompt = self.base_model.create_voice_clone_prompt( + ref_audio=ref_audio, + ref_text=normalized_ref_text, + ) + + # Serialize using torch.save which handles tensors natively + buffer = io.BytesIO() + torch.save(voice_prompt, buffer) + buffer.seek(0) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + def synthesize_with_prompt( + self, + text: str, + voice_prompt: str, + language: str = "English", + speed: float = 1.0, + ) -> Tuple[bytes, int]: + """Synthesize speech using a pre-extracted voice prompt.""" + if self.base_model is None: + raise RuntimeError("Base model not loaded") + + # Decode the voice prompt using torch.load + buffer = io.BytesIO(base64.b64decode(voice_prompt)) + prompt_tensor = torch.load(buffer, map_location=self.device, weights_only=True) + + # Generate using the cached voice prompt + wavs, sr = self.base_model.generate_voice_clone( + text=text, + language=language, + voice_clone_prompt=prompt_tensor, + max_new_tokens=2048, + ) + + # Convert to WAV bytes + wav_buffer = io.BytesIO() + sf.write(wav_buffer, wavs[0], sr, format='WAV') + wav_buffer.seek(0) + + return wav_buffer.read(), sr + + def unload(self) -> None: + """Unload models from memory.""" + self.custom_voice_model = None + self.base_model = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +class MockTTSBackend(TTSBackend): + """ + Mock TTS backend for development/testing without GPU. + + Generates silence audio for testing API routes. + """ + + def __init__(self): + self.sample_rate = 24000 + + def load(self) -> None: + print("MockTTS backend loaded (no GPU required)") + + def synthesize( + self, + text: str, + language: str = "English", + speaker: Optional[str] = None, + ref_audio: Optional[Tuple[bytes, int]] = None, + ref_text: Optional[str] = None, + speed: float = 1.0, + ) -> Tuple[bytes, int]: + # Generate ~1 second of silence per 10 characters (rough estimate) + duration = max(1.0, len(text) / 20) + num_samples = int(duration * self.sample_rate) + audio = np.zeros(num_samples, dtype=np.float32) + + # Add a small click at the beginning to indicate something was generated + if num_samples > 100: + audio[50:60] = 0.1 + + # Convert to WAV bytes + wav_buffer = io.BytesIO() + sf.write(wav_buffer, audio, self.sample_rate, format='WAV') + wav_buffer.seek(0) + + return wav_buffer.read(), self.sample_rate + + def get_info(self) -> dict: + return { + "backend": "mock", + "supports_voice_cloning": True, + "supports_ref_text": True, + "supports_voice_prompt": True, + "note": "Mock backend for development - generates silence", + } + + def get_speakers(self) -> List[str]: + return ["MockSpeaker1", "MockSpeaker2"] + + def extract_voice_prompt( + self, + ref_audio: Tuple[Any, int], + ref_text: Optional[str] = None, + ) -> str: + # Return a mock voice prompt (base64 of "mock_voice_prompt") + return base64.b64encode(b"mock_voice_prompt_data").decode('utf-8') + + def synthesize_with_prompt( + self, + text: str, + voice_prompt: str, + language: str = "English", + speed: float = 1.0, + ) -> Tuple[bytes, int]: + # Same as synthesize for mock + return self.synthesize(text, language, speed=speed) + + +# Backend registry +TTS_BACKENDS = { + "qwen3-tts": Qwen3TTSBackend, + "mock": MockTTSBackend, +} + + +def get_tts_backend(name: str) -> TTSBackend: + """Get a TTS backend instance by name.""" + if name not in TTS_BACKENDS: + raise ValueError(f"Unknown TTS backend: {name}. Available: {list(TTS_BACKENDS.keys())}") + return TTS_BACKENDS[name]() diff --git a/requirements.txt b/requirements.txt index 82fdbac..47a7f46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,20 @@ fastapi>=0.115.0 uvicorn[standard]>=0.32.0 python-multipart>=0.0.12 +pydantic>=2.0.0 # Audio processing soundfile>=0.12.1 numpy>=1.24.0 -# Qwen3-TTS backend +# TTS backend (Qwen3-TTS) qwen-tts>=0.0.5 +torch>=2.0.0 + +# Face backend (InsightFace) +insightface>=0.7.0 +opencv-python>=4.8.0 +onnxruntime-gpu>=1.16.0 # or onnxruntime for CPU + +# Transcription backend (faster-whisper) +faster-whisper>=1.0.0 diff --git a/server.py b/server.py index 25c1789..14ec3ac 100644 --- a/server.py +++ b/server.py @@ -1,372 +1,147 @@ """ -TTS Server - Multi-model text-to-speech API with voice cloning support. +Studio Server - AI-powered studio utilities for video production. -Supported models: -- qwen3-tts: Qwen3-TTS with voice cloning via ref_audio + ref_text +Provides REST API endpoints for: +- TTS: Text-to-speech with voice cloning (Qwen3-TTS) +- Face: Face embedding extraction for IP-Adapter (InsightFace) +- Transcription: Audio transcription with word timings (Whisper) -Voice cloning can be done two ways: -1. On-the-fly: Pass ref_audio + ref_text with each synthesis request -2. Cached: Extract a voice_prompt once, then reuse it for multiple requests +All backends are modular and can be swapped via environment variables. """ import io import os -import base64 -from abc import ABC, abstractmethod -from typing import Optional, Tuple, List, Any +from typing import Optional from contextlib import asynccontextmanager -import torch -import numpy as np import soundfile as sf from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import StreamingResponse, JSONResponse +from pydantic import BaseModel - -# ============================================================================= -# Model Backend Abstraction -# ============================================================================= - -class TTSBackend(ABC): - """Abstract base class for TTS model backends.""" - - @abstractmethod - def load(self) -> None: - """Load the model into memory.""" - pass - - @abstractmethod - def synthesize( - self, - text: str, - language: str = "English", - speaker: Optional[str] = None, - ref_audio: Optional[Tuple[bytes, int]] = None, - ref_text: Optional[str] = None, - speed: float = 1.0, - ) -> Tuple[bytes, int]: - """ - Synthesize speech from text. - - Args: - text: Text to synthesize - language: Target language - speaker: Preset speaker name (for basic TTS) - ref_audio: Optional (audio_bytes, sample_rate) for voice cloning - ref_text: Optional transcript of reference audio for voice cloning - speed: Speech speed multiplier - - Returns: - Tuple of (wav_bytes, sample_rate) - """ - pass - - @abstractmethod - def get_info(self) -> dict: - """Return model information.""" - pass - - @abstractmethod - def get_speakers(self) -> List[str]: - """Return available preset speakers.""" - pass - - def extract_voice_prompt( - self, - ref_audio: Tuple[Any, int], - ref_text: Optional[str] = None, - ) -> str: - """ - Extract a reusable voice prompt from reference audio. - - Args: - ref_audio: Tuple of (audio_data, sample_rate) - ref_text: Optional transcript of reference audio - - Returns: - Base64-encoded voice prompt that can be reused - """ - raise NotImplementedError("This backend does not support voice prompt extraction") - - def synthesize_with_prompt( - self, - text: str, - voice_prompt: str, - language: str = "English", - speed: float = 1.0, - ) -> Tuple[bytes, int]: - """ - Synthesize speech using a pre-extracted voice prompt. - - Args: - text: Text to synthesize - voice_prompt: Base64-encoded voice prompt from extract_voice_prompt - language: Target language - speed: Speech speed multiplier - - Returns: - Tuple of (wav_bytes, sample_rate) - """ - raise NotImplementedError("This backend does not support voice prompt synthesis") - - -class Qwen3TTSBackend(TTSBackend): - """ - Qwen3-TTS backend with voice cloning support. - - Uses CustomVoice model for basic TTS and Base model for voice cloning. - """ - - def __init__( - self, - custom_voice_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", - base_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base", - ): - self.custom_voice_model_name = custom_voice_model - self.base_model_name = base_model - self.custom_voice_model = None - self.base_model = None - self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 - - def load(self) -> None: - from qwen_tts import Qwen3TTSModel - from qwen_tts.inference.qwen3_tts_model import VoiceClonePromptItem - - # Allow VoiceClonePromptItem for safe torch.load deserialization - torch.serialization.add_safe_globals([VoiceClonePromptItem]) - - # Load CustomVoice model for basic TTS with preset speakers - print(f"Loading CustomVoice model: {self.custom_voice_model_name}") - self.custom_voice_model = Qwen3TTSModel.from_pretrained( - self.custom_voice_model_name, - device_map=self.device, - dtype=self.dtype, - ) - print("CustomVoice model loaded") - - # Load Base model for voice cloning - print(f"Loading Base model: {self.base_model_name}") - self.base_model = Qwen3TTSModel.from_pretrained( - self.base_model_name, - device_map=self.device, - dtype=self.dtype, - ) - print("Base model loaded") - - def synthesize( - self, - text: str, - language: str = "English", - speaker: Optional[str] = None, - ref_audio: Optional[Tuple[bytes, int]] = None, - ref_text: Optional[str] = None, - speed: float = 1.0, - ) -> Tuple[bytes, int]: - if self.custom_voice_model is None or self.base_model is None: - raise RuntimeError("Models not loaded") - - if ref_audio: - # Voice cloning path - use Base model - # Normalize ref_text - strip whitespace, convert empty to None - normalized_ref_text = ref_text.strip() if ref_text else None - - wavs, sr = self.base_model.generate_voice_clone( - text=text, - language=language, - ref_audio=ref_audio, - ref_text=normalized_ref_text, - max_new_tokens=2048, - ) - else: - # Basic TTS path - use CustomVoice model with preset speaker - speaker = speaker or "Vivian" - wavs, sr = self.custom_voice_model.generate_custom_voice( - text=text, - language=language, - speaker=speaker, - ) - - # Convert to WAV bytes - wav_buffer = io.BytesIO() - sf.write(wav_buffer, wavs[0], sr, format='WAV') - wav_buffer.seek(0) - - return wav_buffer.read(), sr - - def get_info(self) -> dict: - return { - "backend": "qwen3-tts", - "custom_voice_model": self.custom_voice_model_name, - "base_model": self.base_model_name, - "supports_voice_cloning": True, - "supports_ref_text": True, - "supports_voice_prompt": True, - "device": "cuda" if torch.cuda.is_available() else "cpu", - } - - def get_speakers(self) -> List[str]: - """Return available preset speakers from CustomVoice model.""" - if self.custom_voice_model is None: - return [] - try: - return self.custom_voice_model.get_supported_speakers() - except Exception: - # Fallback to known speakers - return ["Vivian", "Ryan", "Sophia", "Isabella", "Evan", "Lily"] - - def extract_voice_prompt( - self, - ref_audio: Tuple[Any, int], - ref_text: Optional[str] = None, - ) -> str: - """Extract a reusable voice prompt from reference audio.""" - if self.base_model is None: - raise RuntimeError("Base model not loaded") - - # Normalize ref_text - strip whitespace, convert empty to None - normalized_ref_text = ref_text.strip() if ref_text else None - - # Use the Base model's create_voice_clone_prompt method - voice_prompt = self.base_model.create_voice_clone_prompt( - ref_audio=ref_audio, - ref_text=normalized_ref_text, - ) - - # Serialize using torch.save which handles tensors natively - buffer = io.BytesIO() - torch.save(voice_prompt, buffer) - buffer.seek(0) - return base64.b64encode(buffer.getvalue()).decode('utf-8') - - def synthesize_with_prompt( - self, - text: str, - voice_prompt: str, - language: str = "English", - speed: float = 1.0, - ) -> Tuple[bytes, int]: - """Synthesize speech using a pre-extracted voice prompt.""" - if self.base_model is None: - raise RuntimeError("Base model not loaded") - - # Decode the voice prompt using torch.load - buffer = io.BytesIO(base64.b64decode(voice_prompt)) - prompt_tensor = torch.load(buffer, map_location=self.device, weights_only=True) - - # Generate using the cached voice prompt - wavs, sr = self.base_model.generate_voice_clone( - text=text, - language=language, - voice_clone_prompt=prompt_tensor, - max_new_tokens=2048, - ) - - # Convert to WAV bytes - wav_buffer = io.BytesIO() - sf.write(wav_buffer, wavs[0], sr, format='WAV') - wav_buffer.seek(0) - - return wav_buffer.read(), sr +from backends.tts import TTSBackend, get_tts_backend, TTS_BACKENDS +from backends.face import FaceBackend, get_face_backend, FACE_BACKENDS +from backends.transcription import TranscriptionBackend, get_transcription_backend, TRANSCRIPTION_BACKENDS # ============================================================================= -# Backend Registry +# Backend Instances # ============================================================================= -BACKENDS = { - "qwen3-tts": Qwen3TTSBackend, -} - -def get_backend(name: str) -> TTSBackend: - """Get a backend instance by name.""" - if name not in BACKENDS: - raise ValueError(f"Unknown backend: {name}. Available: {list(BACKENDS.keys())}") - return BACKENDS[name]() - - -# ============================================================================= -# FastAPI Application -# ============================================================================= - -backend: Optional[TTSBackend] = None +tts_backend: Optional[TTSBackend] = None +face_backend: Optional[FaceBackend] = None +transcription_backend: Optional[TranscriptionBackend] = None @asynccontextmanager async def lifespan(app: FastAPI): - """Load model on startup.""" - global backend - - backend_name = os.environ.get("TTS_BACKEND", "qwen3-tts") - backend = get_backend(backend_name) - backend.load() - + """Load backends on startup based on environment config.""" + global tts_backend, face_backend, transcription_backend + + # Load TTS backend (always enabled) + tts_name = os.environ.get("TTS_BACKEND", "qwen3-tts") + print(f"\n=== Loading TTS backend: {tts_name} ===") + tts_backend = get_tts_backend(tts_name) + tts_backend.load() + + # Load Face backend (optional, enabled by default) + if os.environ.get("FACE_ENABLED", "true").lower() == "true": + face_name = os.environ.get("FACE_BACKEND", "insightface") + print(f"\n=== Loading Face backend: {face_name} ===") + face_backend = get_face_backend(face_name) + face_backend.load() + + # Load Transcription backend (optional, enabled by default) + if os.environ.get("TRANSCRIPTION_ENABLED", "true").lower() == "true": + transcription_name = os.environ.get("TRANSCRIPTION_BACKEND", "whisper") + print(f"\n=== Loading Transcription backend: {transcription_name} ===") + transcription_backend = get_transcription_backend(transcription_name) + transcription_backend.load() + + print("\n=== Studio Server Ready ===\n") yield # Cleanup - backend = None + if tts_backend: + tts_backend.unload() + if face_backend: + face_backend.unload() + if transcription_backend: + transcription_backend.unload() app = FastAPI( - title="TTS Server", - description="Multi-model text-to-speech API with voice cloning support", - version="0.3.3", + title="Studio Server", + description="AI-powered studio utilities for video production: TTS, face embedding, transcription", + version="0.4.0", lifespan=lifespan, ) +# ============================================================================= +# Health & Info Endpoints +# ============================================================================= + @app.get("/health") async def health(): """Health check endpoint.""" - return {"status": "ok", "backend": backend.get_info() if backend else None} + backends = {} + if tts_backend: + backends["tts"] = tts_backend.get_info() + if face_backend: + backends["face"] = face_backend.get_info() + if transcription_backend: + backends["transcription"] = transcription_backend.get_info() + + return { + "status": "ok", + "backends": backends, + } @app.get("/v1/models") async def list_models(): - """List available TTS backends.""" + """List available backends by category.""" return { - "models": [ - {"id": name, "object": "model"} - for name in BACKENDS.keys() - ] + "tts": list(TTS_BACKENDS.keys()), + "face": list(FACE_BACKENDS.keys()), + "transcription": list(TRANSCRIPTION_BACKENDS.keys()), } -@app.get("/v1/speakers") +# ============================================================================= +# TTS Endpoints +# ============================================================================= + +@app.get("/v1/tts/speakers") async def list_speakers(): """List available preset speakers for basic TTS.""" - if backend is None: - raise HTTPException(status_code=503, detail="Model not loaded") - return {"speakers": backend.get_speakers()} + if tts_backend is None: + raise HTTPException(status_code=503, detail="TTS backend not loaded") + return {"speakers": tts_backend.get_speakers()} -@app.post("/v1/voice/extract") -async def extract_voice_prompt_endpoint( +@app.post("/v1/tts/extract") +async def extract_voice_prompt( ref_audio: UploadFile = File(..., description="Reference audio for voice extraction"), ref_text: Optional[str] = Form(None, description="Transcript of reference audio"), ): """ Extract a reusable voice prompt from reference audio. - The returned `voice_prompt` can be cached and reused with `/v1/audio/speech` + The returned `voice_prompt` can be cached and reused with `/v1/tts/synthesize` to avoid re-processing the reference audio on every request. - - **Returns:** - - `voice_prompt`: Base64-encoded voice embedding (store this) - - `format`: Always "base64-numpy" for this backend """ - if backend is None: - raise HTTPException(status_code=503, detail="Model not loaded") + if tts_backend is None: + raise HTTPException(status_code=503, detail="TTS backend not loaded") try: - # Read reference audio audio_bytes = await ref_audio.read() audio_buffer = io.BytesIO(audio_bytes) audio_data, sample_rate = sf.read(audio_buffer) ref_audio_data = (audio_data, sample_rate) - # Extract voice prompt - voice_prompt = backend.extract_voice_prompt( + voice_prompt = tts_backend.extract_voice_prompt( ref_audio=ref_audio_data, ref_text=ref_text, ) @@ -383,14 +158,14 @@ async def extract_voice_prompt_endpoint( raise HTTPException(status_code=500, detail=str(e)) -@app.post("/v1/audio/speech") +@app.post("/v1/tts/synthesize") async def synthesize_speech( text: str = Form(..., description="Text to synthesize"), language: str = Form("English", description="Target language"), - speaker: Optional[str] = Form(None, description="Preset speaker for basic TTS (e.g., Vivian, Ryan)"), + speaker: Optional[str] = Form(None, description="Preset speaker for basic TTS"), ref_audio: Optional[UploadFile] = File(None, description="Reference audio for voice cloning"), ref_text: Optional[str] = Form(None, description="Transcript of reference audio"), - voice_prompt: Optional[str] = Form(None, description="Pre-extracted voice prompt from /v1/voice/extract"), + voice_prompt: Optional[str] = Form(None, description="Pre-extracted voice prompt"), speed: float = Form(1.0, description="Speech speed multiplier"), ): """ @@ -399,32 +174,27 @@ async def synthesize_speech( **Basic TTS:** Just provide `text` and optionally `speaker`. **Voice Cloning (on-the-fly):** Provide `ref_audio` and `ref_text`. - The `ref_text` should be the exact transcript of the reference audio. - **Voice Cloning (cached):** Provide `voice_prompt` from `/v1/voice/extract`. - This is faster as it skips re-processing the reference audio. + **Voice Cloning (cached):** Provide `voice_prompt` from `/v1/tts/extract`. """ - if backend is None: - raise HTTPException(status_code=503, detail="Model not loaded") + if tts_backend is None: + raise HTTPException(status_code=503, detail="TTS backend not loaded") try: - # Priority: voice_prompt > ref_audio > speaker if voice_prompt: - # Use pre-extracted voice prompt - wav_bytes, sample_rate = backend.synthesize_with_prompt( + wav_bytes, sample_rate = tts_backend.synthesize_with_prompt( text=text, voice_prompt=voice_prompt, language=language, speed=speed, ) elif ref_audio: - # On-the-fly voice cloning audio_bytes = await ref_audio.read() audio_buffer = io.BytesIO(audio_bytes) audio_data, sample_rate = sf.read(audio_buffer) ref_audio_data = (audio_data, sample_rate) - wav_bytes, sample_rate = backend.synthesize( + wav_bytes, sample_rate = tts_backend.synthesize( text=text, language=language, speaker=speaker, @@ -433,8 +203,7 @@ async def synthesize_speech( speed=speed, ) else: - # Basic TTS with preset speaker - wav_bytes, sample_rate = backend.synthesize( + wav_bytes, sample_rate = tts_backend.synthesize( text=text, language=language, speaker=speaker, @@ -458,6 +227,129 @@ async def synthesize_speech( raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# Face Embedding Endpoints +# ============================================================================= + +@app.post("/v1/face/embed") +async def extract_face_embedding( + image: UploadFile = File(..., description="Image containing a face"), + return_bbox: bool = Form(False, description="Include bounding box in response"), +): + """ + Extract face embedding from an image. + + Returns a 512-dimensional embedding compatible with IP-Adapter FaceID + for maintaining face consistency across video generation (PerformerDNA). + """ + if face_backend is None: + raise HTTPException(status_code=503, detail="Face backend not loaded") + + try: + image_bytes = await image.read() + result = face_backend.extract_embedding(image_bytes, return_bbox=return_bbox) + return JSONResponse(result) + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/face/embed-all") +async def extract_all_face_embeddings( + image: UploadFile = File(..., description="Image containing faces"), + max_faces: int = Form(5, description="Maximum number of faces to extract"), +): + """ + Extract face embeddings for all faces in an image. + + Useful when processing reference images that may contain multiple people. + Returns faces sorted by size (largest first). + """ + if face_backend is None: + raise HTTPException(status_code=503, detail="Face backend not loaded") + + try: + image_bytes = await image.read() + results = face_backend.extract_embeddings(image_bytes, max_faces=max_faces) + return JSONResponse({"faces": results, "count": len(results)}) + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +class CompareEmbeddingsRequest(BaseModel): + """Request body for comparing face embeddings.""" + embedding1: str + embedding2: str + + +@app.post("/v1/face/compare") +async def compare_face_embeddings(request: CompareEmbeddingsRequest): + """ + Compare two face embeddings for similarity. + + Returns a cosine similarity score between 0 and 1. + Scores above 0.6 typically indicate the same person. + """ + if face_backend is None: + raise HTTPException(status_code=503, detail="Face backend not loaded") + + try: + similarity = face_backend.compare_embeddings( + request.embedding1, + request.embedding2, + ) + return JSONResponse({ + "similarity": similarity, + "same_person": similarity > 0.6, + }) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Transcription Endpoints +# ============================================================================= + +@app.post("/v1/transcribe") +async def transcribe_audio( + audio: UploadFile = File(..., description="Audio file to transcribe"), + language: Optional[str] = Form(None, description="Language code (auto-detect if empty)"), + word_timestamps: bool = Form(True, description="Include word-level timestamps"), +): + """ + Transcribe audio to text with word-level timestamps. + + Returns the full transcript and per-word timing information + for lip-sync alignment. + """ + if transcription_backend is None: + raise HTTPException(status_code=503, detail="Transcription backend not loaded") + + try: + audio_bytes = await audio.read() + result = transcription_backend.transcribe( + audio_bytes, + language=language, + word_timestamps=word_timestamps, + ) + return JSONResponse(result.to_dict()) + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Main +# ============================================================================= + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/tests/test_server.py b/tests/test_server.py index ef2857f..f5e9e00 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,7 @@ """ -Tests for TTS Server. +Tests for Studio Server. -These tests mock the Qwen3-TTS model to test server logic without GPU. +These tests mock the backends to test server logic without GPU. """ import io @@ -56,25 +56,28 @@ def mock_torch(): @pytest.fixture def client(mock_qwen3_model, mock_torch): - """Create test client with mocked model.""" - with patch('server.Qwen3TTSBackend.load') as mock_load: - # Import server after patching + """Create test client with mocked TTS backend only.""" + with patch('backends.tts.Qwen3TTSBackend.load') as mock_load: + # Import after patching import server + from backends.tts import Qwen3TTSBackend # Create backend and inject mocked models - backend = server.Qwen3TTSBackend() + backend = Qwen3TTSBackend() backend.custom_voice_model = mock_qwen3_model backend.base_model = mock_qwen3_model backend.device = "cpu" backend.dtype = None - # Set the global backend - server.backend = backend + # Set the global backend (disable face and transcription) + server.tts_backend = backend + server.face_backend = None + server.transcription_backend = None yield TestClient(server.app) # Cleanup - server.backend = None + server.tts_backend = None class TestHealthEndpoint: @@ -85,8 +88,8 @@ def test_health_returns_ok(self, client): assert response.status_code == 200 data = response.json() assert data["status"] == "ok" - assert data["backend"] is not None - assert data["backend"]["backend"] == "qwen3-tts" + assert "backends" in data + assert data["backends"]["tts"]["backend"] == "qwen3-tts" class TestModelsEndpoint: @@ -96,28 +99,30 @@ def test_list_models(self, client): response = client.get("/v1/models") assert response.status_code == 200 data = response.json() - assert "models" in data - assert len(data["models"]) >= 1 - assert any(m["id"] == "qwen3-tts" for m in data["models"]) + assert "tts" in data + assert "face" in data + assert "transcription" in data + assert "qwen3-tts" in data["tts"] class TestSpeakersEndpoint: - """Tests for /v1/speakers endpoint.""" + """Tests for /v1/tts/speakers endpoint.""" def test_list_speakers(self, client): - response = client.get("/v1/speakers") + response = client.get("/v1/tts/speakers") assert response.status_code == 200 data = response.json() assert "speakers" in data assert "Vivian" in data["speakers"] + class TestSynthesizeEndpoint: - """Tests for /v1/audio/speech endpoint.""" + """Tests for /v1/tts/synthesize endpoint.""" def test_basic_tts(self, client, mock_qwen3_model): response = client.post( - "/v1/audio/speech", + "/v1/tts/synthesize", data={"text": "Hello world", "language": "English"} ) assert response.status_code == 200 @@ -131,7 +136,7 @@ def test_basic_tts(self, client, mock_qwen3_model): def test_basic_tts_with_speaker(self, client, mock_qwen3_model): response = client.post( - "/v1/audio/speech", + "/v1/tts/synthesize", data={ "text": "Hello world", "language": "English", @@ -153,7 +158,7 @@ def test_voice_cloning(self, client, mock_qwen3_model): wav_buffer.seek(0) response = client.post( - "/v1/audio/speech", + "/v1/tts/synthesize", data={ "text": "Clone my voice", "language": "English", @@ -170,17 +175,10 @@ def test_voice_cloning(self, client, mock_qwen3_model): assert call_args.kwargs["text"] == "Clone my voice" assert call_args.kwargs["ref_text"] == "This is a reference transcript" - def test_empty_text_fails(self, client): - response = client.post( - "/v1/audio/speech", - data={"text": ""} - ) - # FastAPI validates Form(...) as required - assert response.status_code == 422 class TestVoiceExtractEndpoint: - """Tests for /v1/voice/extract endpoint.""" + """Tests for /v1/tts/extract endpoint.""" def test_extract_voice_prompt(self, client, mock_qwen3_model): # Create a minimal WAV file for testing @@ -192,7 +190,7 @@ def test_extract_voice_prompt(self, client, mock_qwen3_model): wav_buffer.seek(0) response = client.post( - "/v1/voice/extract", + "/v1/tts/extract", data={ "ref_text": "This is a reference transcript", }, @@ -219,12 +217,13 @@ def test_extract_without_ref_text(self, client, mock_qwen3_model): wav_buffer.seek(0) response = client.post( - "/v1/voice/extract", + "/v1/tts/extract", files={"ref_audio": ("ref.wav", wav_buffer, "audio/wav")} ) assert response.status_code == 200 + class TestSynthesizeWithPrompt: """Tests for synthesizing with pre-extracted voice prompt.""" @@ -240,7 +239,7 @@ def test_synthesize_with_voice_prompt(self, client, mock_qwen3_model): wav_buffer.seek(0) extract_response = client.post( - "/v1/voice/extract", + "/v1/tts/extract", data={"ref_text": "Reference transcript"}, files={"ref_audio": ("ref.wav", wav_buffer, "audio/wav")} ) @@ -249,7 +248,7 @@ def test_synthesize_with_voice_prompt(self, client, mock_qwen3_model): # Now use the voice prompt for synthesis response = client.post( - "/v1/audio/speech", + "/v1/tts/synthesize", data={ "text": "Hello with cached voice", "voice_prompt": voice_prompt, @@ -271,10 +270,34 @@ class TestBackendInfo: def test_get_info(self, client): response = client.get("/health") data = response.json() - backend_info = data["backend"] + backend_info = data["backends"]["tts"] assert backend_info["supports_voice_cloning"] is True assert backend_info["supports_ref_text"] is True assert backend_info["supports_voice_prompt"] is True assert "custom_voice_model" in backend_info assert "base_model" in backend_info + + +class TestFaceEndpoints: + """Tests for /v1/face/* endpoints when backend is disabled.""" + + def test_face_embed_returns_503_when_disabled(self, client): + """Face endpoints should return 503 when backend is not loaded.""" + response = client.post( + "/v1/face/embed", + files={"image": ("test.jpg", b"fake image data", "image/jpeg")} + ) + assert response.status_code == 503 + + +class TestTranscriptionEndpoints: + """Tests for /v1/transcribe endpoint when backend is disabled.""" + + def test_transcribe_returns_503_when_disabled(self, client): + """Transcription endpoint should return 503 when backend is not loaded.""" + response = client.post( + "/v1/transcribe", + files={"audio": ("test.wav", b"fake audio data", "audio/wav")} + ) + assert response.status_code == 503 From 08ad253be06aa83f0f289b02ac4e94c356d2a55b Mon Sep 17 00:00:00 2001 From: Will Griffin Date: Mon, 26 Jan 2026 14:48:13 -0700 Subject: [PATCH 2/2] fix: add build-essential for insightface compilation --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index b5b2db7..1e29f7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime WORKDIR /app -# Install system dependencies +# Install system dependencies (including build tools for insightface) RUN apt-get update && apt-get install -y --no-install-recommends \ libsndfile1 \ ffmpeg \ @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libsox-dev \ libgl1-mesa-glx \ libglib2.0-0 \ + build-essential \ && rm -rf /var/lib/apt/lists/* # Install Python dependencies