Files
rocm-chatterbox-whisper/engine.py
scott 7babd0584e
All checks were successful
Build ROCm Image / build (push) Successful in 2m47s
Replace MIOpen convolution path with torch.compile on s3gen
The GemmFwdRest workspace=0 issue is in MIOpen itself — PyTorch's ROCm
backend does not allocate workspace for convolutions, causing HiFiGAN to
use a slow fallback solver regardless of benchmark settings.

torch.compile(s3gen, dynamic=True) replaces MIOpen's conv path with
Triton-generated kernels, bypassing the issue entirely. dynamic=True
handles variable audio lengths without recompiling per request. The warmup
triggers JIT compilation so first HA request is fast.

Also removes fp16 autocast (Triton handles precision internally) and
cudnn.benchmark (no longer needed without MIOpen convs).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 14:27:09 -04:00

138 lines
4.2 KiB
Python

import logging
import time
import torch
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
chatterbox_model = None
_sample_rate = 24000
_is_turbo = False
def _test_cuda() -> bool:
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu"
def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo
from config import get_model_repo_id, get_device_override
device = get_device_override() or detect_device()
repo_id = get_model_repo_id()
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try:
if "turbo" in repo_id.lower():
from chatterbox.tts_turbo import ChatterboxTurboTTS
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000
if torch.cuda.is_available():
# torch.compile replaces MIOpen's convolution path with Triton-generated
# kernels, bypassing the workspace=0 fallback entirely. We compile only
# s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck.
# suppress_errors=True falls back to eager for any op compile can't handle.
try:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True)
logger.info("s3gen compiled with torch.compile")
except Exception:
logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True)
_patch_timing(chatterbox_model)
logger.info("Model loaded successfully")
return True
except Exception:
logger.exception("Failed to load model")
return False
def _patch_timing(model) -> None:
"""Wrap key sub-model forward() calls with timing logs."""
def _wrap(obj, method_name, label):
original = getattr(obj, method_name)
def timed(*args, **kwargs):
t0 = time.monotonic()
result = original(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
return result
setattr(obj, method_name, timed)
try:
# S3 tokenizer — processes reference audio through a conformer
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
except AttributeError:
pass
try:
# Speaker/voice encoder — xvector embedding from reference audio
_wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)")
except AttributeError:
pass
try:
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
except AttributeError:
pass
def get_sample_rate() -> int:
return _sample_rate
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[torch.Tensor, int]:
if chatterbox_model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
if seed > 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
kwargs: dict = {}
if audio_prompt_path:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo:
kwargs["temperature"] = temperature
else:
kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode():
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate