import logging import torch from typing import Optional, Tuple logger = logging.getLogger(__name__) chatterbox_model = None _sample_rate = 24000 _is_turbo = False def _test_cuda() -> bool: try: if torch.cuda.is_available(): torch.zeros(1).cuda() return True except Exception: pass return False def detect_device() -> str: return "cuda" if _test_cuda() else "cpu" def load_model() -> bool: global chatterbox_model, _sample_rate, _is_turbo from config import get_model_repo_id, get_device_override device = get_device_override() or detect_device() repo_id = get_model_repo_id() logger.info(f"Loading model '{repo_id}' on device '{device}'") try: if "turbo" in repo_id.lower(): from chatterbox.tts_turbo import ChatterboxTurboTTS chatterbox_model = ChatterboxTurboTTS.from_pretrained(device) _is_turbo = True else: from chatterbox.tts import ChatterboxTTS chatterbox_model = ChatterboxTTS.from_pretrained(device) _is_turbo = False _sample_rate = 24000 logger.info("Model loaded successfully") return True except Exception: logger.exception("Failed to load model") return False def get_sample_rate() -> int: return _sample_rate def synthesize( text: str, audio_prompt_path: Optional[str] = None, exaggeration: float = 0.5, cfg_weight: float = 0.5, temperature: float = 0.8, seed: int = 0, ) -> Tuple[torch.Tensor, int]: if chatterbox_model is None: raise RuntimeError("Model not loaded. Call load_model() first.") if seed > 0: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) kwargs: dict = {} if audio_prompt_path: kwargs["audio_prompt_path"] = audio_prompt_path if _is_turbo: kwargs["temperature"] = temperature else: kwargs["exaggeration"] = exaggeration kwargs["cfg_weight"] = cfg_weight with torch.inference_mode(): wav = chatterbox_model.generate(text=text, **kwargs) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() return wav, _sample_rate