diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 30ba3b2..257b4c4 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -17,18 +17,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /app -# Step 1: Install ROCm-compatible PyTorch stack first. -# This must happen before anything else to prevent pip from pulling CPU wheels. +# Step 1: Install onnxruntime-rocm first so it claims the onnxruntime namespace +# before any other package can pull in the CPU-only onnxruntime wheel. COPY requirements-rocm-init.txt . RUN pip3 install -r requirements-rocm-init.txt -# Step 2: Install remaining dependencies (pinned to avoid overwriting torch). +# Step 2: Install remaining dependencies. COPY requirements-rocm.txt . RUN pip3 install -r requirements-rocm.txt -# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch. -RUN pip3 install --no-deps chatterbox-tts - # Application source COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./ diff --git a/engine.py b/engine.py index 766ed87..56b5a4b 100644 --- a/engine.py +++ b/engine.py @@ -1,84 +1,116 @@ import logging -import torch -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple + +import librosa +import numpy as np logger = logging.getLogger(__name__) -chatterbox_model = None -_sample_rate = 24000 -_is_turbo = False +ONNX_REPO = "ResembleAI/chatterbox-turbo-ONNX" +PRECISION_SUFFIX = "_fp16" +SAMPLE_RATE = 24000 +START_SPEECH_TOKEN = 6561 +STOP_SPEECH_TOKEN = 6562 +SILENCE_TOKEN = 4299 +NUM_KV_HEADS = 16 +HEAD_DIM = 64 +MAX_NEW_TOKENS = 1024 +REPETITION_PENALTY = 1.2 -# Cache: voice file path → prepared conditionals object. -# prepare_conditionals loads audio, runs s3tokenizer + voice encoder, and -# builds mel embeddings — expensive work that only depends on the reference -# audio, not the text. Cache it so multi-chunk requests pay the cost once. -_cond_cache: dict = {} +# ONNX inference sessions keyed by module name +_sessions: Dict = {} +_tokenizer = None - -def _test_cuda() -> bool: - try: - if torch.cuda.is_available(): - torch.zeros(1).cuda() - return True - except Exception: - pass - return False +# Cache: voice file path → speech_encoder outputs dict +_cond_cache: Dict[str, Dict[str, np.ndarray]] = {} def detect_device() -> str: - return "cuda" if _test_cuda() else "cpu" + try: + import onnxruntime as ort + if "ROCMExecutionProvider" in ort.get_available_providers(): + return "rocm" + except Exception: + pass + return "cpu" + + +def _get_providers(device: str) -> list: + if device in ("rocm", "cuda"): + return [("ROCMExecutionProvider", {"device_id": 0}), "CPUExecutionProvider"] + return ["CPUExecutionProvider"] def load_model() -> bool: - global chatterbox_model, _sample_rate, _is_turbo + global _sessions, _tokenizer - from config import get_model_repo_id, get_device_override + import onnxruntime as ort + from huggingface_hub import hf_hub_download + from transformers import AutoTokenizer + + from config import get_device_override device = get_device_override() or detect_device() - repo_id = get_model_repo_id() - - logger.info(f"Loading model '{repo_id}' on device '{device}'") + providers = _get_providers(device) + logger.info(f"Loading ONNX model (fp16) on device='{device}', providers={providers}") try: - if "turbo" in repo_id.lower(): - from chatterbox.tts_turbo import ChatterboxTurboTTS - chatterbox_model = ChatterboxTurboTTS.from_pretrained(device) - _is_turbo = True - else: - from chatterbox.tts import ChatterboxTTS - chatterbox_model = ChatterboxTTS.from_pretrained(device) - _is_turbo = False + logger.info(f"Loading tokenizer from '{ONNX_REPO}'") + _tokenizer = AutoTokenizer.from_pretrained(ONNX_REPO) - _sample_rate = 24000 - logger.info("Model loaded successfully") + module_names = [ + "speech_encoder", + "embed_tokens", + "language_model", + "conditional_decoder", + ] + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + for name in module_names: + onnx_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx" + data_file = f"onnx/{name}{PRECISION_SUFFIX}.onnx_data" + + logger.info(f"Downloading {onnx_file} ...") + onnx_path = hf_hub_download(repo_id=ONNX_REPO, filename=onnx_file) + hf_hub_download(repo_id=ONNX_REPO, filename=data_file) + + logger.info(f"Creating session: {name}") + _sessions[name] = ort.InferenceSession( + onnx_path, sess_options=sess_options, providers=providers + ) + + logger.info("ONNX model loaded successfully") return True + except Exception: - logger.exception("Failed to load model") + logger.exception("Failed to load ONNX model") return False +def get_sample_rate() -> int: + return SAMPLE_RATE + + def prepare_voice(audio_prompt_path: str) -> None: - """ - Pre-compute and cache the voice conditionals for a reference audio file. - Calling this once avoids repeating the s3tokenizer + voice encoder work - on every synthesis chunk that uses the same voice. - """ - if chatterbox_model is None: + """Pre-compute and cache speech_encoder outputs for a reference audio file.""" + if not _sessions or audio_prompt_path in _cond_cache: return - if audio_prompt_path in _cond_cache: - return - if not _is_turbo: - return # only turbo exposes prepare_conditionals logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") - with torch.inference_mode(): - chatterbox_model.prepare_conditionals(audio_prompt_path) - _cond_cache[audio_prompt_path] = chatterbox_model.conds - logger.info("Voice conditionals cached") + try: + audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True) + audio = audio[np.newaxis, :].astype(np.float16) # [1, T] + session = _sessions["speech_encoder"] + input_name = session.get_inputs()[0].name + output_names = [o.name for o in session.get_outputs()] + outputs = session.run(None, {input_name: audio}) -def get_sample_rate() -> int: - return _sample_rate + _cond_cache[audio_prompt_path] = dict(zip(output_names, outputs)) + logger.info("Voice conditionals cached") + except Exception: + logger.exception(f"Failed to prepare voice '{audio_prompt_path}'") def synthesize( @@ -88,38 +120,128 @@ def synthesize( cfg_weight: float = 0.5, temperature: float = 0.8, seed: int = 0, -) -> Tuple[torch.Tensor, int]: - if chatterbox_model is None: +) -> Tuple[np.ndarray, int]: + """Synthesize speech. Returns (waveform_float32_1d, sample_rate).""" + if not _sessions: raise RuntimeError("Model not loaded. Call load_model() first.") + if audio_prompt_path is None: + raise RuntimeError( + "ONNX model requires a reference audio file for voice cloning." + ) if seed > 0: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + np.random.seed(seed) - # Restore cached conditionals so generate() skips prepare_conditionals. - if audio_prompt_path and _is_turbo: - if audio_prompt_path not in _cond_cache: - prepare_voice(audio_prompt_path) - chatterbox_model.conds = _cond_cache[audio_prompt_path] + if audio_prompt_path not in _cond_cache: + prepare_voice(audio_prompt_path) - kwargs: dict = {} - # Don't pass audio_prompt_path — conds are already set above. - # For non-turbo models there's no cache, pass path as normal. - if audio_prompt_path and not _is_turbo: - kwargs["audio_prompt_path"] = audio_prompt_path + cond = _cond_cache[audio_prompt_path] - if _is_turbo: - kwargs["temperature"] = temperature - else: - kwargs["exaggeration"] = exaggeration - kwargs["cfg_weight"] = cfg_weight + # Tokenize input text + input_ids = _tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64) - with torch.inference_mode(): - wav = chatterbox_model.generate(text=text, **kwargs) + # Generate speech tokens autoregressively + speech_tokens = _generate(input_ids, cond) - if torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() + # Prepend prompt token (speaker identity anchor) and append silence tokens + prompt_token = cond.get("prompt_token") + silence = np.full((1, 3), SILENCE_TOKEN, dtype=np.int64) + parts = ([prompt_token] if prompt_token is not None else []) + [speech_tokens, silence] + decoder_input = np.concatenate(parts, axis=1) - return wav, _sample_rate + wav = _decode(decoder_input, cond) + return wav, SAMPLE_RATE + + +def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: + """Autoregressive LM loop with KV cache. Returns int64 speech token array [1, T].""" + embed_sess = _sessions["embed_tokens"] + lm_sess = _sessions["language_model"] + + embed_input_name = embed_sess.get_inputs()[0].name + + # Discover KV cache slot names from session metadata + past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name] + present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name] + lm_out_names = [o.name for o in lm_sess.get_outputs()] + + kv_dtype = np.float16 + + # Embed full text sequence + text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden] + + # Prepend conditioning embeddings from speech encoder + cond_emb = cond.get("cond_emb") + inputs_embeds = ( + np.concatenate([cond_emb, text_embeds], axis=1) if cond_emb is not None else text_embeds + ) + + seq_len = inputs_embeds.shape[1] + attention_mask = np.ones((1, seq_len), dtype=np.int64) + position_ids = np.arange(seq_len, dtype=np.int64)[np.newaxis, :] + + # Empty KV cache to start + past_kv = { + name: np.zeros((1, NUM_KV_HEADS, 0, HEAD_DIM), dtype=kv_dtype) + for name in past_names + } + + generated: list = [] + + for _ in range(MAX_NEW_TOKENS): + feed = { + "inputs_embeds": inputs_embeds.astype(kv_dtype), + "attention_mask": attention_mask, + "position_ids": position_ids, + } + feed.update(past_kv) + + raw = lm_sess.run(None, feed) + out = dict(zip(lm_out_names, raw)) + + logits = out["logits"][0, -1, :].astype(np.float32) # [vocab] + + # Repetition penalty (greedy decoding) + for tok in set(generated): + if logits[tok] > 0: + logits[tok] /= REPETITION_PENALTY + else: + logits[tok] *= REPETITION_PENALTY + + next_token = int(np.argmax(logits)) + if next_token == STOP_SPEECH_TOKEN: + break + + generated.append(next_token) + + # Roll KV cache forward + past_kv = {pname: out[prname] for pname, prname in zip(past_names, present_names)} + + # Embed only the new token for next step + inputs_embeds = embed_sess.run( + None, {embed_input_name: np.array([[next_token]], dtype=np.int64)} + )[0] + new_pos = attention_mask.shape[1] + attention_mask = np.ones((1, new_pos + 1), dtype=np.int64) + position_ids = np.array([[new_pos]], dtype=np.int64) + + if not generated: + logger.warning("No speech tokens generated — returning start token") + return np.array([[START_SPEECH_TOKEN]], dtype=np.int64) + + return np.array([generated], dtype=np.int64) + + +def _decode(speech_tokens: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: + """Run conditional_decoder to produce a float32 waveform.""" + dec_sess = _sessions["conditional_decoder"] + dec_inputs = dec_sess.get_inputs() + + feed = {dec_inputs[0].name: speech_tokens} + if len(dec_inputs) > 1: + feed[dec_inputs[1].name] = cond.get("speaker_embeddings") + if len(dec_inputs) > 2: + feed[dec_inputs[2].name] = cond.get("speaker_features") + + wav = dec_sess.run(None, feed)[0] # [1, T] or [T] + return wav.squeeze().astype(np.float32) diff --git a/requirements-rocm-init.txt b/requirements-rocm-init.txt index 44a44f9..8c57eb8 100644 --- a/requirements-rocm-init.txt +++ b/requirements-rocm-init.txt @@ -1,5 +1 @@ ---index-url https://download.pytorch.org/whl/rocm6.1 -torch==2.5.1 -torchaudio==2.5.1 -torchvision==0.20.1 -pytorch_triton_rocm==3.1.0 +onnxruntime-rocm diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 8ce84d8..51ebf7e 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,21 +2,10 @@ numpy>=1.24.0,<2.0.0 soundfile librosa==0.11.0 -pyloudnorm -# ML dependencies (pinned to match chatterbox without overwriting ROCm torch) -transformers==4.46.3 -diffusers==0.29.0 -safetensors>=0.4.1 +# ONNX model dependencies +transformers>=4.40.0 huggingface-hub -omegaconf - -# Chatterbox dependencies (installed separately since chatterbox uses --no-deps) -conformer==0.3.2 -s3tokenizer==0.3.0 -spacy-pkuseg -pykakasi==2.3.0 -resemble-perth==1.0.1 # Wyoming protocol wyoming>=1.5.4 diff --git a/wyoming_handler.py b/wyoming_handler.py index 2cb2f0c..5b3e1a7 100644 --- a/wyoming_handler.py +++ b/wyoming_handler.py @@ -3,6 +3,8 @@ import logging import time from typing import Dict, Optional +import numpy as np + from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.event import Event from wyoming.info import Describe, Info @@ -151,7 +153,7 @@ class ChatterboxWyomingHandler(AsyncEventHandler): continue audio_bytes = ( - audio_tensor.cpu().numpy().squeeze() * 32767 + np.asarray(audio_tensor).squeeze() * 32767 ).astype("int16").tobytes() if first_chunk: