Switch to ONNX runtime with chatterbox-turbo-ONNX (fp16)
Some checks failed
Build ROCm Image / build (push) Has been cancelled

Replaces the PyTorch/chatterbox-tts stack with direct ONNX inference
using ResembleAI/chatterbox-turbo-ONNX fp16 weights.

- engine.py: full rewrite — ONNX sessions, autoregressive KV-cache LM
  loop, voice conditionals cache via speech_encoder outputs
- wyoming_handler.py: remove torch dep, use np.asarray for audio bytes
- requirements-rocm-init.txt: onnxruntime-rocm replaces torch wheels
- requirements-rocm.txt: drop chatterbox/torch deps, keep audio utils
- Dockerfile.rocm: remove chatterbox-tts install step

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 19:08:26 -04:00
parent 4c79a82428
commit 2b1398109d
5 changed files with 209 additions and 103 deletions

View File

@@ -17,18 +17,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
WORKDIR /app WORKDIR /app
# Step 1: Install ROCm-compatible PyTorch stack first. # Step 1: Install onnxruntime-rocm first so it claims the onnxruntime namespace
# This must happen before anything else to prevent pip from pulling CPU wheels. # before any other package can pull in the CPU-only onnxruntime wheel.
COPY requirements-rocm-init.txt . COPY requirements-rocm-init.txt .
RUN pip3 install -r requirements-rocm-init.txt RUN pip3 install -r requirements-rocm-init.txt
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch). # Step 2: Install remaining dependencies.
COPY requirements-rocm.txt . COPY requirements-rocm.txt .
RUN pip3 install -r requirements-rocm.txt RUN pip3 install -r requirements-rocm.txt
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
RUN pip3 install --no-deps chatterbox-tts
# Application source # Application source
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./ COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./

274
engine.py
View File

@@ -1,84 +1,116 @@
import logging import logging
import torch from typing import Dict, Optional, Tuple
from typing import Optional, Tuple
import librosa
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
chatterbox_model = None ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
_sample_rate = 24000 PRECISION_SUFFIX = "_fp16"
_is_turbo = False SAMPLE_RATE = 24000
START_SPEECH_TOKEN = 6561
STOP_SPEECH_TOKEN = 6562
SILENCE_TOKEN = 4299
NUM_KV_HEADS = 16
HEAD_DIM = 64
MAX_NEW_TOKENS = 1024
REPETITION_PENALTY = 1.2
# Cache: voice file path → prepared conditionals object. # ONNX inference sessions keyed by module name
# prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and _sessions: Dict = {}
# builds mel embeddings — expensive work that only depends on the reference _tokenizer = None
# audio, not the text. Cache it so multi-chunk requests pay the cost once.
_cond_cache: dict = {}
# Cache: voice file path → speech_encoder outputs dict
def _test_cuda() -> bool: _cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str: def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu" try:
import onnxruntime as ort
if "ROCMExecutionProvider" in ort.get_available_providers():
return "rocm"
except Exception:
pass
return "cpu"
def _get_providers(device: str) -> list:
if device in ("rocm", "cuda"):
return [("ROCMExecutionProvider", {"device_id": 0}), "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
def load_model() -> bool: def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo global _sessions, _tokenizer
from config import get_model_repo_id, get_device_override import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from config import get_device_override
device = get_device_override() or detect_device() device = get_device_override() or detect_device()
repo_id = get_model_repo_id() providers = _get_providers(device)
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try: try:
if "turbo" in repo_id.lower(): logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
from chatterbox.tts_turbo import ChatterboxTurboTTS _tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000 module_names = [
logger.info("Model loaded successfully") "speech_encoder",
"embed_tokens",
"language_model",
"conditional_decoder",
]
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
for name in module_names:
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
logger.info(f"Downloading {onnx_file} ...")
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
logger.info(f"Creating session: {name}")
_sessions[name] = ort.InferenceSession(
onnx_path, sess_options=sess_options, providers=providers
)
logger.info("ONNX model loaded successfully")
return True return True
except Exception: except Exception:
logger.exception("Failed to load model") logger.exception("Failed to load ONNX model")
return False return False
def get_sample_rate() -> int:
return SAMPLE_RATE
def prepare_voice(audio_prompt_path: str) -> None: def prepare_voice(audio_prompt_path: str) -> None:
""" """Pre-compute and cache speech_encoder outputs for a reference audio file."""
Pre-compute and cache the voice conditionals for a reference audio file. if not _sessions or audio_prompt_path in _cond_cache:
Calling this once avoids repeating the s3tokenizer + voice encoder work
on every synthesis chunk that uses the same voice.
"""
if chatterbox_model is None:
return return
if audio_prompt_path in _cond_cache:
return
if not _is_turbo:
return # only turbo exposes prepare_conditionals
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
with torch.inference_mode(): try:
chatterbox_model.prepare_conditionals(audio_prompt_path) audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
_cond_cache[audio_prompt_path] = chatterbox_model.conds audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
session = _sessions["speech_encoder"]
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
outputs = session.run(None, {input_name: audio})
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
logger.info("Voice conditionals cached") logger.info("Voice conditionals cached")
except Exception:
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
def get_sample_rate() -> int:
return _sample_rate
def synthesize( def synthesize(
@@ -88,38 +120,128 @@ def synthesize(
cfg_weight: float = 0.5, cfg_weight: float = 0.5,
temperature: float = 0.8, temperature: float = 0.8,
seed: int = 0, seed: int = 0,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[np.ndarray, int]:
if chatterbox_model is None: """Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
if not _sessions:
raise RuntimeError("Model not loaded. Call load_model() first.") raise RuntimeError("Model not loaded. Call load_model() first.")
if audio_prompt_path is None:
raise RuntimeError(
"ONNX model requires a reference audio file for voice cloning."
)
if seed > 0: if seed > 0:
torch.manual_seed(seed) np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Restore cached conditionals so generate() skips prepare_conditionals.
if audio_prompt_path and _is_turbo:
if audio_prompt_path not in _cond_cache: if audio_prompt_path not in _cond_cache:
prepare_voice(audio_prompt_path) prepare_voice(audio_prompt_path)
chatterbox_model.conds = _cond_cache[audio_prompt_path]
kwargs: dict = {} cond = _cond_cache[audio_prompt_path]
# Don't pass audio_prompt_path — conds are already set above.
# For non-turbo models there's no cache, pass path as normal.
if audio_prompt_path and not _is_turbo:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo: # Tokenize input text
kwargs["temperature"] = temperature input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
# Generate speech tokens autoregressively
speech_tokens = _generate(input_ids, cond)
# Prepend prompt token (speaker identity anchor) and append silence tokens
prompt_token = cond.get("prompt_token")
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
decoder_input = np.concatenate(parts, axis=1)
wav = _decode(decoder_input, cond)
return wav, SAMPLE_RATE
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
embed_sess = _sessions["embed_tokens"]
lm_sess = _sessions["language_model"]
embed_input_name = embed_sess.get_inputs()[0].name
# Discover KV cache slot names from session metadata
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
lm_out_names = [o.name for o in lm_sess.get_outputs()]
kv_dtype = np.float16
# Embed full text sequence
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
# Prepend conditioning embeddings from speech encoder
cond_emb = cond.get("cond_emb")
inputs_embeds = (
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
)
seq_len = inputs_embeds.shape[1]
attention_mask = np.ones((1, seq_len), dtype=np.int64)
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
# Empty KV cache to start
past_kv = {
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
for name in past_names
}
generated: list = []
for _ in range(MAX_NEW_TOKENS):
feed = {
"inputs_embeds": inputs_embeds.astype(kv_dtype),
"attention_mask": attention_mask,
"position_ids": position_ids,
}
feed.update(past_kv)
raw = lm_sess.run(None, feed)
out = dict(zip(lm_out_names, raw))
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
# Repetition penalty (greedy decoding)
for tok in set(generated):
if logits[tok] > 0:
logits[tok] /= REPETITION_PENALTY
else: else:
kwargs["exaggeration"] = exaggeration logits[tok] *= REPETITION_PENALTY
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode(): next_token = int(np.argmax(logits))
wav = chatterbox_model.generate(text=text, **kwargs) if next_token == STOP_SPEECH_TOKEN:
break
if torch.cuda.is_available(): generated.append(next_token)
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate # Roll KV cache forward
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
# Embed only the new token for next step
inputs_embeds = embed_sess.run(
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
)[0]
new_pos = attention_mask.shape[1]
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
position_ids = np.array([[new_pos]], dtype=np.int64)
if not generated:
logger.warning("No speech tokens generated — returning start token")
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
return np.array([generated], dtype=np.int64)
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
"""Run conditional_decoder to produce a float32 waveform."""
dec_sess = _sessions["conditional_decoder"]
dec_inputs = dec_sess.get_inputs()
feed = {dec_inputs[0].name: speech_tokens}
if len(dec_inputs) > 1:
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
if len(dec_inputs) > 2:
feed[dec_inputs[2].name] = cond.get("speaker_features")
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
return wav.squeeze().astype(np.float32)

View File

@@ -1,5 +1 @@
--index-url https://download.pytorch.org/whl/rocm6.1 onnxruntime-rocm
torch==2.5.1
torchaudio==2.5.1
torchvision==0.20.1
pytorch_triton_rocm==3.1.0

View File

@@ -2,21 +2,10 @@
numpy>=1.24.0,<2.0.0 numpy>=1.24.0,<2.0.0
soundfile soundfile
librosa==0.11.0 librosa==0.11.0
pyloudnorm
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch) # ONNX model dependencies
transformers==4.46.3 transformers>=4.40.0
diffusers==0.29.0
safetensors>=0.4.1
huggingface-hub huggingface-hub
omegaconf
# Chatterbox dependencies (installed separately since chatterbox uses --no-deps)
conformer==0.3.2
s3tokenizer==0.3.0
spacy-pkuseg
pykakasi==2.3.0
resemble-perth==1.0.1
# Wyoming protocol # Wyoming protocol
wyoming>=1.5.4 wyoming>=1.5.4

View File

@@ -3,6 +3,8 @@ import logging
import time import time
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np
from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event from wyoming.event import Event
from wyoming.info import Describe, Info from wyoming.info import Describe, Info
@@ -151,7 +153,7 @@ class ChatterboxWyomingHandler(AsyncEventHandler):
continue continue
audio_bytes = ( audio_bytes = (
audio_tensor.cpu().numpy().squeeze() * 32767 np.asarray(audio_tensor).squeeze() * 32767
).astype("int16").tobytes() ).astype("int16").tobytes()
if first_chunk: if first_chunk: