-
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathstyle_transfer.py
More file actions
58 lines (45 loc) · 1.58 KB
/
style_transfer.py
File metadata and controls
58 lines (45 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gc
import logging
import threading
import torch
logger = logging.getLogger(__name__)
_muq_model = None
_muq_lock = threading.Lock()
def get_muq_model():
"""Lazy-load MuQ-MuLan on CPU (avoids GPU contention with HeartMuLa)."""
global _muq_model
with _muq_lock:
if _muq_model is not None:
return _muq_model
logger.info("Loading MuQ-MuLan model...")
from muq import MuQMuLan
_muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
_muq_model = _muq_model.to("cpu").float().eval()
logger.info("MuQ-MuLan loaded on CPU")
return _muq_model
def unload_muq():
"""Free MuQ model memory."""
global _muq_model
with _muq_lock:
_muq_model = None
gc.collect()
logger.info("MuQ-MuLan unloaded")
def extract_style_embedding(audio_path: str) -> torch.Tensor:
"""Extract 512D style embedding from a reference audio file.
MuQ-MuLan requires 24 kHz mono input and fp32 precision.
Returns:
torch.Tensor of shape [512] in bfloat16 (matches HeartMuLa dtype).
"""
import librosa
try:
wav, _sr = librosa.load(audio_path, sr=24000)
except Exception as e:
raise RuntimeError(
f"Failed to load audio file: {e}. "
"Ensure the file is a valid audio format (WAV, MP3, FLAC, etc.)."
) from e
model = get_muq_model()
wavs = torch.tensor(wav).unsqueeze(0).float() # [1, samples]
with torch.no_grad():
embedding = model(wavs=wavs) # [1, 512]
return embedding.squeeze(0).to(torch.bfloat16) # [512]