All checks were successful
Build ROCm Image / build (push) Successful in 14m51s
- Dockerfile: rocm-libs installs all ROCm compute libraries at once (hipblas, hipfft, hipsparse, rocblas, miopen, etc.) avoiding whack-a-mole with individual missing .so files - engine.py: query ort.get_available_providers() at runtime and only request providers that actually loaded — falls back to CPU instead of crashing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
259 lines
8.4 KiB
Python
259 lines
8.4 KiB
Python
import logging
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import librosa
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
|
|
PRECISION_SUFFIX = "_fp16"
|
|
SAMPLE_RATE = 24000
|
|
START_SPEECH_TOKEN = 6561
|
|
STOP_SPEECH_TOKEN = 6562
|
|
SILENCE_TOKEN = 4299
|
|
NUM_KV_HEADS = 16
|
|
HEAD_DIM = 64
|
|
MAX_NEW_TOKENS = 1024
|
|
REPETITION_PENALTY = 1.2
|
|
|
|
# ONNX inference sessions keyed by module name
|
|
_sessions: Dict = {}
|
|
_tokenizer = None
|
|
|
|
# Cache: voice file path → speech_encoder outputs dict
|
|
_cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
|
|
|
|
|
|
def detect_device() -> str:
|
|
try:
|
|
import onnxruntime as ort
|
|
available = ort.get_available_providers()
|
|
if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available:
|
|
return "rocm"
|
|
except Exception:
|
|
pass
|
|
return "cpu"
|
|
|
|
|
|
def _get_providers(device: str) -> list:
|
|
if device not in ("rocm", "cuda"):
|
|
return ["CPUExecutionProvider"]
|
|
|
|
import onnxruntime as ort
|
|
available = set(ort.get_available_providers())
|
|
providers = []
|
|
if "MIGraphXExecutionProvider" in available:
|
|
providers.append("MIGraphXExecutionProvider")
|
|
if "ROCMExecutionProvider" in available:
|
|
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
|
|
providers.append("CPUExecutionProvider")
|
|
logger.info(f"Available ORT providers: {available}")
|
|
return providers
|
|
|
|
|
|
def load_model() -> bool:
|
|
global _sessions, _tokenizer
|
|
|
|
import onnxruntime as ort
|
|
from huggingface_hub import hf_hub_download
|
|
from transformers import AutoTokenizer
|
|
|
|
from config import get_device_override
|
|
|
|
device = get_device_override() or detect_device()
|
|
providers = _get_providers(device)
|
|
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
|
|
|
|
try:
|
|
logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
|
|
_tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
|
|
|
|
module_names = [
|
|
"speech_encoder",
|
|
"embed_tokens",
|
|
"language_model",
|
|
"conditional_decoder",
|
|
]
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
for name in module_names:
|
|
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
|
|
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
|
|
|
|
logger.info(f"Downloading {onnx_file} ...")
|
|
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
|
|
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
|
|
|
|
logger.info(f"Creating session: {name}")
|
|
_sessions[name] = ort.InferenceSession(
|
|
onnx_path, sess_options=sess_options, providers=providers
|
|
)
|
|
|
|
logger.info("ONNX model loaded successfully")
|
|
return True
|
|
|
|
except Exception:
|
|
logger.exception("Failed to load ONNX model")
|
|
return False
|
|
|
|
|
|
def get_sample_rate() -> int:
|
|
return SAMPLE_RATE
|
|
|
|
|
|
def prepare_voice(audio_prompt_path: str) -> None:
|
|
"""Pre-compute and cache speech_encoder outputs for a reference audio file."""
|
|
if not _sessions or audio_prompt_path in _cond_cache:
|
|
return
|
|
|
|
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
|
try:
|
|
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
|
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
|
|
|
|
session = _sessions["speech_encoder"]
|
|
input_name = session.get_inputs()[0].name
|
|
output_names = [o.name for o in session.get_outputs()]
|
|
outputs = session.run(None, {input_name: audio})
|
|
|
|
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
|
|
logger.info("Voice conditionals cached")
|
|
except Exception:
|
|
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
|
|
|
|
|
|
def synthesize(
|
|
text: str,
|
|
audio_prompt_path: Optional[str] = None,
|
|
exaggeration: float = 0.5,
|
|
cfg_weight: float = 0.5,
|
|
temperature: float = 0.8,
|
|
seed: int = 0,
|
|
) -> Tuple[np.ndarray, int]:
|
|
"""Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
|
|
if not _sessions:
|
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
if audio_prompt_path is None:
|
|
raise RuntimeError(
|
|
"ONNX model requires a reference audio file for voice cloning."
|
|
)
|
|
|
|
if seed > 0:
|
|
np.random.seed(seed)
|
|
|
|
if audio_prompt_path not in _cond_cache:
|
|
prepare_voice(audio_prompt_path)
|
|
|
|
cond = _cond_cache[audio_prompt_path]
|
|
|
|
# Tokenize input text
|
|
input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
|
|
|
|
# Generate speech tokens autoregressively
|
|
speech_tokens = _generate(input_ids, cond)
|
|
|
|
# Prepend prompt token (speaker identity anchor) and append silence tokens
|
|
prompt_token = cond.get("prompt_token")
|
|
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
|
|
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
|
|
decoder_input = np.concatenate(parts, axis=1)
|
|
|
|
wav = _decode(decoder_input, cond)
|
|
return wav, SAMPLE_RATE
|
|
|
|
|
|
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
|
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
|
|
embed_sess = _sessions["embed_tokens"]
|
|
lm_sess = _sessions["language_model"]
|
|
|
|
embed_input_name = embed_sess.get_inputs()[0].name
|
|
|
|
# Discover KV cache slot names from session metadata
|
|
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
|
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
|
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
|
|
|
kv_dtype = np.float16
|
|
|
|
# Embed full text sequence
|
|
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
|
|
|
# Prepend conditioning embeddings from speech encoder
|
|
cond_emb = cond.get("cond_emb")
|
|
inputs_embeds = (
|
|
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
|
|
)
|
|
|
|
seq_len = inputs_embeds.shape[1]
|
|
attention_mask = np.ones((1, seq_len), dtype=np.int64)
|
|
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
|
|
|
|
# Empty KV cache to start
|
|
past_kv = {
|
|
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
|
|
for name in past_names
|
|
}
|
|
|
|
generated: list = []
|
|
|
|
for _ in range(MAX_NEW_TOKENS):
|
|
feed = {
|
|
"inputs_embeds": inputs_embeds.astype(kv_dtype),
|
|
"attention_mask": attention_mask,
|
|
"position_ids": position_ids,
|
|
}
|
|
feed.update(past_kv)
|
|
|
|
raw = lm_sess.run(None, feed)
|
|
out = dict(zip(lm_out_names, raw))
|
|
|
|
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
|
|
|
|
# Repetition penalty (greedy decoding)
|
|
for tok in set(generated):
|
|
if logits[tok] > 0:
|
|
logits[tok] /= REPETITION_PENALTY
|
|
else:
|
|
logits[tok] *= REPETITION_PENALTY
|
|
|
|
next_token = int(np.argmax(logits))
|
|
if next_token == STOP_SPEECH_TOKEN:
|
|
break
|
|
|
|
generated.append(next_token)
|
|
|
|
# Roll KV cache forward
|
|
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
|
|
|
|
# Embed only the new token for next step
|
|
inputs_embeds = embed_sess.run(
|
|
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
|
|
)[0]
|
|
new_pos = attention_mask.shape[1]
|
|
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
|
|
position_ids = np.array([[new_pos]], dtype=np.int64)
|
|
|
|
if not generated:
|
|
logger.warning("No speech tokens generated — returning start token")
|
|
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
|
|
|
|
return np.array([generated], dtype=np.int64)
|
|
|
|
|
|
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
|
"""Run conditional_decoder to produce a float32 waveform."""
|
|
dec_sess = _sessions["conditional_decoder"]
|
|
dec_inputs = dec_sess.get_inputs()
|
|
|
|
feed = {dec_inputs[0].name: speech_tokens}
|
|
if len(dec_inputs) > 1:
|
|
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
|
|
if len(dec_inputs) > 2:
|
|
feed[dec_inputs[2].name] = cond.get("speaker_features")
|
|
|
|
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
|
|
return wav.squeeze().astype(np.float32)
|