import logging import torch from typing import Optional, Tuple logger = logging.getLogger(__name__) chatterbox_model = None _sample_rate = 24000 _is_turbo = False # Cache: voice file path → prepared conditionals object. # prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and # builds mel embeddings — expensive work that only depends on the reference # audio, not the text. Cache it so multi-chunk requests pay the cost once. _cond_cache: dict = {} def _test_cuda() -> bool: try: if torch.cuda.is_available(): torch.zeros(1).cuda() return True except Exception: pass return False def detect_device() -> str: return "cuda" if _test_cuda() else "cpu" def load_model() -> bool: global chatterbox_model, _sample_rate, _is_turbo from config import get_model_repo_id, get_device_override device = get_device_override() or detect_device() repo_id = get_model_repo_id() logger.info(f"Loading model '{repo_id}' on device '{device}'") try: if "turbo" in repo_id.lower(): from chatterbox.tts_turbo import ChatterboxTurboTTS chatterbox_model = ChatterboxTurboTTS.from_pretrained(device) _is_turbo = True else: from chatterbox.tts import ChatterboxTTS chatterbox_model = ChatterboxTTS.from_pretrained(device) _is_turbo = False _sample_rate = 24000 logger.info("Model loaded successfully") return True except Exception: logger.exception("Failed to load model") return False def prepare_voice(audio_prompt_path: str) -> None: """ Pre-compute and cache the voice conditionals for a reference audio file. Calling this once avoids repeating the s3tokenizer + voice encoder work on every synthesis chunk that uses the same voice. """ if chatterbox_model is None: return if audio_prompt_path in _cond_cache: return if not _is_turbo: return # only turbo exposes prepare_conditionals logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") with torch.inference_mode(): chatterbox_model.prepare_conditionals(audio_prompt_path) _cond_cache[audio_prompt_path] = chatterbox_model.conds logger.info("Voice conditionals cached") def get_sample_rate() -> int: return _sample_rate def synthesize( text: str, audio_prompt_path: Optional[str] = None, exaggeration: float = 0.5, cfg_weight: float = 0.5, temperature: float = 0.8, seed: int = 0, ) -> Tuple[torch.Tensor, int]: if chatterbox_model is None: raise RuntimeError("Model not loaded. Call load_model() first.") if seed > 0: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Restore cached conditionals so generate() skips prepare_conditionals. if audio_prompt_path and _is_turbo: if audio_prompt_path not in _cond_cache: prepare_voice(audio_prompt_path) chatterbox_model.conds = _cond_cache[audio_prompt_path] kwargs: dict = {} # Don't pass audio_prompt_path — conds are already set above. # For non-turbo models there's no cache, pass path as normal. if audio_prompt_path and not _is_turbo: kwargs["audio_prompt_path"] = audio_prompt_path if _is_turbo: kwargs["temperature"] = temperature else: kwargs["exaggeration"] = exaggeration kwargs["cfg_weight"] = cfg_weight with torch.inference_mode(): with torch.amp.autocast(device_type="cuda", dtype=torch.float16): wav = chatterbox_model.generate(text=text, **kwargs) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() return wav, _sample_rate