All checks were successful
Build ROCm Image / build (push) Successful in 2m47s
The GemmFwdRest workspace=0 issue is in MIOpen itself — PyTorch's ROCm backend does not allocate workspace for convolutions, causing HiFiGAN to use a slow fallback solver regardless of benchmark settings. torch.compile(s3gen, dynamic=True) replaces MIOpen's conv path with Triton-generated kernels, bypassing the issue entirely. dynamic=True handles variable audio lengths without recompiling per request. The warmup triggers JIT compilation so first HA request is fast. Also removes fp16 autocast (Triton handles precision internally) and cudnn.benchmark (no longer needed without MIOpen convs). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
138 lines
4.2 KiB
Python
138 lines
4.2 KiB
Python
import logging
|
|
import time
|
|
import torch
|
|
from typing import Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
chatterbox_model = None
|
|
_sample_rate = 24000
|
|
_is_turbo = False
|
|
|
|
|
|
def _test_cuda() -> bool:
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.zeros(1).cuda()
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_device() -> str:
|
|
return "cuda" if _test_cuda() else "cpu"
|
|
|
|
|
|
def load_model() -> bool:
|
|
global chatterbox_model, _sample_rate, _is_turbo
|
|
|
|
from config import get_model_repo_id, get_device_override
|
|
|
|
device = get_device_override() or detect_device()
|
|
repo_id = get_model_repo_id()
|
|
|
|
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
|
|
|
try:
|
|
if "turbo" in repo_id.lower():
|
|
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
|
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
|
_is_turbo = True
|
|
else:
|
|
from chatterbox.tts import ChatterboxTTS
|
|
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
|
_is_turbo = False
|
|
|
|
_sample_rate = 24000
|
|
|
|
if torch.cuda.is_available():
|
|
# torch.compile replaces MIOpen's convolution path with Triton-generated
|
|
# kernels, bypassing the workspace=0 fallback entirely. We compile only
|
|
# s3gen (HiFiGAN vocoder + flow matching) since that's the bottleneck.
|
|
# suppress_errors=True falls back to eager for any op compile can't handle.
|
|
try:
|
|
import torch._dynamo
|
|
torch._dynamo.config.suppress_errors = True
|
|
chatterbox_model.s3gen = torch.compile(chatterbox_model.s3gen, dynamic=True)
|
|
logger.info("s3gen compiled with torch.compile")
|
|
except Exception:
|
|
logger.warning("torch.compile unavailable, running s3gen in eager mode", exc_info=True)
|
|
|
|
_patch_timing(chatterbox_model)
|
|
logger.info("Model loaded successfully")
|
|
return True
|
|
except Exception:
|
|
logger.exception("Failed to load model")
|
|
return False
|
|
|
|
|
|
def _patch_timing(model) -> None:
|
|
"""Wrap key sub-model forward() calls with timing logs."""
|
|
def _wrap(obj, method_name, label):
|
|
original = getattr(obj, method_name)
|
|
def timed(*args, **kwargs):
|
|
t0 = time.monotonic()
|
|
result = original(*args, **kwargs)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
logger.info(f"[timing] {label}: {time.monotonic() - t0:.3f}s")
|
|
return result
|
|
setattr(obj, method_name, timed)
|
|
|
|
try:
|
|
# S3 tokenizer — processes reference audio through a conformer
|
|
_wrap(model.s3tokenizer, "forward", "s3tokenizer (ref audio encoding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# Speaker/voice encoder — xvector embedding from reference audio
|
|
_wrap(model.voice_encoder, "forward", "voice_encoder (speaker embedding)")
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
# S3Gen decode: flow matching (token -> mel) + HiFiGAN (mel -> wav)
|
|
_wrap(model.s3gen, "inference", "s3gen.inference (flow+vocoder)")
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
def get_sample_rate() -> int:
|
|
return _sample_rate
|
|
|
|
|
|
def synthesize(
|
|
text: str,
|
|
audio_prompt_path: Optional[str] = None,
|
|
exaggeration: float = 0.5,
|
|
cfg_weight: float = 0.5,
|
|
temperature: float = 0.8,
|
|
seed: int = 0,
|
|
) -> Tuple[torch.Tensor, int]:
|
|
if chatterbox_model is None:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
|
|
if seed > 0:
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
kwargs: dict = {}
|
|
if audio_prompt_path:
|
|
kwargs["audio_prompt_path"] = audio_prompt_path
|
|
|
|
if _is_turbo:
|
|
kwargs["temperature"] = temperature
|
|
else:
|
|
kwargs["exaggeration"] = exaggeration
|
|
kwargs["cfg_weight"] = cfg_weight
|
|
|
|
with torch.inference_mode():
|
|
wav = chatterbox_model.generate(text=text, **kwargs)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
|
|
return wav, _sample_rate
|