All checks were successful
Build ROCm Image / build (push) Successful in 2m39s
import inside a function creates a local binding that shadows the module-level torch import, breaking all earlier torch references in the same function scope. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
138 lines
4.2 KiB
Python
138 lines
4.2 KiB
Python
import logging
|
|
import time
|
|
import torch
|
|
import torch._dynamo
|
|
from typing import Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
chatterbox_model = None
|
|
_sample_rate = 24000
|
|
_is_turbo = False
|
|
|
|
|
|
def _test_cuda() -> bool:
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.zeros(1).cuda()
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_device() -> str:
|
|
return "cuda" if _test_cuda() else "cpu"
|
|
|
|
|
|
def load_model() -> bool:
|
|
global chatterbox_model, _sample_rate, _is_turbo
|
|
|
|
from config import get_model_repo_id, get_device_override
|
|
|
|
device = get_device_override() or detect_device()
|
|
repo_id = get_model_repo_id()
|
|
|
|
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
|
|
|
try:
|
|
if "turbo" in repo_id.lower():
|
|
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
|
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
|
_is_turbo = True
|
|
else:
|
|
from chatterbox.tts import ChatterboxTTS
|
|
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
|
_is_turbo = False
|
|
|
|
_sample_rate = 24000
|
|
|
|
if torch.cuda.is_available():
|
|
# torch.compile replaces MIOpen's convolution path with Triton-generated
|
|
# kernels, bypassing the workspace=0 fallback entirely. We compile only
|
|
# s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck.
|
|
# suppress_errors=True falls back to eager for any op compile can't handle.
|
|
try:
|
|
torch._dynamo.config.suppress_errors = True
|
|
chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True)
|
|
logger.info("s3gen compiled with torch.compile")
|
|
except Exception:
|
|
logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True)
|
|
|
|
_patch_timing(chatterbox_model)
|
|
logger.info("Model loaded successfully")
|
|
return True
|
|
except Exception:
|
|
logger.exception("Failed to load model")
|
|
return False
|
|
|
|
|
|
def _patch_timing(model) -> None:
|
|
"""Wrap key sub-model forward() calls with timing logs."""
|
|
def _wrap(obj, method_name, label):
|
|
original = getattr(obj, method_name)
|
|
def timed(*args, **kwargs):
|
|
t0 = time.monotonic()
|
|
result = original(*args, **kwargs)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
|
|
return result
|
|
setattr(obj, method_name, timed)
|
|
|
|
try:
|
|
# S3 tokenizer — processes reference audio through a conformer
|
|
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# Speaker/voice encoder — xvector embedding from reference audio
|
|
_wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
|
|
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
def get_sample_rate() -> int:
|
|
return _sample_rate
|
|
|
|
|
|
def synthesize(
|
|
text: str,
|
|
audio_prompt_path: Optional[str] = None,
|
|
exaggeration: float = 0.5,
|
|
cfg_weight: float = 0.5,
|
|
temperature: float = 0.8,
|
|
seed: int = 0,
|
|
) -> Tuple[torch.Tensor, int]:
|
|
if chatterbox_model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
|
|
if seed > 0:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
kwargs: dict = {}
|
|
if audio_prompt_path:
|
|
kwargs["audio_prompt_path"] = audio_prompt_path
|
|
|
|
if _is_turbo:
|
|
kwargs["temperature"] = temperature
|
|
else:
|
|
kwargs["exaggeration"] = exaggeration
|
|
kwargs["cfg_weight"] = cfg_weight
|
|
|
|
with torch.inference_mode():
|
|
wav = chatterbox_model.generate(text=text, **kwargs)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
|
|
return wav, _sample_rate
|