Switch to ONNX runtime with chatterbox-turbo-ONNX (fp16)
Some checks failed
Build ROCm Image / build (push) Has been cancelled
Some checks failed
Build ROCm Image / build (push) Has been cancelled
Replaces the PyTorch/chatterbox-tts stack with direct ONNX inference using ResembleAI/chatterbox-turbo-ONNX fp16 weights. - engine.py: full rewrite — ONNX sessions, autoregressive KV-cache LM loop, voice conditionals cache via speech_encoder outputs - wyoming_handler.py: remove torch dep, use np.asarray for audio bytes - requirements-rocm-init.txt: onnxruntime-rocm replaces torch wheels - requirements-rocm.txt: drop chatterbox/torch deps, keep audio utils - Dockerfile.rocm: remove chatterbox-tts install step Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,18 +17,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Step 1: Install ROCm-compatible PyTorch stack first.
|
||||
# This must happen before anything else to prevent pip from pulling CPU wheels.
|
||||
# Step 1: Install onnxruntime-rocm first so it claims the onnxruntime namespace
|
||||
# before any other package can pull in the CPU-only onnxruntime wheel.
|
||||
COPY requirements-rocm-init.txt .
|
||||
RUN pip3 install -r requirements-rocm-init.txt
|
||||
|
||||
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch).
|
||||
# Step 2: Install remaining dependencies.
|
||||
COPY requirements-rocm.txt .
|
||||
RUN pip3 install -r requirements-rocm.txt
|
||||
|
||||
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
|
||||
RUN pip3 install --no-deps chatterbox-tts
|
||||
|
||||
# Application source
|
||||
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
|
||||
|
||||
|
||||
274
engine.py
274
engine.py
@@ -1,84 +1,116 @@
|
||||
import logging
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
chatterbox_model = None
|
||||
_sample_rate = 24000
|
||||
_is_turbo = False
|
||||
ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
|
||||
PRECISION_SUFFIX = "_fp16"
|
||||
SAMPLE_RATE = 24000
|
||||
START_SPEECH_TOKEN = 6561
|
||||
STOP_SPEECH_TOKEN = 6562
|
||||
SILENCE_TOKEN = 4299
|
||||
NUM_KV_HEADS = 16
|
||||
HEAD_DIM = 64
|
||||
MAX_NEW_TOKENS = 1024
|
||||
REPETITION_PENALTY = 1.2
|
||||
|
||||
# Cache: voice file path → prepared conditionals object.
|
||||
# prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and
|
||||
# builds mel embeddings — expensive work that only depends on the reference
|
||||
# audio, not the text. Cache it so multi-chunk requests pay the cost once.
|
||||
_cond_cache: dict = {}
|
||||
# ONNX inference sessions keyed by module name
|
||||
_sessions: Dict = {}
|
||||
_tokenizer = None
|
||||
|
||||
|
||||
def _test_cuda() -> bool:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.zeros(1).cuda()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
# Cache: voice file path → speech_encoder outputs dict
|
||||
_cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
|
||||
|
||||
|
||||
def detect_device() -> str:
|
||||
return "cuda" if _test_cuda() else "cpu"
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers():
|
||||
return "rocm"
|
||||
except Exception:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _get_providers(device: str) -> list:
|
||||
if device in ("rocm", "cuda"):
|
||||
return [("ROCMExecutionProvider", {"device_id": 0}), "CPUExecutionProvider"]
|
||||
return ["CPUExecutionProvider"]
|
||||
|
||||
|
||||
def load_model() -> bool:
|
||||
global chatterbox_model, _sample_rate, _is_turbo
|
||||
global _sessions, _tokenizer
|
||||
|
||||
from config import get_model_repo_id, get_device_override
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from config import get_device_override
|
||||
|
||||
device = get_device_override() or detect_device()
|
||||
repo_id = get_model_repo_id()
|
||||
|
||||
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
||||
providers = _get_providers(device)
|
||||
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
|
||||
|
||||
try:
|
||||
if "turbo" in repo_id.lower():
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
||||
_is_turbo = True
|
||||
else:
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
||||
_is_turbo = False
|
||||
logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
|
||||
_tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
|
||||
|
||||
_sample_rate = 24000
|
||||
logger.info("Model loaded successfully")
|
||||
module_names = [
|
||||
"speech_encoder",
|
||||
"embed_tokens",
|
||||
"language_model",
|
||||
"conditional_decoder",
|
||||
]
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
for name in module_names:
|
||||
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
|
||||
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
|
||||
|
||||
logger.info(f"Downloading {onnx_file} ...")
|
||||
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
|
||||
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
|
||||
|
||||
logger.info(f"Creating session: {name}")
|
||||
_sessions[name] = ort.InferenceSession(
|
||||
onnx_path, sess_options=sess_options, providers=providers
|
||||
)
|
||||
|
||||
logger.info("ONNX model loaded successfully")
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to load model")
|
||||
logger.exception("Failed to load ONNX model")
|
||||
return False
|
||||
|
||||
|
||||
def get_sample_rate() -> int:
|
||||
return SAMPLE_RATE
|
||||
|
||||
|
||||
def prepare_voice(audio_prompt_path: str) -> None:
|
||||
"""
|
||||
Pre-compute and cache the voice conditionals for a reference audio file.
|
||||
Calling this once avoids repeating the s3tokenizer + voice encoder work
|
||||
on every synthesis chunk that uses the same voice.
|
||||
"""
|
||||
if chatterbox_model is None:
|
||||
"""Pre-compute and cache speech_encoder outputs for a reference audio file."""
|
||||
if not _sessions or audio_prompt_path in _cond_cache:
|
||||
return
|
||||
if audio_prompt_path in _cond_cache:
|
||||
return
|
||||
if not _is_turbo:
|
||||
return # only turbo exposes prepare_conditionals
|
||||
|
||||
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
||||
with torch.inference_mode():
|
||||
chatterbox_model.prepare_conditionals(audio_prompt_path)
|
||||
_cond_cache[audio_prompt_path] = chatterbox_model.conds
|
||||
try:
|
||||
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
||||
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
|
||||
|
||||
session = _sessions["speech_encoder"]
|
||||
input_name = session.get_inputs()[0].name
|
||||
output_names = [o.name for o in session.get_outputs()]
|
||||
outputs = session.run(None, {input_name: audio})
|
||||
|
||||
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
|
||||
logger.info("Voice conditionals cached")
|
||||
|
||||
|
||||
def get_sample_rate() -> int:
|
||||
return _sample_rate
|
||||
except Exception:
|
||||
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
|
||||
|
||||
|
||||
def synthesize(
|
||||
@@ -88,38 +120,128 @@ def synthesize(
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
seed: int = 0,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
if chatterbox_model is None:
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
|
||||
if not _sessions:
|
||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||
if audio_prompt_path is None:
|
||||
raise RuntimeError(
|
||||
"ONNX model requires a reference audio file for voice cloning."
|
||||
)
|
||||
|
||||
if seed > 0:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Restore cached conditionals so generate() skips prepare_conditionals.
|
||||
if audio_prompt_path and _is_turbo:
|
||||
if audio_prompt_path not in _cond_cache:
|
||||
prepare_voice(audio_prompt_path)
|
||||
chatterbox_model.conds = _cond_cache[audio_prompt_path]
|
||||
|
||||
kwargs: dict = {}
|
||||
# Don't pass audio_prompt_path — conds are already set above.
|
||||
# For non-turbo models there's no cache, pass path as normal.
|
||||
if audio_prompt_path and not _is_turbo:
|
||||
kwargs["audio_prompt_path"] = audio_prompt_path
|
||||
cond = _cond_cache[audio_prompt_path]
|
||||
|
||||
if _is_turbo:
|
||||
kwargs["temperature"] = temperature
|
||||
# Tokenize input text
|
||||
input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
|
||||
|
||||
# Generate speech tokens autoregressively
|
||||
speech_tokens = _generate(input_ids, cond)
|
||||
|
||||
# Prepend prompt token (speaker identity anchor) and append silence tokens
|
||||
prompt_token = cond.get("prompt_token")
|
||||
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
|
||||
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
|
||||
decoder_input = np.concatenate(parts, axis=1)
|
||||
|
||||
wav = _decode(decoder_input, cond)
|
||||
return wav, SAMPLE_RATE
|
||||
|
||||
|
||||
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
|
||||
embed_sess = _sessions["embed_tokens"]
|
||||
lm_sess = _sessions["language_model"]
|
||||
|
||||
embed_input_name = embed_sess.get_inputs()[0].name
|
||||
|
||||
# Discover KV cache slot names from session metadata
|
||||
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
||||
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
||||
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
||||
|
||||
kv_dtype = np.float16
|
||||
|
||||
# Embed full text sequence
|
||||
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
||||
|
||||
# Prepend conditioning embeddings from speech encoder
|
||||
cond_emb = cond.get("cond_emb")
|
||||
inputs_embeds = (
|
||||
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
|
||||
)
|
||||
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
attention_mask = np.ones((1, seq_len), dtype=np.int64)
|
||||
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
|
||||
|
||||
# Empty KV cache to start
|
||||
past_kv = {
|
||||
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
|
||||
for name in past_names
|
||||
}
|
||||
|
||||
generated: list = []
|
||||
|
||||
for _ in range(MAX_NEW_TOKENS):
|
||||
feed = {
|
||||
"inputs_embeds": inputs_embeds.astype(kv_dtype),
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
feed.update(past_kv)
|
||||
|
||||
raw = lm_sess.run(None, feed)
|
||||
out = dict(zip(lm_out_names, raw))
|
||||
|
||||
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
|
||||
|
||||
# Repetition penalty (greedy decoding)
|
||||
for tok in set(generated):
|
||||
if logits[tok] > 0:
|
||||
logits[tok] /= REPETITION_PENALTY
|
||||
else:
|
||||
kwargs["exaggeration"] = exaggeration
|
||||
kwargs["cfg_weight"] = cfg_weight
|
||||
logits[tok] *= REPETITION_PENALTY
|
||||
|
||||
with torch.inference_mode():
|
||||
wav = chatterbox_model.generate(text=text, **kwargs)
|
||||
next_token = int(np.argmax(logits))
|
||||
if next_token == STOP_SPEECH_TOKEN:
|
||||
break
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
generated.append(next_token)
|
||||
|
||||
return wav, _sample_rate
|
||||
# Roll KV cache forward
|
||||
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
|
||||
|
||||
# Embed only the new token for next step
|
||||
inputs_embeds = embed_sess.run(
|
||||
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
|
||||
)[0]
|
||||
new_pos = attention_mask.shape[1]
|
||||
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
|
||||
position_ids = np.array([[new_pos]], dtype=np.int64)
|
||||
|
||||
if not generated:
|
||||
logger.warning("No speech tokens generated — returning start token")
|
||||
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
|
||||
|
||||
return np.array([generated], dtype=np.int64)
|
||||
|
||||
|
||||
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
"""Run conditional_decoder to produce a float32 waveform."""
|
||||
dec_sess = _sessions["conditional_decoder"]
|
||||
dec_inputs = dec_sess.get_inputs()
|
||||
|
||||
feed = {dec_inputs[0].name: speech_tokens}
|
||||
if len(dec_inputs) > 1:
|
||||
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
|
||||
if len(dec_inputs) > 2:
|
||||
feed[dec_inputs[2].name] = cond.get("speaker_features")
|
||||
|
||||
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
|
||||
return wav.squeeze().astype(np.float32)
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
--index-url https://download.pytorch.org/whl/rocm6.1
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
torchvision==0.20.1
|
||||
pytorch_triton_rocm==3.1.0
|
||||
onnxruntime-rocm
|
||||
|
||||
@@ -2,21 +2,10 @@
|
||||
numpy>=1.24.0,<2.0.0
|
||||
soundfile
|
||||
librosa==0.11.0
|
||||
pyloudnorm
|
||||
|
||||
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
|
||||
transformers==4.46.3
|
||||
diffusers==0.29.0
|
||||
safetensors>=0.4.1
|
||||
# ONNX model dependencies
|
||||
transformers>=4.40.0
|
||||
huggingface-hub
|
||||
omegaconf
|
||||
|
||||
# Chatterbox dependencies (installed separately since chatterbox uses --no-deps)
|
||||
conformer==0.3.2
|
||||
s3tokenizer==0.3.0
|
||||
spacy-pkuseg
|
||||
pykakasi==2.3.0
|
||||
resemble-perth==1.0.1
|
||||
|
||||
# Wyoming protocol
|
||||
wyoming>=1.5.4
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Describe, Info
|
||||
@@ -151,7 +153,7 @@ class ChatterboxWyomingHandler(AsyncEventHandler):
|
||||
continue
|
||||
|
||||
audio_bytes = (
|
||||
audio_tensor.cpu().numpy().squeeze() * 32767
|
||||
np.asarray(audio_tensor).squeeze() * 32767
|
||||
).astype("int16").tobytes()
|
||||
|
||||
if first_chunk:
|
||||
|
||||
Reference in New Issue
Block a user