diff --git a/.gitignore b/.gitignore index 33b7875..344d83d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ __pycache__ tmp -cache \ No newline at end of file +cache +mlx_models/ +asset/ +config/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..185a1e0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: https://github.com/codespell-project/codespell + rev: v2.2.5 # Specify the latest stable version + hooks: + - id: codespell + args: ["-w"] # The -w flag tells codespell to automatically apply fixes + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.1.1 # Replace with the latest stable version of ruff-pre-commit + hooks: + - id: ruff + args: ["--fix"] # This will automatically fix linting issues diff --git a/LLM/chat.py b/LLM/chat.py index bc8ac4f..6f5569d 100644 --- a/LLM/chat.py +++ b/LLM/chat.py @@ -6,7 +6,7 @@ class Chat: def __init__(self, size): self.size = size self.init_chat_message = None - # maxlen is necessary pair, since a each new step we add an prompt and assitant answer + # maxlen is necessary pair, since a each new step we add an prompt and assistant answer self.buffer = [] def append(self, item): diff --git a/LLM/language_model.py b/LLM/language_model.py index ddeb34b..202e007 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -68,7 +68,7 @@ def setup( if init_chat_role: if not init_chat_prompt: raise ValueError( - "An initial promt needs to be specified when setting init_chat_role." + "An initial prompt needs to be specified when setting init_chat_role." ) self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) self.user_role = user_role @@ -111,7 +111,7 @@ def warmup(self): ) def process(self, prompt): - logger.debug("infering language model...") + logger.debug("inferring language model...") language_code = None if isinstance(prompt, tuple): prompt, language_code = prompt diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 87812c5..8269b3b 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -42,7 +42,7 @@ def setup( if init_chat_role: if not init_chat_prompt: raise ValueError( - "An initial promt needs to be specified when setting init_chat_role." + "An initial prompt needs to be specified when setting init_chat_role." ) self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) self.user_role = user_role @@ -68,7 +68,7 @@ def warmup(self): ) def process(self, prompt): - logger.debug("infering language model...") + logger.debug("inferring language model...") language_code = None if isinstance(prompt, tuple): diff --git a/LLM/openai_api_language_model.py b/LLM/openai_api_language_model.py index dcbabe0..2866867 100644 --- a/LLM/openai_api_language_model.py +++ b/LLM/openai_api_language_model.py @@ -44,7 +44,7 @@ def setup( if init_chat_role: if not init_chat_prompt: raise ValueError( - "An initial promt needs to be specified when setting init_chat_role." + "An initial prompt needs to be specified when setting init_chat_role." ) self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) self.user_role = user_role @@ -54,7 +54,7 @@ def setup( def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") start = time.time() - response = self.client.chat.completions.create( + _ = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": "You are a helpful assistant"}, diff --git a/README.md b/README.md index 02c1676..1d517ba 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ This repository implements a speech-to-speech cascaded pipeline consisting of th 2. **Speech to Text (STT)** 3. **Language Model (LM)** 4. **Text to Speech (TTS)** +5. **Speech to Visemes (STV)** ### Modularity The pipeline provides a fully open and modular approach, with a focus on leveraging models available through the Transformers library on the Hugging Face hub. The code is designed for easy modification, and we already support device-specific and external library implementations: @@ -50,6 +51,9 @@ The pipeline provides a fully open and modular approach, with a focus on leverag - [MeloTTS](https://github.com/myshell-ai/MeloTTS) - [ChatTTS](https://github.com/2noise/ChatTTS?tab=readme-ov-file) +**STV** +- [Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/en/model_doc/wav2vec2_phoneme) + [Phoneme to viseme mapping](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis-viseme?tabs=visemeid&pivots=programming-language-python#map-phonemes-to-visemes) + ## Setup Clone the repository: @@ -80,7 +84,7 @@ The pipeline can be run in two ways: - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. - **Local approach**: Runs locally. -### Recommanded setup +### Recommended setup ### Server/Client Approach @@ -120,7 +124,7 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install ### Recommended usage with Cuda -Leverage Torch Compile for Whisper and Parler-TTS. **The usage of Parler-TTS allows for audio output streaming, futher reducing the overeall latency** 🚀: +Leverage Torch Compile for Whisper and Parler-TTS. **The usage of Parler-TTS allows for audio output streaming, further reducing the overeall latency** 🚀: ```bash python s2s_pipeline.py \ @@ -216,6 +220,13 @@ For example: --lm_model_name google/gemma-2b-it ``` + +### STV parameters +See [Wav2Vec2STVHandlerArguments](arguments_classes/w2v_stv_arguments.py) class. Notably: +- `stv_model_name` is by default `bookbot/wav2vec2-ljspeech-gruut` and has been chosen because accurate and fast enough +- `stv_skip`, flag it to `True` if you don't need visemes + + ### Generation parameters Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed. diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 53b6b5a..2f2d657 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -4,7 +4,6 @@ from lightning_whisper_mlx import LightningWhisperMLX import numpy as np from rich.console import Console -from copy import copy import torch logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ def warmup(self): _ = self.model.transcribe(dummy_input)["text"].strip() def process(self, spoken_prompt): - logger.debug("infering whisper...") + logger.debug("inferring whisper...") global pipeline_start pipeline_start = perf_counter() diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py index 99fd6ac..09d481b 100644 --- a/STT/paraformer_handler.py +++ b/STT/paraformer_handler.py @@ -28,7 +28,6 @@ def setup( device="cuda", gen_kwargs={}, ): - print(model_name) if len(model_name.split("/")) > 1: model_name = model_name.split("/")[-1] self.device = device @@ -45,7 +44,7 @@ def warmup(self): _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "") def process(self, spoken_prompt): - logger.debug("infering paraformer...") + logger.debug("inferring paraformer...") global pipeline_start pipeline_start = perf_counter() diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 0930087..88c578f 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -109,7 +109,7 @@ def warmup(self): ) def process(self, spoken_prompt): - logger.debug("infering whisper...") + logger.debug("inferring whisper...") global pipeline_start pipeline_start = perf_counter() diff --git a/STV/phoneme_viseme_map.json b/STV/phoneme_viseme_map.json new file mode 100644 index 0000000..8c91531 --- /dev/null +++ b/STV/phoneme_viseme_map.json @@ -0,0 +1 @@ +{"æ":[1],"ə":[1],"ʌ":[1],"ɑ":[2],"ɔ":[3],"ɛ":[4],"ʊ":[4],"ɝ":[5],"j":[6],"i":[6],"ɪ":[6],"w":[7],"u":[7],"o":[8],"aʊ":[9],"ɔɪ":[10],"aɪ":[11],"h":[12],"ɹ":[13],"l":[14],"s":[15],"z":[15],"ʃ":[16],"tʃ":[19,16],"dʒ":[19,16],"ʒ":[16],"ð":[17],"f":[18],"v":[18],"d":[19],"t":[19],"n":[19],"θ":[19],"k":[20],"g":[20],"ŋ":[20],"p":[21],"b":[21],"m":[21]," ":[0],"a":[2],"aː":[2],"iː":[6],"uː":[7],"dˤ":[19],"q":[20],"tˤ":[19],"ʔ":[19],"ħ":[12],"ðˤ":[17],"ɣ":[20],"x":[12],"sˤ":[15],"r":[13],"ʕ":[12],"j͡a":[6,2],"ɤ":[1],"j͡u":[6,7],"t͡s":[19,15],"zʲ":[15],"lʲ":[14],"nʲ":[19],"d͡ʒ":[19,16],"mʲ":[21],"tʲ":[19],"rʲ":[13],"pʲ":[21],"dʲ":[19],"vʲ":[18],"sʲ":[15],"bʲ":[21],"kʲ":[20],"gʲ":[20],"fʲ":[18],"t͡ʃ":[19,16],"d͡z":[19,15],"e":[4],"β":[21],"ʎ":[14],"ɲ":[19],"ɾ":[19],"ɛː":[4],"oː":[8],"o͡ʊ̯":[8,4],"a͡ʊ":[2,4],"ɛ͡ʊ̯":[4,4],"c":[16],"ɟ":[16],"r̝":[13],"ɦ":[12],"ɱ":[21],"r̝̊":[13],"ɑː":[2],"ɒ":[2],"ɒː":[2],"ɔː":[3],"ɐ":[4],"æː":[1],"ø":[1],"øː":[1],"eː":[4],"œ":[4],"œː":[4],"y":[4],"yː":[4],"kʰ":[20],"pʰ":[21],"ʁ":[13],"ɐ̯":[4],"ɕ":[16],"ʏ":[7],"ai":[2,6],"au":[2,7],"ɔy":[3,4],"ɔʏ̯":[3,4],"ʤ":[16],"pf":[21,18],"ʀ":[13],"ts":[19,15],"ç":[12],"ʝ":[12],"ɛə":[4,1],"ɜː":[5],"eɪ":[4,6],"ɪə":[6,1],"əʊ":[1,4],"ʊə":[4,1],"iy":[6],"oʊ":[8,4],"ju":[6,7],"ɪɹ":[6,13],"ɛɹ":[4,13],"ʊɹ":[4,13],"aɪɹ":[11,13],"aʊɹ":[9,13],"ɔɹ":[3,13],"ɑɹ":[2,13],"ɚ":[1],"j͡j":[6,6],"ɑ͡i":[2,6],"ɑ͡u":[2,7],"æ͡i":[1,6],"æ͡y":[1,4],"e͡i":[4,6],"ø͡i":[1,6],"ø͡y":[1,4],"e͡u":[4,7],"e͡y":[4,4],"i͡e":[6,4],"i͡u":[6,7],"i͡y":[6,4],"o͡i":[8,6],"o͡u":[8,7],"u͡i":[7,6],"u͡o":[7,8],"y͡ø":[4,1],"y͡i":[4,6],"ʋ":[18],"ɑ̃":[2],"ɛ̃":[4],"ɔ̃":[3],"œ̃":[4],"ɥ":[7],"n‿":[19],"t‿":[19],"z‿":[15],"ʨ":[16],"ʥ":[16],"bː":[21],"dː":[19],"ɟː":[16],"d͡ʒː":[19,16],"dz":[19,15],"dzː":[19,15],"fː":[18],"gː":[20],"hː":[12],"jː":[6],"ɲː":[19],"kː":[20],"lː":[14],"mː":[21],"nː":[19],"pː":[21],"rː":[13],"sː":[15],"ʃː":[16],"tː":[19],"cː":[16],"t͡sː":[19,15],"t͡ʃː":[19,16],"vː":[18],"ɰ":[20],"zː":[15],"ʒː":[16],"a͡i":[2,6],"ɔ͡i":[3,6],"ɛj":[4,6],"ɛu":[4,7],"ei":[4,6],"eu":[4,7],"ɔj":[3,6],"oi":[8,6],"ou":[8,7],"ʧ":[16],"tʃː":[19,16],"ʣ":[15],"ʣː":[15],"ʤː":[16],"ʎː":[14],"ʦ":[15],"ʦː":[15],"ɯ":[6],"ɰ͡i":[20,6],"w͡a":[7,2],"w͡ɛ":[7,4],"w͡e":[7,4],"w͡i":[7,6],"w͡ʌ":[7,1],"j͡ɛ":[6,4],"j͡e":[6,4],"j͡ʌ":[6,1],"j͡o":[6,8],"b̥":[21],"t͡ɕʰ":[19,16],"d̥":[19],"g̥":[20],"d͡ʑ":[19,16],"d͡ʑ̥":[19,16],"t͡ɕ":[19,16],"sʰ":[15],"tʰ":[19],"ʉ":[6],"ʉː":[6],"æɪ":[1,6],"æʉ":[1,6],"ɑɪ":[2,6],"œʏ":[4,7],"ɔʏ":[3,7],"ʉɪ":[6,6],"ʂ":[15],"ɖ":[19],"ɭ":[14],"ɳ":[19],"ʈ":[19],"ɛ͡i":[4,6],"œ͡y":[4,4],"χ":[12],"ɨ":[6],"t͡ʂ":[19,15],"d̪ʲ":[19],"ɡ":[20],"d͡ʐ":[19,15],"l̪ʲ":[14],"t̪ʲ":[19],"xʲ":[12],"ʑ":[16],"ĩ":[6],"ũ":[7],"ɐ̃":[4],"ẽ":[4],"õ":[8],"w̃":[7],"j̃":[6],"ɐj":[4,6],"ɐ̃j̃":[4,6],"ɐ̃w̃":[4,7],"ɐ͡w":[4,7],"a͡j":[2,6],"ɔ͡j":[3,6],"a͡w":[2,7],"ɛ͡w":[4,7],"e͡w":[4,7],"i͡w":[6,7],"o͡j":[8,6],"õj̃":[8,6],"u͡j":[7,6],"ũj̃":[7,6],"ɫ":[14],"e̯a":[4,2],"e̯o":[4,8],"o̯a":[8,2],"d͡ʒʲ":[19,16],"ʃʲ":[16],"t͡sʲ":[19,15],"t͡ʃʲ":[19,16],"ʒʲ":[16],"ʐ":[15],"ɕː":[16],"i͡a":[6,2],"r̩":[13],"r̩ː":[13],"l̩":[14],"l̩ː":[14],"ɴ":[19],"u̯":[7],"i̯":[6],"dˡ":[19],"dn":[19,19],"tˡ":[19],"tn":[19,19],"ʍ":[7],"a‿u":[2,7],"ɶ":[8],"ɵ":[1],"ɧ":[16],"ia":[6,2],"əː":[1],"ua":[7,2],"ɯː":[6],"ɯa":[6,2],"tɕʰ":[19,16],"œ͡ɟ":[4,16],"i͡ɟ":[6,16],"o͡ɟ":[8,16],"u͡ɟ":[7,16],"ɯ͡ɟ":[6,16],"y͡ɟ":[4,16],"ɮ":[6],"u͡a":[7,2],"ɛ̆j":[4,6],"ə͡j":[1,6],"i͡e͡w":[6,4,7],"ɨ͡ə":[6,1],"ie":[6,4],"ăw":[2,7],"ăj":[2,6],"ɨ͡ə͡j":[6,1,6],"ɔ̆w":[3,7],"ɨ͡w":[6,7],"e͡j":[4,6],"ɨ͡ʌ͡w":[6,1,7],"ɨ͡j":[6,6],"iə":[6,1],"a͡ʲ":[2],"ɓ":[21],"ɗ":[19]} \ No newline at end of file diff --git a/STV/w2v_stv_handler.py b/STV/w2v_stv_handler.py new file mode 100644 index 0000000..20c2cef --- /dev/null +++ b/STV/w2v_stv_handler.py @@ -0,0 +1,253 @@ +import json +import logging +import time +from typing import Any, Dict, Generator, List + +import numpy as np +from rich.console import Console +from transformers import pipeline + +from baseHandler import BaseHandler + +logger = logging.getLogger(__name__) +console = Console() + + +class Wav2Vec2STVHandler(BaseHandler): + """ + Handles the Speech-To-Viseme generation using a Wav2Vec2 model for automatic + speech recognition (ASR) and phoneme mapping to visemes. + + Attributes: + MIN_AUDIO_LENGTH (float): Minimum length of audio (in seconds) required + for phoneme extraction. + """ + + MIN_AUDIO_LENGTH = 0.5 # Minimum audio length in seconds for phoneme extraction + + def setup( + self, + should_listen: bool, + model_name: str = "bookbot/wav2vec2-ljspeech-gruut", + blocksize: int = 512, + device: str = "cuda", + skip: bool = False, + gen_kwargs: Dict[str, Any] = {}, # Not used + ) -> None: + """ + Initializes the handler by loading the ASR model and phoneme-to-viseme map. + + Args: + should_listen (bool): Flag indicating whether the speech-to-speech pipeline should start + listening to the user or not. + model_name (str): Name of the ASR model to use. + Defaults to "bookbot/wav2vec2-ljspeech-gruut". + blocksize (int): Size of each audio block when processing audio. + Defaults to 512. + device (str): Device to run the model on ("cuda", "mps", or "cpu"). + Defaults to "cuda". + skip (bool): If True, the speech-to-viseme process is skipped. + Defaults to False. + gen_kwargs (dict): Additional parameters for speech generation. + + Returns: + None + """ + self.device = device + self.gen_kwargs = gen_kwargs + self.blocksize = blocksize + self.should_listen = should_listen + self.skip = skip + + # Load phoneme-to-viseme map from the JSON file + # inspired by https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-ssml-phonetic-sets + phoneme_viseme_map_file = "STV/phoneme_viseme_map.json" + with open(phoneme_viseme_map_file, "r") as f: + self.phoneme_viseme_map = json.load(f) + + # Initialize the ASR pipeline using the specified model and device + self.asr_pipeline = pipeline( + "automatic-speech-recognition", + model=model_name, + device=device, + torch_dtype="auto", + ) + self.expected_sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate + + # Initialize an empty dictionary to store audio batch data + self.audio_batch = { + "waveform": np.array([]), + "sampling_rate": self.expected_sampling_rate, + } + self.text_batch = None + self.should_listen_flag = False + + self.warmup() # Perform model warmup + + def warmup(self) -> None: + """Warms up the model with dummy input to prepare it for inference. + + Returns: + None + """ + logger.info(f"Warming up {self.__class__.__name__}") + start_time = time.time() + + # Create dummy input for warmup inference + dummy_input = np.random.randn(self.blocksize).astype(np.int16) + _ = self.speech_to_visemes(dummy_input) + + warmup_time = time.time() - start_time + logger.info( + f"{self.__class__.__name__}: warmed up in {warmup_time:.4f} seconds!" + ) + + def speech_to_visemes(self, audio: Any) -> List[Dict[str, Any]]: + """ + Converts speech audio to visemes by performing Automatic Speech Recognition (ASR) + and mapping phonemes to visemes. + + Args: + audio (Any): The input audio data. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing mapped visemes + and their corresponding timestamps. + + Note: + Heuristically, the input audio should be at least 0.5 seconds long for proper phoneme extraction. + """ + + def _map_phonemes_to_visemes( + data: Dict[str, Any], + ) -> List[Dict[str, Any]]: + """ + Maps extracted phonemes to their corresponding visemes based on a predefined map. + + Args: + data (Dict[str, Any]): Dictionary containing phoneme data where data['chunks'] + holds a list of phonemes and their timestamps. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with viseme IDs and their corresponding timestamps. + """ + viseme_list = [] + chunks = data.get("chunks", []) + + # Map each phoneme to corresponding visemes + for chunk in chunks: + phoneme = chunk.get("text", None) + timestamp = chunk.get("timestamp", None) + visemes = self.phoneme_viseme_map.get(phoneme, []) + + for viseme in visemes: + viseme_list.append({"viseme": viseme, "timestamp": timestamp}) + + return viseme_list + + # Perform ASR to extract phoneme data, including timestamps + try: + asr_result = self.asr_pipeline(audio, return_timestamps="char") + except Exception as e: + logger.error(f"ASR error: {e}") + return [] + # Map the phonemes obtained from ASR to visemes + return _map_phonemes_to_visemes(asr_result) + + def process(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]: + """ + Processes an audio file to generate visemes and output blocks of audio data + along with corresponding viseme data. + + Args: + data (Dict[str, Any]): Dictionary containing audio, text, and potentially additional information. + + Yields: + Dict: A dictionary containing audio waveform, and optionally viseme data, text, and potentially additional information. + """ + + if "sentence_end" in data and data["sentence_end"]: + self.should_listen_flag = True + if self.skip: # Skip viseme extraction if the flag is set + yield { + "audio": { + "waveform": data["audio"]["waveform"], + "sampling_rate": data["audio"]["sampling_rate"], + }, + "text": data["text"] if "text" in data else None, + } + else: + # Check if text data is present and save it for later + if "text" in data and data["text"] is not None: + self.text_batch = data["text"] + # Concatenate new audio data into the buffer if available and valid + if "audio" in data and data["audio"] is not None: + audio_data = data["audio"] + # Check if the sampling rate is valid and matches the expected one + if audio_data.get("sampling_rate", None) != self.expected_sampling_rate: + logger.error( + f"Expected sampling rate {self.expected_sampling_rate}, " + f"but got {audio_data['sampling_rate']}." + ) + return + # Append the waveform to the audio buffer + self.audio_batch["waveform"] = np.concatenate( + (self.audio_batch["waveform"], audio_data["waveform"]), axis=0 + ) + + # Ensure the total audio length is sufficient for phoneme extraction + if ( + len(self.audio_batch["waveform"]) / self.audio_batch["sampling_rate"] + < self.MIN_AUDIO_LENGTH + ): + return + else: + logger.debug("Starting viseme inference...") + + # Perform viseme inference using the accumulated audio batch + viseme_data = self.speech_to_visemes(self.audio_batch["waveform"]) + logger.debug("Viseme inference completed.") + + # Print the visemes and timestamps to the console + for viseme in viseme_data: + console.print( + f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}" + ) + + # Process the audio in chunks of the defined blocksize + self.audio_batch["waveform"] = self.audio_batch["waveform"].astype( + np.int16 + ) + for i in range(0, len(self.audio_batch["waveform"]), self.blocksize): + chunk_waveform = self.audio_batch["waveform"][ + i : i + self.blocksize + ] + padded_waveform = np.pad( + chunk_waveform, (0, self.blocksize - len(chunk_waveform)) + ) + + chunk_data = { + "audio": { + "waveform": padded_waveform, + "sample_rate": self.audio_batch["sampling_rate"], + } + } + + # Add text and viseme data only in the first chunk + if i == 0: + if self.text_batch: + chunk_data["text"] = self.text_batch + if viseme_data and len(viseme_data) > 0: + chunk_data["visemes"] = viseme_data + yield chunk_data + + # Reset the audio and text buffer after processing + self.audio_batch = { + "waveform": np.array([]), + "sampling_rate": self.expected_sampling_rate, + } + self.text_batch = "" + + if self.should_listen_flag: + self.should_listen.set() + self.should_listen_flag = False diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py index 6bdc6bf..6c177c4 100644 --- a/TTS/chatTTS_handler.py +++ b/TTS/chatTTS_handler.py @@ -17,13 +17,11 @@ class ChatTTSHandler(BaseHandler): def setup( self, - should_listen, device="cuda", gen_kwargs={}, # Unused stream=True, chunk_size=512, ): - self.should_listen = should_listen self.device = device self.model = ChatTTS.Chat() self.model.load(compile=False) # Doesn't work for me with True @@ -33,6 +31,7 @@ def setup( self.params_infer_code = ChatTTS.Chat.InferCodeParams( spk_emb=rnd_spk_emb, ) + self.output_sampling_rate = 16000 self.warmup() def warmup(self): @@ -40,6 +39,8 @@ def warmup(self): _ = self.model.infer("text") def process(self, llm_sentence): + if isinstance(llm_sentence, tuple): + llm_sentence, _ = llm_sentence # Ignore language console.print(f"[green]ASSISTANT: {llm_sentence}") if self.device == "mps": import time @@ -59,24 +60,62 @@ def process(self, llm_sentence): wavs = [np.array([])] for gen in wavs_gen: if gen[0] is None or len(gen[0]) == 0: - self.should_listen.set() - return - audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) - audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] - while len(audio_chunk) > self.chunk_size: - yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据 - audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据 - yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) + return { + "text": llm_sentence, + "sentence_end": True + } + + # Resample the audio to 16000 Hz + audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=self.output_sampling_rate) + # Ensure the audio is converted to mono (single channel) + if len(audio_chunk.shape) > 1: + audio_chunk = librosa.to_mono(audio_chunk) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + + # Loop through audio chunks, yielding dict for each chunk + for i in range(0, len(audio_chunk), self.chunk_size): + chunk_data = { + "audio": { + "waveform": np.pad( + audio_chunk[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), + ), + "sampling_rate": self.output_sampling_rate, + } + } + # Include text for the first chunk + if i == 0: + chunk_data["text"] = llm_sentence # Assuming llm_sentence is defined elsewhere + if i >= len(audio_chunk) - self.chunk_size: + # This is the last round + chunk_data["sentence_end"] = True + yield chunk_data else: wavs = wavs_gen if len(wavs[0]) == 0: - self.should_listen.set() - return - audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) + return { + "sentence_end": True + } + audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=self.output_sampling_rate) + # Ensure the audio is converted to mono (single channel) + if len(audio_chunk.shape) > 1: + audio_chunk = librosa.to_mono(audio_chunk) audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.chunk_size): - yield np.pad( - audio_chunk[i : i + self.chunk_size], - (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), - ) - self.should_listen.set() + chunk_data = { + "audio": { + "waveform": np.pad( + audio_chunk[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), + ), + "sampling_rate": self.output_sampling_rate, + } + } + # For the first chunk, include text + if i == 0: + chunk_data["text"] = llm_sentence + if i >= len(audio_chunk) - self.chunk_size: + # This is the last round + chunk_data["sentence_end"] = True + yield chunk_data diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 6dd50f1..be25007 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -28,18 +28,15 @@ "ko": "KR", } - class MeloTTSHandler(BaseHandler): def setup( self, - should_listen, - device="mps", + device="auto", language="en", speaker_to_id="en", gen_kwargs={}, # Unused blocksize=512, ): - self.should_listen = should_listen self.device = device self.language = language self.model = TTS( @@ -49,6 +46,8 @@ def setup( WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] ] self.blocksize = blocksize + self.output_sampling_rate = 16000 + self.warmup() def warmup(self): @@ -96,14 +95,27 @@ def process(self, llm_sentence): logger.error(f"Error in MeloTTSHandler: {e}") audio_chunk = np.array([]) if len(audio_chunk) == 0: - self.should_listen.set() - return - audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + return { + "text": llm_sentence, + "sentence_end": True + } + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=self.output_sampling_rate) audio_chunk = (audio_chunk * 32768).astype(np.int16) - for i in range(0, len(audio_chunk), self.blocksize): - yield np.pad( - audio_chunk[i : i + self.blocksize], - (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), - ) - self.should_listen.set() + for i in range(0, len(audio_chunk), self.blocksize): + chunk_data = { + "audio": { + "waveform": np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])) + ), + "sampling_rate": self.output_sampling_rate + } + } + # For the first chunk, include text + if i == 0: + chunk_data["text"] = llm_sentence + if i >= len(audio_chunk) - self.blocksize: + # This is the last round + chunk_data["sentence_end"] = True + yield chunk_data diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py index ac539c7..2b84e8d 100644 --- a/TTS/parler_handler.py +++ b/TTS/parler_handler.py @@ -14,7 +14,6 @@ from transformers.utils.import_utils import ( is_flash_attn_2_available, ) - torch._inductor.config.fx_graph_cache = True # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS torch._dynamo.config.cache_size_limit = 15 @@ -34,7 +33,6 @@ class ParlerTTSHandler(BaseHandler): def setup( self, - should_listen, model_name="ylacombe/parler-tts-mini-jenny-30H", device="cuda", torch_dtype="float16", @@ -48,7 +46,6 @@ def setup( play_steps_s=1, blocksize=512, ): - self.should_listen = should_listen self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.gen_kwargs = gen_kwargs @@ -77,6 +74,7 @@ def setup( self.model.forward = torch.compile( self.model.forward, mode=self.compile_mode, fullgraph=True ) + self.output_sampling_rate = 16000 self.warmup() @@ -180,12 +178,24 @@ def process(self, llm_sentence): logger.info( f"Time to first audio: {perf_counter() - pipeline_start:.3f}" ) - audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) + audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=self.output_sampling_rate) audio_chunk = (audio_chunk * 32768).astype(np.int16) - for i in range(0, len(audio_chunk), self.blocksize): - yield np.pad( - audio_chunk[i : i + self.blocksize], - (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), - ) - self.should_listen.set() + for i in range(0, len(audio_chunk), self.blocksize): + chunk_data = { + "audio": { + "waveform": np.pad( + audio_chunk[i : i + self.blocksize], + (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])) + ), + "sampling_rate": self.output_sampling_rate + } + } + # For the first chunk, include text + if i == 0: + chunk_data["text"] = llm_sentence + if i >= len(audio_chunk) - self.blocksize: + # This is the last round + chunk_data["sentence_end"] = True + + yield chunk_data diff --git a/arguments_classes/parler_tts_arguments.py b/arguments_classes/parler_tts_arguments.py index 5159432..b519751 100644 --- a/arguments_classes/parler_tts_arguments.py +++ b/arguments_classes/parler_tts_arguments.py @@ -36,7 +36,7 @@ class ParlerTTSHandlerArguments: tts_gen_max_new_tokens: int = field( default=512, metadata={ - "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" + "help": "Maximum number of new tokens to generate in a single completion. Default is 512, which corresponds to ~6 secs" }, ) description: str = field( @@ -57,6 +57,6 @@ class ParlerTTSHandlerArguments: max_prompt_pad_length: int = field( default=8, metadata={ - "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." + "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximum power of 2 possible." }, ) diff --git a/arguments_classes/w2v_stv_arguments.py b/arguments_classes/w2v_stv_arguments.py new file mode 100644 index 0000000..229610a --- /dev/null +++ b/arguments_classes/w2v_stv_arguments.py @@ -0,0 +1,29 @@ +"""This file contains the arguments for the Wav2Vec2STVHandler.""" +from dataclasses import dataclass, field + +@dataclass +class Wav2Vec2STVHandlerArguments: + stv_model_name: str = field( + default="bookbot/wav2vec2-ljspeech-gruut", + metadata={ + "help": "The pretrained language model to use. Default is 'bookbot/wav2vec2-ljspeech-gruut'." + }, + ) + stv_device: str = field( + default="cuda", + metadata={ + "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." + }, + ) + stv_blocksize: int = field( + default=512, + metadata={ + "help": "The blocksize of the model. Default is 512." + }, + ) + stv_skip: bool = field( + default=False, + metadata={ + "help": "If True, skips the STV generation. Default is False." + }, + ) diff --git a/connections/local_audio_streamer.py b/connections/local_audio_streamer.py index 389dcb8..99b9d83 100644 --- a/connections/local_audio_streamer.py +++ b/connections/local_audio_streamer.py @@ -27,7 +27,8 @@ def callback(indata, outdata, frames, time, status): self.input_queue.put(indata.copy()) outdata[:] = 0 * outdata else: - outdata[:] = self.output_queue.get()[:, np.newaxis] + data = self.output_queue.get() + outdata[:] = data['audio']['waveform'][:, np.newaxis] logger.debug("Available devices:") logger.debug(sd.query_devices()) diff --git a/connections/socket_sender.py b/connections/socket_sender.py index 11ed210..f849bf3 100644 --- a/connections/socket_sender.py +++ b/connections/socket_sender.py @@ -1,6 +1,8 @@ import socket from rich.console import Console import logging +import pickle +import struct logger = logging.getLogger(__name__) @@ -11,7 +13,6 @@ class SocketSender: """ Handles sending generated audio packets to the clients. """ - def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): self.stop_event = stop_event self.queue_in = queue_in @@ -28,9 +29,31 @@ def run(self): logger.info("sender connected") while not self.stop_event.is_set(): - audio_chunk = self.queue_in.get() - self.conn.sendall(audio_chunk) - if isinstance(audio_chunk, bytes) and audio_chunk == b"END": - break + data = self.queue_in.get() + packet = {} + if 'audio' in data and data['audio'] is not None: + audio_chunk = data['audio'] + packet['audio'] = audio_chunk + if 'text' in data and data['text'] is not None: + packet['text'] = data['text'] + if 'visemes' in data and data['visemes'] is not None: + packet['visemes'] = data['visemes'] + + # Serialize the packet using pickle + serialized_packet = pickle.dumps(packet) + + # Compute the length of the serialized packet + packet_length = len(serialized_packet) + + # Send the packet length as a 4-byte integer using struct + self.conn.sendall(struct.pack('!I', packet_length)) + + # Send the serialized packet + self.conn.sendall(serialized_packet) + + if 'audio' in data and data['audio'] is not None: + if isinstance(audio_chunk, bytes) and audio_chunk == b"END": + break + self.conn.close() logger.info("Sender closed") diff --git a/listen_and_play.py b/listen_and_play.py index 35eabd6..b1f282f 100644 --- a/listen_and_play.py +++ b/listen_and_play.py @@ -4,15 +4,16 @@ from dataclasses import dataclass, field import sounddevice as sd from transformers import HfArgumentParser - +import struct +import pickle @dataclass class ListenAndPlayArguments: send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) list_play_chunk_size: int = field( - default=1024, - metadata={"help": "The size of data chunks (in bytes). Default is 1024."}, + default=512, + metadata={"help": "The size of data chunks (in bytes). Default is 512."}, ) host: str = field( default="localhost", @@ -33,7 +34,7 @@ class ListenAndPlayArguments: def listen_and_play( send_rate=16000, recv_rate=44100, - list_play_chunk_size=1024, + list_play_chunk_size=512, host="localhost", send_port=12345, recv_port=12346, @@ -79,9 +80,22 @@ def receive_full_chunk(conn, chunk_size): return data while not stop_event.is_set(): - data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) - if data: - recv_queue.put(data) + # Step 1: Receive the first 4 bytes to get the packet length + length_data = receive_full_chunk(recv_socket, 4) + if not length_data: + continue # Handle disconnection or data not available + + # Step 2: Unpack the length (4 bytes) + packet_length = struct.unpack('!I', length_data)[0] + + # Step 3: Receive the full packet based on the length + serialized_packet = receive_full_chunk(recv_socket, packet_length) + if serialized_packet: + # Step 4: Deserialize the packet using pickle + packet = pickle.loads(serialized_packet) + # Step 5: Put the packet audio data into the queue for sending, if any + if 'audio' in packet and packet['audio'] is not None and 'waveform' in packet['audio'] and packet['audio']['waveform'] is not None: + recv_queue.put(packet['audio']['waveform'].tobytes()) try: send_stream = sd.RawInputStream( @@ -123,4 +137,4 @@ def receive_full_chunk(conn, chunk_size): if __name__ == "__main__": parser = HfArgumentParser((ListenAndPlayArguments,)) (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses() - listen_and_play(**vars(listen_and_play_kwargs)) + listen_and_play(**vars(listen_and_play_kwargs)) \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1da202e..4c86bf5 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -8,6 +8,7 @@ from typing import Optional from sys import platform from VAD.vad_handler import VADHandler +from STV.w2v_stv_handler import Wav2Vec2STVHandler from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.mlx_language_model_arguments import ( @@ -22,6 +23,7 @@ from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments +from arguments_classes.w2v_stv_arguments import Wav2Vec2STVHandlerArguments import torch import nltk from rich.console import Console @@ -82,6 +84,7 @@ def parse_arguments(): ParlerTTSHandlerArguments, MeloTTSHandlerArguments, ChatTTSHandlerArguments, + Wav2Vec2STVHandlerArguments, ) ) @@ -148,6 +151,8 @@ def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): kwargs.stt_device = common_device if hasattr(kwargs, "paraformer_stt_device"): kwargs.paraformer_stt_device = common_device + if hasattr(kwargs, "stv_device"): + kwargs.stv_device = common_device def prepare_module_args(module_kwargs, *handler_kwargs): @@ -167,6 +172,7 @@ def prepare_all_args( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs, ): prepare_module_args( module_kwargs, @@ -178,6 +184,7 @@ def prepare_all_args( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs ) @@ -189,6 +196,7 @@ def prepare_all_args( rename_args(parler_tts_handler_kwargs, "tts") rename_args(melo_tts_handler_kwargs, "melo") rename_args(chat_tts_handler_kwargs, "chat_tts") + rename_args(stv_handler_kwargs, "stv") def initialize_queues_and_events(): @@ -200,6 +208,7 @@ def initialize_queues_and_events(): "spoken_prompt_queue": Queue(), "text_prompt_queue": Queue(), "lm_response_queue": Queue(), + "send_viseme_queue": Queue(), } @@ -216,6 +225,7 @@ def build_pipeline( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs, queues_and_events, ): stop_event = queues_and_events["stop_event"] @@ -225,11 +235,13 @@ def build_pipeline( spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] text_prompt_queue = queues_and_events["text_prompt_queue"] lm_response_queue = queues_and_events["lm_response_queue"] + send_viseme_queue = queues_and_events["send_viseme_queue"] + if module_kwargs.mode == "local": from connections.local_audio_streamer import LocalAudioStreamer local_audio_streamer = LocalAudioStreamer( - input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue + input_queue=recv_audio_chunks_queue, output_queue=send_viseme_queue ) comms_handlers = [local_audio_streamer] should_listen.set() @@ -248,7 +260,7 @@ def build_pipeline( ), SocketSender( stop_event, - send_audio_chunks_queue, + send_viseme_queue, host=socket_sender_kwargs.send_host, port=socket_sender_kwargs.send_port, ), @@ -264,9 +276,17 @@ def build_pipeline( stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) - tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + + stv = Wav2Vec2STVHandler( + stop_event, + queue_in=send_audio_chunks_queue, + queue_out=send_viseme_queue, + setup_args=(should_listen,), + setup_kwargs=vars(stv_handler_kwargs), + ) - return ThreadManager([*comms_handlers, vad, stt, lm, tts]) + return ThreadManager([*comms_handlers, vad, stt, lm, tts, stv]) def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): @@ -337,14 +357,14 @@ def get_llm_handler( raise ValueError("The LLM should be either transformers or mlx-lm") -def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): +def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): if module_kwargs.tts == "parler": from TTS.parler_handler import ParlerTTSHandler return ParlerTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, - setup_args=(should_listen,), + setup_args=(), setup_kwargs=vars(parler_tts_handler_kwargs), ) elif module_kwargs.tts == "melo": @@ -355,11 +375,12 @@ def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chu "Error importing MeloTTSHandler. You might need to run: python -m unidic download" ) raise e + return MeloTTSHandler( stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, - setup_args=(should_listen,), + setup_args=(), setup_kwargs=vars(melo_tts_handler_kwargs), ) elif module_kwargs.tts == "chatTTS": @@ -372,7 +393,7 @@ def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chu stop_event, queue_in=lm_response_queue, queue_out=send_audio_chunks_queue, - setup_args=(should_listen,), + setup_args=(), setup_kwargs=vars(chat_tts_handler_kwargs), ) else: @@ -393,6 +414,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs, ) = parse_arguments() setup_logger(module_kwargs.log_level) @@ -407,6 +429,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs ) queues_and_events = initialize_queues_and_events() @@ -424,6 +447,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + stv_handler_kwargs, queues_and_events, )