import logging import time import torch import torch._dynamo from typing import Optional, Tuple logger = logging.getLogger(__name__) chatterbox_model = None _sample_rate = 24000 _is_turbo = False def _test_cuda() -> bool: try: if torch.cuda.is_available(): torch.zeros(1).cuda() return True except Exception: pass return False def detect_device() -> str: return "cuda" if _test_cuda() else "cpu" def load_model() -> bool: global chatterbox_model, _sample_rate, _is_turbo from config import get_model_repo_id, get_device_override device = get_device_override() or detect_device() repo_id = get_model_repo_id() logger.info(f"Loading model '{repo_id}' on device '{device}'") try: if "turbo" in repo_id.lower(): from chatterbox.tts_turbo import ChatterboxTurboTTS chatterbox_model = ChatterboxTurboTTS.from_pretrained(device) _is_turbo = True else: from chatterbox.tts import ChatterboxTTS chatterbox_model = ChatterboxTTS.from_pretrained(device) _is_turbo = False _sample_rate = 24000 if torch.cuda.is_available(): # torch.compile replaces MIOpen's convolution path with Triton-generated # kernels, bypassing the workspace=0 fallback entirely. We compile only # s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck. # suppress_errors=True falls back to eager for any op compile can't handle. try: torch._dynamo.config.suppress_errors = True chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True) logger.info("s3gen compiled with torch.compile") except Exception: logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True) _patch_timing(chatterbox_model) logger.info("Model loaded successfully") return True except Exception: logger.exception("Failed to load model") return False def _patch_timing(model) -> None: """Wrap key sub-model forward() calls with timing logs.""" def _wrap(obj, method_name, label): original = getattr(obj, method_name) def timed(*args, **kwargs): t0 = time.monotonic() result = original(*args, **kwargs) if torch.cuda.is_available(): torch.cuda.synchronize() logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s") return result setattr(obj, method_name, timed) try: # S3 tokenizer — processes reference audio through a conformer _wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)") except AttributeError: pass try: # Speaker/voice encoder — xvector embedding from reference audio _wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)") except AttributeError: pass try: # S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav) _wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)") except AttributeError: pass def get_sample_rate() -> int: return _sample_rate def synthesize( text: str, audio_prompt_path: Optional[str] = None, exaggeration: float = 0.5, cfg_weight: float = 0.5, temperature: float = 0.8, seed: int = 0, ) -> Tuple[torch.Tensor, int]: if chatterbox_model is None: raise RuntimeError("Model not loaded. Call load_model() first.") if seed > 0: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) kwargs: dict = {} if audio_prompt_path: kwargs["audio_prompt_path"] = audio_prompt_path if _is_turbo: kwargs["temperature"] = temperature else: kwargs["exaggeration"] = exaggeration kwargs["cfg_weight"] = cfg_weight with torch.inference_mode(): wav = chatterbox_model.generate(text=text, **kwargs) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() return wav, _sample_rate