Files
rocm-chatterbox-whisper/engine.py
scott f0ab3c1d59
All checks were successful
Build ROCm Image / build (push) Successful in 8m1s
Add MIGraphX/half deps and use AMD onnxruntime wheel for ROCm 6.1.3
Per AMD docs (rocm.docs.amd.com install-onnx):
- apt install migraphx, migraphx-dev, half (required by onnxruntime-rocm)
- Switch to AMD-hosted wheel: onnxruntime_rocm-1.17.0-cp310 from repo.radeon.com
- Pin numpy==1.26.4 (numpy 2.0 incompatible with this wheel)
- Add MIGraphXExecutionProvider to provider list in engine.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 19:23:28 -04:00

252 lines
8.0 KiB
Python

import logging
from typing import Dict, Optional, Tuple
import librosa
import numpy as np
logger = logging.getLogger(__name__)
ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
PRECISION_SUFFIX = "_fp16"
SAMPLE_RATE = 24000
START_SPEECH_TOKEN = 6561
STOP_SPEECH_TOKEN = 6562
SILENCE_TOKEN = 4299
NUM_KV_HEADS = 16
HEAD_DIM = 64
MAX_NEW_TOKENS = 1024
REPETITION_PENALTY = 1.2
# ONNX inference sessions keyed by module name
_sessions: Dict = {}
_tokenizer = None
# Cache: voice file path → speech_encoder outputs dict
_cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
def detect_device() -> str:
try:
import onnxruntime as ort
if "ROCMExecutionProvider" in ort.get_available_providers():
return "rocm"
except Exception:
pass
return "cpu"
def _get_providers(device: str) -> list:
if device in ("rocm", "cuda"):
return [
"MIGraphXExecutionProvider",
("ROCMExecutionProvider", {"device_id": 0}),
"CPUExecutionProvider",
]
return ["CPUExecutionProvider"]
def load_model() -> bool:
global _sessions, _tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from config import get_device_override
device = get_device_override() or detect_device()
providers = _get_providers(device)
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
try:
logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
_tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
module_names = [
"speech_encoder",
"embed_tokens",
"language_model",
"conditional_decoder",
]
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
for name in module_names:
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
logger.info(f"Downloading {onnx_file} ...")
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
logger.info(f"Creating session: {name}")
_sessions[name] = ort.InferenceSession(
onnx_path, sess_options=sess_options, providers=providers
)
logger.info("ONNX model loaded successfully")
return True
except Exception:
logger.exception("Failed to load ONNX model")
return False
def get_sample_rate() -> int:
return SAMPLE_RATE
def prepare_voice(audio_prompt_path: str) -> None:
"""Pre-compute and cache speech_encoder outputs for a reference audio file."""
if not _sessions or audio_prompt_path in _cond_cache:
return
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
try:
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
session = _sessions["speech_encoder"]
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
outputs = session.run(None, {input_name: audio})
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
logger.info("Voice conditionals cached")
except Exception:
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[np.ndarray, int]:
"""Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
if not _sessions:
raise RuntimeError("Model not loaded. Call load_model() first.")
if audio_prompt_path is None:
raise RuntimeError(
"ONNX model requires a reference audio file for voice cloning."
)
if seed > 0:
np.random.seed(seed)
if audio_prompt_path not in _cond_cache:
prepare_voice(audio_prompt_path)
cond = _cond_cache[audio_prompt_path]
# Tokenize input text
input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
# Generate speech tokens autoregressively
speech_tokens = _generate(input_ids, cond)
# Prepend prompt token (speaker identity anchor) and append silence tokens
prompt_token = cond.get("prompt_token")
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
decoder_input = np.concatenate(parts, axis=1)
wav = _decode(decoder_input, cond)
return wav, SAMPLE_RATE
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
embed_sess = _sessions["embed_tokens"]
lm_sess = _sessions["language_model"]
embed_input_name = embed_sess.get_inputs()[0].name
# Discover KV cache slot names from session metadata
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
lm_out_names = [o.name for o in lm_sess.get_outputs()]
kv_dtype = np.float16
# Embed full text sequence
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
# Prepend conditioning embeddings from speech encoder
cond_emb = cond.get("cond_emb")
inputs_embeds = (
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
)
seq_len = inputs_embeds.shape[1]
attention_mask = np.ones((1, seq_len), dtype=np.int64)
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
# Empty KV cache to start
past_kv = {
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
for name in past_names
}
generated: list = []
for _ in range(MAX_NEW_TOKENS):
feed = {
"inputs_embeds": inputs_embeds.astype(kv_dtype),
"attention_mask": attention_mask,
"position_ids": position_ids,
}
feed.update(past_kv)
raw = lm_sess.run(None, feed)
out = dict(zip(lm_out_names, raw))
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
# Repetition penalty (greedy decoding)
for tok in set(generated):
if logits[tok] > 0:
logits[tok] /= REPETITION_PENALTY
else:
logits[tok] *= REPETITION_PENALTY
next_token = int(np.argmax(logits))
if next_token == STOP_SPEECH_TOKEN:
break
generated.append(next_token)
# Roll KV cache forward
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
# Embed only the new token for next step
inputs_embeds = embed_sess.run(
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
)[0]
new_pos = attention_mask.shape[1]
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
position_ids = np.array([[new_pos]], dtype=np.int64)
if not generated:
logger.warning("No speech tokens generated — returning start token")
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
return np.array([generated], dtype=np.int64)
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Run conditional_decoder to produce a float32 waveform."""
dec_sess = _sessions["conditional_decoder"]
dec_inputs = dec_sess.get_inputs()
feed = {dec_inputs[0].name: speech_tokens}
if len(dec_inputs) > 1:
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
if len(dec_inputs) > 2:
feed[dec_inputs[2].name] = cond.get("speaker_features")
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
return wav.squeeze().astype(np.float32)