import logging from typing import Dict, Optional, Tuple import librosa import numpy as np logger = logging.getLogger(__name__) ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX" PRECISION_SUFFIX = "_fp16" SAMPLE_RATE = 24000 START_SPEECH_TOKEN = 6561 STOP_SPEECH_TOKEN = 6562 SILENCE_TOKEN = 4299 NUM_KV_HEADS = 16 HEAD_DIM = 64 MAX_NEW_TOKENS = 1024 REPETITION_PENALTY = 1.2 # ONNX inference sessions keyed by module name _sessions: Dict = {} _tokenizer = None # Cache: voice file path → speech_encoder outputs dict _cond_cache: Dict[str, Dict[str, np.ndarray]] = {} def detect_device() -> str: try: import onnxruntime as ort if "ROCMExecutionProvider" in ort.get_available_providers(): return "rocm" except Exception: pass return "cpu" def _get_providers(device: str) -> list: if device in ("rocm", "cuda"): return [ "MIGraphXExecutionProvider", ("ROCMExecutionProvider", {"device_id": 0}), "CPUExecutionProvider", ] return ["CPUExecutionProvider"] def load_model() -> bool: global _sessions, _tokenizer import onnxruntime as ort from huggingface_hub import hf_hub_download from transformers import AutoTokenizer from config import get_device_override device = get_device_override() or detect_device() providers = _get_providers(device) logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}") try: logger.info(f"Loading tokenizer from '{ONNX_REPO}'") _tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO) module_names = [ "speech_encoder", "embed_tokens", "language_model", "conditional_decoder", ] sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL for name in module_names: onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx" data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data" logger.info(f"Downloading {onnx_file} ...") onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file) hf_hub_download(repo_id=ONNX_REPO, filename=data_file) logger.info(f"Creating session: {name}") _sessions[name] = ort.InferenceSession( onnx_path, sess_options=sess_options, providers=providers ) logger.info("ONNX model loaded successfully") return True except Exception: logger.exception("Failed to load ONNX model") return False def get_sample_rate() -> int: return SAMPLE_RATE def prepare_voice(audio_prompt_path: str) -> None: """Pre-compute and cache speech_encoder outputs for a reference audio file.""" if not _sessions or audio_prompt_path in _cond_cache: return logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") try: audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True) audio = audio[np.newaxis, :].astype(np.float16) # [1, T] session = _sessions["speech_encoder"] input_name = session.get_inputs()[0].name output_names = [o.name for o in session.get_outputs()] outputs = session.run(None, {input_name: audio}) _cond_cache[audio_prompt_path] = dict(zip(output_names, outputs)) logger.info("Voice conditionals cached") except Exception: logger.exception(f"Failed to prepare voice '{audio_prompt_path}'") def synthesize( text: str, audio_prompt_path: Optional[str] = None, exaggeration: float = 0.5, cfg_weight: float = 0.5, temperature: float = 0.8, seed: int = 0, ) -> Tuple[np.ndarray, int]: """Synthesize speech. Returns (waveform_float32_1d, sample_rate).""" if not _sessions: raise RuntimeError("Model not loaded. Call load_model() first.") if audio_prompt_path is None: raise RuntimeError( "ONNX model requires a reference audio file for voice cloning." ) if seed > 0: np.random.seed(seed) if audio_prompt_path not in _cond_cache: prepare_voice(audio_prompt_path) cond = _cond_cache[audio_prompt_path] # Tokenize input text input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64) # Generate speech tokens autoregressively speech_tokens = _generate(input_ids, cond) # Prepend prompt token (speaker identity anchor) and append silence tokens prompt_token = cond.get("prompt_token") silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64) parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence] decoder_input = np.concatenate(parts, axis=1) wav = _decode(decoder_input, cond) return wav, SAMPLE_RATE def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: """Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T].""" embed_sess = _sessions["embed_tokens"] lm_sess = _sessions["language_model"] embed_input_name = embed_sess.get_inputs()[0].name # Discover KV cache slot names from session metadata past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name] present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name] lm_out_names = [o.name for o in lm_sess.get_outputs()] kv_dtype = np.float16 # Embed full text sequence text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden] # Prepend conditioning embeddings from speech encoder cond_emb = cond.get("cond_emb") inputs_embeds = ( np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds ) seq_len = inputs_embeds.shape[1] attention_mask = np.ones((1, seq_len), dtype=np.int64) position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :] # Empty KV cache to start past_kv = { name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype) for name in past_names } generated: list = [] for _ in range(MAX_NEW_TOKENS): feed = { "inputs_embeds": inputs_embeds.astype(kv_dtype), "attention_mask": attention_mask, "position_ids": position_ids, } feed.update(past_kv) raw = lm_sess.run(None, feed) out = dict(zip(lm_out_names, raw)) logits = out["logits"][0, -1, :].astype(np.float32) # [vocab] # Repetition penalty (greedy decoding) for tok in set(generated): if logits[tok] > 0: logits[tok] /= REPETITION_PENALTY else: logits[tok] *= REPETITION_PENALTY next_token = int(np.argmax(logits)) if next_token == STOP_SPEECH_TOKEN: break generated.append(next_token) # Roll KV cache forward past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)} # Embed only the new token for next step inputs_embeds = embed_sess.run( None, {embed_input_name: np.array([[next_token]], dtype=np.int64)} )[0] new_pos = attention_mask.shape[1] attention_mask = np.ones((1, new_pos + 1), dtype=np.int64) position_ids = np.array([[new_pos]], dtype=np.int64) if not generated: logger.warning("No speech tokens generated — returning start token") return np.array([[START_SPEECH_TOKEN]], dtype=np.int64) return np.array([generated], dtype=np.int64) def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: """Run conditional_decoder to produce a float32 waveform.""" dec_sess = _sessions["conditional_decoder"] dec_inputs = dec_sess.get_inputs() feed = {dec_inputs[0].name: speech_tokens} if len(dec_inputs) > 1: feed[dec_inputs[1].name] = cond.get("speaker_embeddings") if len(dec_inputs) > 2: feed[dec_inputs[2].name] = cond.get("speaker_features") wav = dec_sess.run(None, feed)[0] # [1, T] or [T] return wav.squeeze().astype(np.float32)