Files
rocm-chatterbox-whisper/engine.py
scott f20699aed3
All checks were successful
Build ROCm Image / build (push) Successful in 2m49s
Add fp16 autocast to synthesis for faster GPU throughput
The 6700 XT has significantly higher fp16 throughput than fp32.
autocast("cuda") uses fp16 for matmuls and convolutions (HiFiGAN,
S3 tokenizer, flow matching) while keeping fp32 for precision-sensitive
ops like softmax and layer norm.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 13:34:21 -04:00

133 lines
3.9 KiB
Python

import logging
import time
import torch
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
chatterbox_model = None
_sample_rate = 24000
_is_turbo = False
def _test_cuda() -> bool:
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu"
def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo
from config import get_model_repo_id, get_device_override
device = get_device_override() or detect_device()
repo_id = get_model_repo_id()
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try:
if "turbo" in repo_id.lower():
from chatterbox.tts_turbo import ChatterboxTurboTTS
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000
# Enable MIOpen algorithm benchmarking. Without this, PyTorch picks
# convolution algorithms heuristically and passes ptr=0/size=0 workspace
# to MIOpen, forcing a slow fallback on every conv op. With benchmark=True,
# PyTorch evaluates algorithms with proper workspace on first use and caches
# the best result (persisted via MIOPEN_USER_DB_PATH volume mount).
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
_patch_timing(chatterbox_model)
logger.info("Model loaded successfully")
return True
except Exception:
logger.exception("Failed to load model")
return False
def _patch_timing(model) -> None:
"""Wrap key sub-model forward() calls with timing logs."""
def _wrap(obj, method_name, label):
original = getattr(obj, method_name)
def timed(*args, **kwargs):
t0 = time.monotonic()
result = original(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
return result
setattr(obj, method_name, timed)
try:
# S3 tokenizer — processes reference audio through a conformer
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
except AttributeError:
pass
try:
# Speaker/voice encoder — xvector embedding from reference audio
_wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)")
except AttributeError:
pass
try:
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
except AttributeError:
pass
def get_sample_rate() -> int:
return _sample_rate
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[torch.Tensor, int]:
if chatterbox_model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
if seed > 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
kwargs: dict = {}
if audio_prompt_path:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo:
kwargs["temperature"] = temperature
else:
kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate