Files
rocm-chatterbox-whisper/engine.py
scott 51188ca973
All checks were successful
Build ROCm Image / build (push) Successful in 2m39s
[dev-fp16] Only convert T3 to fp16, leave s3gen/ve in fp32
s3gen.speaker_encoder (CAMPPlus xvector) hardcodes float32 inputs in
its inference() method, causing dtype mismatch when weights are fp16.
T3 (the autoregressive GPT-2-medium LLM) has no such constraint and
is the token-generation bottleneck that benefits most from fp16.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 20:41:24 -04:00

138 lines
4.3 KiB
Python

import logging
import torch
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
chatterbox_model = None
_sample_rate = 24000
_is_turbo = False
# Cache: voice file path → prepared conditionals object.
# prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and
# builds mel embeddings — expensive work that only depends on the reference
# audio, not the text. Cache it so multi-chunk requests pay the cost once.
_cond_cache: dict = {}
def _test_cuda() -> bool:
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu"
def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo
from config import get_model_repo_id, get_device_override
device = get_device_override() or detect_device()
repo_id = get_model_repo_id()
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try:
if "turbo" in repo_id.lower():
from chatterbox.tts_turbo import ChatterboxTurboTTS
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000
# Convert T3 (the autoregressive LLM) to fp16 for faster token generation.
# s3gen and ve are left in fp32 — s3gen.speaker_encoder (CAMPPlus xvector)
# hardcodes float32 inputs in its inference() method and errors on fp16 weights.
if torch.cuda.is_available():
try:
if hasattr(chatterbox_model, "t3"):
chatterbox_model.t3.half()
logger.info("T3 converted to fp16")
except Exception:
logger.warning("T3 fp16 conversion failed, running in fp32", exc_info=True)
logger.info("Model loaded successfully")
return True
except Exception:
logger.exception("Failed to load model")
return False
def prepare_voice(audio_prompt_path: str) -> None:
"""
Pre-compute and cache the voice conditionals for a reference audio file.
Calling this once avoids repeating the s3tokenizer + voice encoder work
on every synthesis chunk that uses the same voice.
"""
if chatterbox_model is None:
return
if audio_prompt_path in _cond_cache:
return
if not _is_turbo:
return # only turbo exposes prepare_conditionals
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
with torch.inference_mode():
chatterbox_model.prepare_conditionals(audio_prompt_path)
_cond_cache[audio_prompt_path] = chatterbox_model.conds
logger.info("Voice conditionals cached")
def get_sample_rate() -> int:
return _sample_rate
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[torch.Tensor, int]:
if chatterbox_model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
if seed > 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Restore cached conditionals so generate() skips prepare_conditionals.
if audio_prompt_path and _is_turbo:
if audio_prompt_path not in _cond_cache:
prepare_voice(audio_prompt_path)
chatterbox_model.conds = _cond_cache[audio_prompt_path]
kwargs: dict = {}
# Don't pass audio_prompt_path — conds are already set above.
# For non-turbo models there's no cache, pass path as normal.
if audio_prompt_path and not _is_turbo:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo:
kwargs["temperature"] = temperature
else:
kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode():
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate