Files
rocm-chatterbox-whisper/engine.py
scott dc179ad8c6
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s
Fix input dtype mismatches; drop MIGraphXExecutionProvider
- Drop MIGraphXExecutionProvider — symbol mismatch with apt migraphx,
  ROCMExecutionProvider handles GPU execution fine without it
- Add _ort_type_to_np() helper to read expected dtypes from session metadata
- prepare_voice: cast audio to session's declared input dtype (float32, not float16)
- _generate: read embed_dtype and kv_dtype from LM session metadata

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:11:57 -04:00

284 lines
9.5 KiB
Python

import logging
from typing import Dict, Optional, Tuple
import librosa
import numpy as np
logger = logging.getLogger(__name__)
ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
PRECISION_SUFFIX = "_fp16"
SAMPLE_RATE = 24000
START_SPEECH_TOKEN = 6561
STOP_SPEECH_TOKEN = 6562
SILENCE_TOKEN = 4299
NUM_KV_HEADS = 16
HEAD_DIM = 64
MAX_NEW_TOKENS = 1024
REPETITION_PENALTY = 1.2
# ONNX inference sessions keyed by module name
_sessions: Dict = {}
_tokenizer = None
# Cache: voice file path → speech_encoder outputs dict
_cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
def detect_device() -> str:
try:
import onnxruntime as ort
available = ort.get_available_providers()
if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available:
return "rocm"
except Exception:
pass
return "cpu"
def _get_providers(device: str) -> list:
if device not in ("rocm", "cuda"):
return ["CPUExecutionProvider"]
import onnxruntime as ort
available = set(ort.get_available_providers())
providers = []
# MIGraphXExecutionProvider excluded — symbol mismatch between onnxruntime-rocm
# and apt migraphx; ROCMExecutionProvider covers GPU execution adequately.
if "ROCMExecutionProvider" in available:
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
providers.append("CPUExecutionProvider")
logger.info(f"Available ORT providers: {available}")
return providers
def _ort_type_to_np(ort_type: str):
"""Convert ORT type string (e.g. 'tensor(float16)') to numpy dtype."""
mapping = {
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(double)": np.float64,
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int16)": np.int16,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(bool)": np.bool_,
}
return mapping.get(ort_type, np.float32)
def load_model() -> bool:
global _sessions, _tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from config import get_device_override
device = get_device_override() or detect_device()
providers = _get_providers(device)
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
try:
logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
_tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
module_names = [
"speech_encoder",
"embed_tokens",
"language_model",
"conditional_decoder",
]
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
for name in module_names:
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
logger.info(f"Downloading {onnx_file} ...")
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
logger.info(f"Creating session: {name}")
_sessions[name] = ort.InferenceSession(
onnx_path, sess_options=sess_options, providers=providers
)
logger.info("ONNX model loaded successfully")
return True
except Exception:
logger.exception("Failed to load ONNX model")
return False
def get_sample_rate() -> int:
return SAMPLE_RATE
def prepare_voice(audio_prompt_path: str) -> None:
"""Pre-compute and cache speech_encoder outputs for a reference audio file."""
if not _sessions or audio_prompt_path in _cond_cache:
return
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
try:
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
session = _sessions["speech_encoder"]
input_meta = session.get_inputs()[0]
input_dtype = _ort_type_to_np(input_meta.type)
audio = audio[np.newaxis, :].astype(input_dtype) # [1, T]
input_name = input_meta.name
output_names = [o.name for o in session.get_outputs()]
outputs = session.run(None, {input_name: audio})
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
logger.info("Voice conditionals cached")
except Exception:
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[np.ndarray, int]:
"""Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
if not _sessions:
raise RuntimeError("Model not loaded. Call load_model() first.")
if audio_prompt_path is None:
raise RuntimeError(
"ONNX model requires a reference audio file for voice cloning."
)
if seed > 0:
np.random.seed(seed)
if audio_prompt_path not in _cond_cache:
prepare_voice(audio_prompt_path)
cond = _cond_cache[audio_prompt_path]
# Tokenize input text
input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
# Generate speech tokens autoregressively
speech_tokens = _generate(input_ids, cond)
# Prepend prompt token (speaker identity anchor) and append silence tokens
prompt_token = cond.get("prompt_token")
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
decoder_input = np.concatenate(parts, axis=1)
wav = _decode(decoder_input, cond)
return wav, SAMPLE_RATE
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
embed_sess = _sessions["embed_tokens"]
lm_sess = _sessions["language_model"]
embed_input_name = embed_sess.get_inputs()[0].name
# Discover KV cache slot names and dtypes from session metadata
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
lm_out_names = [o.name for o in lm_sess.get_outputs()]
# Read expected embed dtype from language_model's inputs_embeds slot
embeds_meta = next((i for i in lm_sess.get_inputs() if "embeds" in i.name), None)
embed_dtype = _ort_type_to_np(embeds_meta.type) if embeds_meta else np.float32
# Read KV cache dtype from first past_key_values slot
kv_meta = next((i for i in lm_sess.get_inputs() if "past_key_values" in i.name), None)
kv_dtype = _ort_type_to_np(kv_meta.type) if kv_meta else np.float16
# Embed full text sequence
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
# Prepend conditioning embeddings from speech encoder
cond_emb = cond.get("cond_emb")
inputs_embeds = (
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
)
seq_len = inputs_embeds.shape[1]
attention_mask = np.ones((1, seq_len), dtype=np.int64)
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
# Empty KV cache to start
past_kv = {
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
for name in past_names
}
generated: list = []
for _ in range(MAX_NEW_TOKENS):
feed = {
"inputs_embeds": inputs_embeds.astype(embed_dtype),
"attention_mask": attention_mask,
"position_ids": position_ids,
}
feed.update(past_kv)
raw = lm_sess.run(None, feed)
out = dict(zip(lm_out_names, raw))
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
# Repetition penalty (greedy decoding)
for tok in set(generated):
if logits[tok] > 0:
logits[tok] /= REPETITION_PENALTY
else:
logits[tok] *= REPETITION_PENALTY
next_token = int(np.argmax(logits))
if next_token == STOP_SPEECH_TOKEN:
break
generated.append(next_token)
# Roll KV cache forward
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
# Embed only the new token for next step
inputs_embeds = embed_sess.run(
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
)[0]
new_pos = attention_mask.shape[1]
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
position_ids = np.array([[new_pos]], dtype=np.int64)
if not generated:
logger.warning("No speech tokens generated — returning start token")
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
return np.array([generated], dtype=np.int64)
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Run conditional_decoder to produce a float32 waveform."""
dec_sess = _sessions["conditional_decoder"]
dec_inputs = dec_sess.get_inputs()
feed = {dec_inputs[0].name: speech_tokens}
if len(dec_inputs) > 1:
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
if len(dec_inputs) > 2:
feed[dec_inputs[2].name] = cond.get("speaker_features")
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
return wav.squeeze().astype(np.float32)