All checks were successful
Build ROCm Image / build (push) Successful in 2m49s
The 6700 XT has significantly higher fp16 throughput than fp32.
autocast("cuda") uses fp16 for matmuls and convolutions (HiFiGAN,
S3 tokenizer, flow matching) while keeping fp32 for precision-sensitive
ops like softmax and layer norm.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
133 lines
3.9 KiB
Python
133 lines
3.9 KiB
Python
import logging
|
|
import time
|
|
import torch
|
|
from typing import Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
chatterbox_model = None
|
|
_sample_rate = 24000
|
|
_is_turbo = False
|
|
|
|
|
|
def _test_cuda() -> bool:
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.zeros(1).cuda()
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_device() -> str:
|
|
return "cuda" if _test_cuda() else "cpu"
|
|
|
|
|
|
def load_model() -> bool:
|
|
global chatterbox_model, _sample_rate, _is_turbo
|
|
|
|
from config import get_model_repo_id, get_device_override
|
|
|
|
device = get_device_override() or detect_device()
|
|
repo_id = get_model_repo_id()
|
|
|
|
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
|
|
|
try:
|
|
if "turbo" in repo_id.lower():
|
|
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
|
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
|
_is_turbo = True
|
|
else:
|
|
from chatterbox.tts import ChatterboxTTS
|
|
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
|
_is_turbo = False
|
|
|
|
_sample_rate = 24000
|
|
|
|
# Enable MIOpen algorithm benchmarking. Without this, PyTorch picks
|
|
# convolution algorithms heuristically and passes ptr=0/size=0 workspace
|
|
# to MIOpen, forcing a slow fallback on every conv op. With benchmark=True,
|
|
# PyTorch evaluates algorithms with proper workspace on first use and caches
|
|
# the best result (persisted via MIOPEN_USER_DB_PATH volume mount).
|
|
if torch.cuda.is_available():
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
_patch_timing(chatterbox_model)
|
|
logger.info("Model loaded successfully")
|
|
return True
|
|
except Exception:
|
|
logger.exception("Failed to load model")
|
|
return False
|
|
|
|
|
|
def _patch_timing(model) -> None:
|
|
"""Wrap key sub-model forward() calls with timing logs."""
|
|
def _wrap(obj, method_name, label):
|
|
original = getattr(obj, method_name)
|
|
def timed(*args, **kwargs):
|
|
t0 = time.monotonic()
|
|
result = original(*args, **kwargs)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
|
|
return result
|
|
setattr(obj, method_name, timed)
|
|
|
|
try:
|
|
# S3 tokenizer — processes reference audio through a conformer
|
|
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# Speaker/voice encoder — xvector embedding from reference audio
|
|
_wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
|
|
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
def get_sample_rate() -> int:
|
|
return _sample_rate
|
|
|
|
|
|
def synthesize(
|
|
text: str,
|
|
audio_prompt_path: Optional[str] = None,
|
|
exaggeration: float = 0.5,
|
|
cfg_weight: float = 0.5,
|
|
temperature: float = 0.8,
|
|
seed: int = 0,
|
|
) -> Tuple[torch.Tensor, int]:
|
|
if chatterbox_model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
|
|
if seed > 0:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
kwargs: dict = {}
|
|
if audio_prompt_path:
|
|
kwargs["audio_prompt_path"] = audio_prompt_path
|
|
|
|
if _is_turbo:
|
|
kwargs["temperature"] = temperature
|
|
else:
|
|
kwargs["exaggeration"] = exaggeration
|
|
kwargs["cfg_weight"] = cfg_weight
|
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
|
wav = chatterbox_model.generate(text=text, **kwargs)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
|
|
return wav, _sample_rate
|