Skip to content

Inference is too slow #100

@ukemamaster

Description

@ukemamaster

Hi, I am trying the MOSS-TTS-Local-Transformer model with a simple sentence and i get around 4.89 it/s on NVIDIA Tesla T4 gpu. It takes 100s to generate a 30s audio. Is this normal?

My inference code is:

import torch
import torchaudio
from pathlib import Path
import importlib.util
from transformers import AutoModel, AutoProcessor
from scipy.io import wavfile

# Disable the broken cuDNN SDPA backend
torch.backends.cuda.enable_cudnn_sdp(False)
# Keep these enabled as fallbacks
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)

pretrained_model_name_or_path = "/local_path/downloaded_models/MOSS-TTS-Local-Transformer"
codec_local_path = "/local_path/downloaded_models/MOSS-Audio-Tokenizer"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

def resolve_attn_implementation() -> str:
    # Prefer FlashAttention 2 when package + device conditions are met.
    if (
        device == "cuda"
        and importlib.util.find_spec("flash_attn") is not None
        and dtype in {torch.float16, torch.bfloat16}
    ):
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            return "flash_attention_2"

    # CUDA fallback: use PyTorch SDPA kernels.
    if device == "cuda":
        return "sdpa"

    # CPU fallback.
    return "eager"


attn_implementation = resolve_attn_implementation()
print(f"[INFO] Using attn_implementation={attn_implementation}")

processor = AutoProcessor.from_pretrained(
    pretrained_model_name_or_path,
    codec_path=codec_local_path,
    trust_remote_code=True,
)
processor.audio_tokenizer = processor.audio_tokenizer.to(device)


text_2 = "Nos encontramos en el umbral de la era de la IA. La inteligencia artificial ya no es solo un concepto de laboratorio, sino que está presente en todos los sectores, en todas las iniciativas creativas y en todas las decisiones. Ha aprendido a ver, oír, hablar y pensar, y comienza a convertirse en una extensión de las capacidades humanas. La IA no busca reemplazar a los humanos, sino potenciar su creatividad, hacer que el conocimiento sea más equitativo y eficiente, y permitir que la imaginación alcance mayores horizontes. Ha llegado una nueva era, moldeada conjuntamente por humanos y sistemas inteligentes."

conversations = [
    # Direct TTS (no reference)
    [processor.build_user_message(text=text_2)],
]

model = AutoModel.from_pretrained(
    pretrained_model_name_or_path,
    trust_remote_code=True,
    attn_implementation=attn_implementation,
    torch_dtype=dtype,
).to(device)
model.eval()

batch_size = 1

save_dir = Path("inference_root")
save_dir.mkdir(exist_ok=True, parents=True)
sample_idx = 0
with torch.no_grad():
    for start in range(0, len(conversations), batch_size):
        print(f'inference ...')
        batch_conversations = conversations[start : start + batch_size]
        batch = processor(batch_conversations, mode="generation")
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=4096,
        )

        for message in processor.decode(outputs):
            audio = message.audio_codes_list[0]
            audio_numpy = audio.squeeze().numpy()
            out_path = save_dir / f"sample{sample_idx}.wav"
            sample_idx += 1
            #torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate, backend="soundfile")
            wavfile.write(out_path, processor.model_config.sampling_rate, audio_numpy)

Logs:

[INFO] Using attn_implementation=sdpa
Loading weights: 100%|███████████████████████████████████████████████████████| 1600/1600 [00:00<00:00, 3922.71it/s, Materializing param=quantizer.quantizers.31.out_proj.parametrizations.weight.original1]
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████| 556/556 [00:00<00:00, 1318.18it/s, Materializing param=speech_embedding_to_local_mlp.up_proj.weight]
inference ...
508it [01:43,  4.89it/s]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions