Switch to ONNX runtime with chatterbox-turbo-ONNX (fp16)
Some checks failed
Build ROCm Image / build (push) Has been cancelled
Some checks failed
Build ROCm Image / build (push) Has been cancelled
Replaces the PyTorch/chatterbox-tts stack with direct ONNX inference using ResembleAI/chatterbox-turbo-ONNX fp16 weights. - engine.py: full rewrite — ONNX sessions, autoregressive KV-cache LM loop, voice conditionals cache via speech_encoder outputs - wyoming_handler.py: remove torch dep, use np.asarray for audio bytes - requirements-rocm-init.txt: onnxruntime-rocm replaces torch wheels - requirements-rocm.txt: drop chatterbox/torch deps, keep audio utils - Dockerfile.rocm: remove chatterbox-tts install step Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,18 +17,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Step 1: Install ROCm-compatible PyTorch stack first.
|
# Step 1: Install onnxruntime-rocm first so it claims the onnxruntime namespace
|
||||||
# This must happen before anything else to prevent pip from pulling CPU wheels.
|
# before any other package can pull in the CPU-only onnxruntime wheel.
|
||||||
COPY requirements-rocm-init.txt .
|
COPY requirements-rocm-init.txt .
|
||||||
RUN pip3 install -r requirements-rocm-init.txt
|
RUN pip3 install -r requirements-rocm-init.txt
|
||||||
|
|
||||||
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch).
|
# Step 2: Install remaining dependencies.
|
||||||
COPY requirements-rocm.txt .
|
COPY requirements-rocm.txt .
|
||||||
RUN pip3 install -r requirements-rocm.txt
|
RUN pip3 install -r requirements-rocm.txt
|
||||||
|
|
||||||
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
|
|
||||||
RUN pip3 install --no-deps chatterbox-tts
|
|
||||||
|
|
||||||
# Application source
|
# Application source
|
||||||
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
|
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
|
||||||
|
|
||||||
|
|||||||
274
engine.py
274
engine.py
@@ -1,84 +1,116 @@
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
from typing import Dict, Optional, Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
chatterbox_model = None
|
ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX"
|
||||||
_sample_rate = 24000
|
PRECISION_SUFFIX = "_fp16"
|
||||||
_is_turbo = False
|
SAMPLE_RATE = 24000
|
||||||
|
START_SPEECH_TOKEN = 6561
|
||||||
|
STOP_SPEECH_TOKEN = 6562
|
||||||
|
SILENCE_TOKEN = 4299
|
||||||
|
NUM_KV_HEADS = 16
|
||||||
|
HEAD_DIM = 64
|
||||||
|
MAX_NEW_TOKENS = 1024
|
||||||
|
REPETITION_PENALTY = 1.2
|
||||||
|
|
||||||
# Cache: voice file path → prepared conditionals object.
|
# ONNX inference sessions keyed by module name
|
||||||
# prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and
|
_sessions: Dict = {}
|
||||||
# builds mel embeddings — expensive work that only depends on the reference
|
_tokenizer = None
|
||||||
# audio, not the text. Cache it so multi-chunk requests pay the cost once.
|
|
||||||
_cond_cache: dict = {}
|
|
||||||
|
|
||||||
|
# Cache: voice file path → speech_encoder outputs dict
|
||||||
def _test_cuda() -> bool:
|
_cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.zeros(1).cuda()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> str:
|
def detect_device() -> str:
|
||||||
return "cuda" if _test_cuda() else "cpu"
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
if "ROCMExecutionProvider" in ort.get_available_providers():
|
||||||
|
return "rocm"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_providers(device: str) -> list:
|
||||||
|
if device in ("rocm", "cuda"):
|
||||||
|
return [("ROCMExecutionProvider", {"device_id": 0}), "CPUExecutionProvider"]
|
||||||
|
return ["CPUExecutionProvider"]
|
||||||
|
|
||||||
|
|
||||||
def load_model() -> bool:
|
def load_model() -> bool:
|
||||||
global chatterbox_model, _sample_rate, _is_turbo
|
global _sessions, _tokenizer
|
||||||
|
|
||||||
from config import get_model_repo_id, get_device_override
|
import onnxruntime as ort
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from config import get_device_override
|
||||||
|
|
||||||
device = get_device_override() or detect_device()
|
device = get_device_override() or detect_device()
|
||||||
repo_id = get_model_repo_id()
|
providers = _get_providers(device)
|
||||||
|
logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}")
|
||||||
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "turbo" in repo_id.lower():
|
logger.info(f"Loading tokenizer from '{ONNX_REPO}'")
|
||||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
_tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO)
|
||||||
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
|
||||||
_is_turbo = True
|
|
||||||
else:
|
|
||||||
from chatterbox.tts import ChatterboxTTS
|
|
||||||
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
|
||||||
_is_turbo = False
|
|
||||||
|
|
||||||
_sample_rate = 24000
|
module_names = [
|
||||||
logger.info("Model loaded successfully")
|
"speech_encoder",
|
||||||
|
"embed_tokens",
|
||||||
|
"language_model",
|
||||||
|
"conditional_decoder",
|
||||||
|
]
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|
||||||
|
for name in module_names:
|
||||||
|
onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx"
|
||||||
|
data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data"
|
||||||
|
|
||||||
|
logger.info(f"Downloading {onnx_file} ...")
|
||||||
|
onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file)
|
||||||
|
hf_hub_download(repo_id=ONNX_REPO, filename=data_file)
|
||||||
|
|
||||||
|
logger.info(f"Creating session: {name}")
|
||||||
|
_sessions[name] = ort.InferenceSession(
|
||||||
|
onnx_path, sess_options=sess_options, providers=providers
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("ONNX model loaded successfully")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to load model")
|
logger.exception("Failed to load ONNX model")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_rate() -> int:
|
||||||
|
return SAMPLE_RATE
|
||||||
|
|
||||||
|
|
||||||
def prepare_voice(audio_prompt_path: str) -> None:
|
def prepare_voice(audio_prompt_path: str) -> None:
|
||||||
"""
|
"""Pre-compute and cache speech_encoder outputs for a reference audio file."""
|
||||||
Pre-compute and cache the voice conditionals for a reference audio file.
|
if not _sessions or audio_prompt_path in _cond_cache:
|
||||||
Calling this once avoids repeating the s3tokenizer + voice encoder work
|
|
||||||
on every synthesis chunk that uses the same voice.
|
|
||||||
"""
|
|
||||||
if chatterbox_model is None:
|
|
||||||
return
|
return
|
||||||
if audio_prompt_path in _cond_cache:
|
|
||||||
return
|
|
||||||
if not _is_turbo:
|
|
||||||
return # only turbo exposes prepare_conditionals
|
|
||||||
|
|
||||||
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
||||||
with torch.inference_mode():
|
try:
|
||||||
chatterbox_model.prepare_conditionals(audio_prompt_path)
|
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
||||||
_cond_cache[audio_prompt_path] = chatterbox_model.conds
|
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
|
||||||
|
|
||||||
|
session = _sessions["speech_encoder"]
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
output_names = [o.name for o in session.get_outputs()]
|
||||||
|
outputs = session.run(None, {input_name: audio})
|
||||||
|
|
||||||
|
_cond_cache[audio_prompt_path] = dict(zip(output_names, outputs))
|
||||||
logger.info("Voice conditionals cached")
|
logger.info("Voice conditionals cached")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to prepare voice '{audio_prompt_path}'")
|
||||||
def get_sample_rate() -> int:
|
|
||||||
return _sample_rate
|
|
||||||
|
|
||||||
|
|
||||||
def synthesize(
|
def synthesize(
|
||||||
@@ -88,38 +120,128 @@ def synthesize(
|
|||||||
cfg_weight: float = 0.5,
|
cfg_weight: float = 0.5,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> Tuple[torch.Tensor, int]:
|
) -> Tuple[np.ndarray, int]:
|
||||||
if chatterbox_model is None:
|
"""Synthesize speech. Returns (waveform_float32_1d, sample_rate)."""
|
||||||
|
if not _sessions:
|
||||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||||
|
if audio_prompt_path is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ONNX model requires a reference audio file for voice cloning."
|
||||||
|
)
|
||||||
|
|
||||||
if seed > 0:
|
if seed > 0:
|
||||||
torch.manual_seed(seed)
|
np.random.seed(seed)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
# Restore cached conditionals so generate() skips prepare_conditionals.
|
|
||||||
if audio_prompt_path and _is_turbo:
|
|
||||||
if audio_prompt_path not in _cond_cache:
|
if audio_prompt_path not in _cond_cache:
|
||||||
prepare_voice(audio_prompt_path)
|
prepare_voice(audio_prompt_path)
|
||||||
chatterbox_model.conds = _cond_cache[audio_prompt_path]
|
|
||||||
|
|
||||||
kwargs: dict = {}
|
cond = _cond_cache[audio_prompt_path]
|
||||||
# Don't pass audio_prompt_path — conds are already set above.
|
|
||||||
# For non-turbo models there's no cache, pass path as normal.
|
|
||||||
if audio_prompt_path and not _is_turbo:
|
|
||||||
kwargs["audio_prompt_path"] = audio_prompt_path
|
|
||||||
|
|
||||||
if _is_turbo:
|
# Tokenize input text
|
||||||
kwargs["temperature"] = temperature
|
input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
|
||||||
|
|
||||||
|
# Generate speech tokens autoregressively
|
||||||
|
speech_tokens = _generate(input_ids, cond)
|
||||||
|
|
||||||
|
# Prepend prompt token (speaker identity anchor) and append silence tokens
|
||||||
|
prompt_token = cond.get("prompt_token")
|
||||||
|
silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64)
|
||||||
|
parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence]
|
||||||
|
decoder_input = np.concatenate(parts, axis=1)
|
||||||
|
|
||||||
|
wav = _decode(decoder_input, cond)
|
||||||
|
return wav, SAMPLE_RATE
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||||
|
"""Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T]."""
|
||||||
|
embed_sess = _sessions["embed_tokens"]
|
||||||
|
lm_sess = _sessions["language_model"]
|
||||||
|
|
||||||
|
embed_input_name = embed_sess.get_inputs()[0].name
|
||||||
|
|
||||||
|
# Discover KV cache slot names from session metadata
|
||||||
|
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
||||||
|
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
||||||
|
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
||||||
|
|
||||||
|
kv_dtype = np.float16
|
||||||
|
|
||||||
|
# Embed full text sequence
|
||||||
|
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
||||||
|
|
||||||
|
# Prepend conditioning embeddings from speech encoder
|
||||||
|
cond_emb = cond.get("cond_emb")
|
||||||
|
inputs_embeds = (
|
||||||
|
np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
attention_mask = np.ones((1, seq_len), dtype=np.int64)
|
||||||
|
position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :]
|
||||||
|
|
||||||
|
# Empty KV cache to start
|
||||||
|
past_kv = {
|
||||||
|
name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype)
|
||||||
|
for name in past_names
|
||||||
|
}
|
||||||
|
|
||||||
|
generated: list = []
|
||||||
|
|
||||||
|
for _ in range(MAX_NEW_TOKENS):
|
||||||
|
feed = {
|
||||||
|
"inputs_embeds": inputs_embeds.astype(kv_dtype),
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
feed.update(past_kv)
|
||||||
|
|
||||||
|
raw = lm_sess.run(None, feed)
|
||||||
|
out = dict(zip(lm_out_names, raw))
|
||||||
|
|
||||||
|
logits = out["logits"][0, -1, :].astype(np.float32) # [vocab]
|
||||||
|
|
||||||
|
# Repetition penalty (greedy decoding)
|
||||||
|
for tok in set(generated):
|
||||||
|
if logits[tok] > 0:
|
||||||
|
logits[tok] /= REPETITION_PENALTY
|
||||||
else:
|
else:
|
||||||
kwargs["exaggeration"] = exaggeration
|
logits[tok] *= REPETITION_PENALTY
|
||||||
kwargs["cfg_weight"] = cfg_weight
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
next_token = int(np.argmax(logits))
|
||||||
wav = chatterbox_model.generate(text=text, **kwargs)
|
if next_token == STOP_SPEECH_TOKEN:
|
||||||
|
break
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
generated.append(next_token)
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return wav, _sample_rate
|
# Roll KV cache forward
|
||||||
|
past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)}
|
||||||
|
|
||||||
|
# Embed only the new token for next step
|
||||||
|
inputs_embeds = embed_sess.run(
|
||||||
|
None, {embed_input_name: np.array([[next_token]], dtype=np.int64)}
|
||||||
|
)[0]
|
||||||
|
new_pos = attention_mask.shape[1]
|
||||||
|
attention_mask = np.ones((1, new_pos + 1), dtype=np.int64)
|
||||||
|
position_ids = np.array([[new_pos]], dtype=np.int64)
|
||||||
|
|
||||||
|
if not generated:
|
||||||
|
logger.warning("No speech tokens generated — returning start token")
|
||||||
|
return np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
|
||||||
|
|
||||||
|
return np.array([generated], dtype=np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||||
|
"""Run conditional_decoder to produce a float32 waveform."""
|
||||||
|
dec_sess = _sessions["conditional_decoder"]
|
||||||
|
dec_inputs = dec_sess.get_inputs()
|
||||||
|
|
||||||
|
feed = {dec_inputs[0].name: speech_tokens}
|
||||||
|
if len(dec_inputs) > 1:
|
||||||
|
feed[dec_inputs[1].name] = cond.get("speaker_embeddings")
|
||||||
|
if len(dec_inputs) > 2:
|
||||||
|
feed[dec_inputs[2].name] = cond.get("speaker_features")
|
||||||
|
|
||||||
|
wav = dec_sess.run(None, feed)[0] # [1, T] or [T]
|
||||||
|
return wav.squeeze().astype(np.float32)
|
||||||
|
|||||||
@@ -1,5 +1 @@
|
|||||||
--index-url https://download.pytorch.org/whl/rocm6.1
|
onnxruntime-rocm
|
||||||
torch==2.5.1
|
|
||||||
torchaudio==2.5.1
|
|
||||||
torchvision==0.20.1
|
|
||||||
pytorch_triton_rocm==3.1.0
|
|
||||||
|
|||||||
@@ -2,21 +2,10 @@
|
|||||||
numpy>=1.24.0,<2.0.0
|
numpy>=1.24.0,<2.0.0
|
||||||
soundfile
|
soundfile
|
||||||
librosa==0.11.0
|
librosa==0.11.0
|
||||||
pyloudnorm
|
|
||||||
|
|
||||||
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
|
# ONNX model dependencies
|
||||||
transformers==4.46.3
|
transformers>=4.40.0
|
||||||
diffusers==0.29.0
|
|
||||||
safetensors>=0.4.1
|
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
omegaconf
|
|
||||||
|
|
||||||
# Chatterbox dependencies (installed separately since chatterbox uses --no-deps)
|
|
||||||
conformer==0.3.2
|
|
||||||
s3tokenizer==0.3.0
|
|
||||||
spacy-pkuseg
|
|
||||||
pykakasi==2.3.0
|
|
||||||
resemble-perth==1.0.1
|
|
||||||
|
|
||||||
# Wyoming protocol
|
# Wyoming protocol
|
||||||
wyoming>=1.5.4
|
wyoming>=1.5.4
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
from wyoming.event import Event
|
from wyoming.event import Event
|
||||||
from wyoming.info import Describe, Info
|
from wyoming.info import Describe, Info
|
||||||
@@ -151,7 +153,7 @@ class ChatterboxWyomingHandler(AsyncEventHandler):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
audio_bytes = (
|
audio_bytes = (
|
||||||
audio_tensor.cpu().numpy().squeeze() * 32767
|
np.asarray(audio_tensor).squeeze() * 32767
|
||||||
).astype("int16").tobytes()
|
).astype("int16").tobytes()
|
||||||
|
|
||||||
if first_chunk:
|
if first_chunk:
|
||||||
|
|||||||
Reference in New Issue
Block a user