Files
rocm-chatterbox-whisper/engine.py
scott 169e003a34
All checks were successful
Build ROCm Image / build (push) Successful in 3m35s
Fix warmup text length and ve attribute for torch.compile
- Warmup now uses a ~170-char representative sentence so torch.compile
  JIT-compiles for typical token sequence lengths. Previously "Warmup."
  compiled for very short shapes, causing a full re-compile (17s) on the
  first real HA request and pushing total synthesis past 30s.
- Compile model.ve (voice encoder) in addition to s3gen — both are
  convolutional and hit the MIOpen workspace=0 bug.
- Fix _patch_timing: attribute is model.ve not model.voice_encoder,
  so the timing wrap was silently skipping the speaker embedding.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:51:08 -04:00

141 lines
4.3 KiB
Python

import logging
import time
import torch
import torch._dynamo
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
chatterbox_model = None
_sample_rate = 24000
_is_turbo = False
def _test_cuda() -> bool:
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu"
def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo
from config import get_model_repo_id, get_device_override
device = get_device_override() or detect_device()
repo_id = get_model_repo_id()
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try:
if "turbo" in repo_id.lower():
from chatterbox.tts_turbo import ChatterboxTurboTTS
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000
if torch.cuda.is_available():
# torch.compile replaces MIOpen's convolution path with Triton-generated
# kernels, bypassing the workspace=0 fallback entirely. We compile s3gen
# (HiFiGAN vocoder + flow matching) and the voice encoder (ve) since both
# are convolutional and hit the workspace=0 bug.
# suppress_errors=True falls back to eager for any op compile can't handle.
torch._dynamo.config.suppress_errors = True
for attr, label in [("s3gen", "s3gen"), ("ve", "ve")]:
try:
obj = getattr(chatterbox_model, attr)
setattr(chatterbox_model, attr, torch.compile(obj, dynamic=True))
logger.info(f"{label} compiled with torch.compile")
except Exception:
logger.warning(f"torch.compile unavailable for {label}, running in eager mode", exc_info=True)
_patch_timing(chatterbox_model)
logger.info("Model loaded successfully")
return True
except Exception:
logger.exception("Failed to load model")
return False
def _patch_timing(model) -> None:
"""Wrap key sub-model forward() calls with timing logs."""
def _wrap(obj, method_name, label):
original = getattr(obj, method_name)
def timed(*args, **kwargs):
t0 = time.monotonic()
result = original(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
return result
setattr(obj, method_name, timed)
try:
# S3 tokenizer — processes reference audio through a conformer
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
except AttributeError:
pass
try:
# Speaker/voice encoder — xvector embedding from reference audio
_wrap(model.ve, "forward", "ve (speaker embedding)")
except AttributeError:
pass
try:
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
except AttributeError:
pass
def get_sample_rate() -> int:
return _sample_rate
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[torch.Tensor, int]:
if chatterbox_model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
if seed > 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
kwargs: dict = {}
if audio_prompt_path:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo:
kwargs["temperature"] = temperature
else:
kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode():
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate