From 16ea2853f526f714530f903744393327f03bbcba Mon Sep 17 00:00:00 2001 From: scott Date: Sun, 5 Apr 2026 09:51:09 -0400 Subject: [PATCH] Initial implementation: Chatterbox TTS with ROCm and Wyoming Wyoming-only server built around the official chatterbox TTS model. Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml management, and Gitea CI for container builds. Co-Authored-By: Claude Sonnet 4.6 --- .gitea/workflows/build.yml | 33 ++++++++ .gitignore | 16 ++++ Dockerfile.rocm | 43 ++++++++++ config.py | 115 +++++++++++++++++++++++++ config.yaml | 29 +++++++ docker-compose.yml | 35 ++++++++ engine.py | 91 ++++++++++++++++++++ main.py | 42 ++++++++++ requirements-rocm-init.txt | 5 ++ requirements-rocm.txt | 17 ++++ wyoming_handler.py | 166 +++++++++++++++++++++++++++++++++++++ wyoming_voices.py | 99 ++++++++++++++++++++++ 12 files changed, 691 insertions(+) create mode 100644 .gitea/workflows/build.yml create mode 100644 .gitignore create mode 100644 Dockerfile.rocm create mode 100644 config.py create mode 100644 config.yaml create mode 100644 docker-compose.yml create mode 100644 engine.py create mode 100644 main.py create mode 100644 requirements-rocm-init.txt create mode 100644 requirements-rocm.txt create mode 100644 wyoming_handler.py create mode 100644 wyoming_voices.py diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml new file mode 100644 index 0000000..71c1c71 --- /dev/null +++ b/.gitea/workflows/build.yml @@ -0,0 +1,33 @@ +name: Build ROCm Image + +on: + push: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Gitea Registry + uses: docker/login-action@v3 + with: + registry: git.sdgarren.com + username: ${{ secrets.REGISTRY_USERNAME }} + password: ${{ secrets.REGISTRY_TOKEN }} + + - name: Build and Push + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile.rocm + push: true + tags: | + git.sdgarren.com/scott/rocm-chatterbox-whisper:latest + git.sdgarren.com/scott/rocm-chatterbox-whisper:${{ gitea.sha }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fe062fe --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +__pycache__/ +*.py[cod] +*.egg-info/ +.env + +# Voice and audio files (mount via volume) +voices/ +reference_audio/ +outputs/ + +# Model cache +hf_cache/ + +# Logs +logs/ +*.log diff --git a/Dockerfile.rocm b/Dockerfile.rocm new file mode 100644 index 0000000..cdb8034 --- /dev/null +++ b/Dockerfile.rocm @@ -0,0 +1,43 @@ +FROM rocm/dev-ubuntu-22.04:latest + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + HF_HOME=/app/hf_cache \ + PIP_NO_CACHE_DIR=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + git \ + ffmpeg \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Step 1: Install ROCm-compatible PyTorch stack first. +# This must happen before anything else to prevent pip from pulling CPU wheels. +COPY requirements-rocm-init.txt . +RUN pip3 install -r requirements-rocm-init.txt + +# Step 2: Install remaining dependencies (pinned to avoid overwriting torch). +COPY requirements-rocm.txt . +RUN pip3 install -r requirements-rocm.txt + +# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch. +RUN pip3 install --no-deps chatterbox-tts + +# Application source +COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./ + +# Default config (can be overridden by volume mount) +COPY config.yaml . + +# Create default directories +RUN mkdir -p voices reference_audio hf_cache + +EXPOSE 10200 + +CMD ["python3", "main.py"] diff --git a/config.py b/config.py new file mode 100644 index 0000000..3980334 --- /dev/null +++ b/config.py @@ -0,0 +1,115 @@ +import logging +import os +from pathlib import Path +from typing import Optional + +import yaml + +logger = logging.getLogger(__name__) + +_config: dict = {} +_config_path = Path(os.environ.get("CONFIG_PATH", "config.yaml")) + +DEFAULTS = { + "model": { + "repo_id": "chatterbox-turbo", + }, + "tts_engine": { + "device": "", # empty = auto-detect + "predefined_voices_path": "voices", + "reference_audio_path": "reference_audio", + "default_voice_id": "default.wav", + }, + "generation_defaults": { + "temperature": 0.8, + "exaggeration": 0.5, + "cfg_weight": 0.5, + "seed": 0, + }, + "wyoming": { + "host": "0.0.0.0", + "port": 10200, + "chunk_size": 300, + }, + "paths": { + "model_cache": "/app/hf_cache", + }, +} + + +def _deep_merge(base: dict, override: dict) -> dict: + result = base.copy() + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + +def load_config() -> None: + global _config + _config = _deep_merge(DEFAULTS, {}) + if _config_path.exists(): + try: + with open(_config_path) as f: + user_config = yaml.safe_load(f) or {} + _config = _deep_merge(_config, user_config) + logger.info(f"Loaded config from {_config_path}") + except Exception: + logger.exception(f"Failed to load {_config_path}, using defaults") + else: + logger.warning(f"Config file not found at {_config_path}, using defaults") + + # Set HuggingFace cache path from config + cache_path = _config.get("paths", {}).get("model_cache", "") + if cache_path: + os.environ.setdefault("HF_HOME", cache_path) + + +def get_model_repo_id() -> str: + return _config.get("model", {}).get("repo_id", "chatterbox-turbo") + + +def get_device_override() -> Optional[str]: + return _config.get("tts_engine", {}).get("device") or None + + +def get_predefined_voices_path() -> Path: + return Path(_config.get("tts_engine", {}).get("predefined_voices_path", "voices")) + + +def get_reference_audio_path() -> Path: + return Path(_config.get("tts_engine", {}).get("reference_audio_path", "reference_audio")) + + +def get_default_voice_id() -> str: + return _config.get("tts_engine", {}).get("default_voice_id", "default.wav") + + +def get_wyoming_host() -> str: + return _config.get("wyoming", {}).get("host", "0.0.0.0") + + +def get_wyoming_port() -> int: + return int(_config.get("wyoming", {}).get("port", 10200)) + + +def get_wyoming_chunk_size() -> int: + return int(_config.get("wyoming", {}).get("chunk_size", 300)) + + +def get_gen_temperature() -> float: + return float(_config.get("generation_defaults", {}).get("temperature", 0.8)) + + +def get_gen_exaggeration() -> float: + return float(_config.get("generation_defaults", {}).get("exaggeration", 0.5)) + + +def get_gen_cfg_weight() -> float: + return float(_config.get("generation_defaults", {}).get("cfg_weight", 0.5)) + + +def get_gen_seed() -> int: + return int(_config.get("generation_defaults", {}).get("seed", 0)) diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..3b8c4cf --- /dev/null +++ b/config.yaml @@ -0,0 +1,29 @@ +model: + # Options: chatterbox, chatterbox-turbo + repo_id: chatterbox-turbo + +tts_engine: + # Device: cuda, cpu, or leave empty for auto-detect + device: "" + predefined_voices_path: voices + reference_audio_path: reference_audio + # Fallback voice (stem name, e.g. "default" matches default.wav) + default_voice_id: default.wav + +generation_defaults: + # Turbo model: uses temperature only (exaggeration/cfg_weight ignored) + # Standard model: uses exaggeration and cfg_weight (temperature ignored) + temperature: 0.8 + exaggeration: 0.5 + cfg_weight: 0.5 + # seed: 0 = random each call, >0 = reproducible output + seed: 0 + +wyoming: + host: "0.0.0.0" + port: 10200 + # Max characters per synthesis chunk (split at sentence boundaries) + chunk_size: 300 + +paths: + model_cache: /app/hf_cache diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..dbb8a46 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,35 @@ +services: + chatterbox-whisper: + image: git.sdgarren.com/scott/rocm-chatterbox-whisper:latest + build: + context: . + dockerfile: Dockerfile.rocm + restart: unless-stopped + ports: + - "${WYOMING_PORT:-10200}:10200" + devices: + - /dev/kfd + - /dev/dri + group_add: + - video + - render + ipc: host + shm_size: 8g + security_opt: + - seccomp=unconfined + volumes: + - ./config.yaml:/app/config.yaml + - ./voices:/app/voices + - ./reference_audio:/app/reference_audio + - hf_cache:/app/hf_cache + environment: + - HF_HUB_ENABLE_HF_TRANSFER=1 + # Set your GPU architecture: + # 10.3.0 = RX 5000/6000 series + # 11.0.0 = RX 7000 series + # 9.0.6 = Vega + - HSA_OVERRIDE_GFX_VERSION=10.3.0 + # - HF_TOKEN=your_token_here + +volumes: + hf_cache: diff --git a/engine.py b/engine.py new file mode 100644 index 0000000..00d4a05 --- /dev/null +++ b/engine.py @@ -0,0 +1,91 @@ +import logging +import torch +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +chatterbox_model = None +_sample_rate = 24000 +_is_turbo = False + + +def _test_cuda() -> bool: + try: + if torch.cuda.is_available(): + torch.zeros(1).cuda() + return True + except Exception: + pass + return False + + +def detect_device() -> str: + return "cuda" if _test_cuda() else "cpu" + + +def load_model() -> bool: + global chatterbox_model, _sample_rate, _is_turbo + + from config import get_model_repo_id, get_device_override + + device = get_device_override() or detect_device() + repo_id = get_model_repo_id() + + logger.info(f"Loading model '{repo_id}' on device '{device}'") + + try: + if "turbo" in repo_id.lower(): + from chatterbox.tts_turbo import ChatterboxTurboTTS + chatterbox_model = ChatterboxTurboTTS.from_pretrained(device) + _is_turbo = True + else: + from chatterbox.tts import ChatterboxTTS + chatterbox_model = ChatterboxTTS.from_pretrained(device) + _is_turbo = False + + _sample_rate = 24000 + logger.info("Model loaded successfully") + return True + except Exception: + logger.exception("Failed to load model") + return False + + +def get_sample_rate() -> int: + return _sample_rate + + +def synthesize( + text: str, + audio_prompt_path: Optional[str] = None, + exaggeration: float = 0.5, + cfg_weight: float = 0.5, + temperature: float = 0.8, + seed: int = 0, +) -> Tuple[torch.Tensor, int]: + if chatterbox_model is None: + raise RuntimeError("Model not loaded. Call load_model() first.") + + if seed > 0: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + kwargs: dict = {} + if audio_prompt_path: + kwargs["audio_prompt_path"] = audio_prompt_path + + if _is_turbo: + kwargs["temperature"] = temperature + else: + kwargs["exaggeration"] = exaggeration + kwargs["cfg_weight"] = cfg_weight + + with torch.inference_mode(): + wav = chatterbox_model.generate(text=text, **kwargs) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + return wav, _sample_rate diff --git a/main.py b/main.py new file mode 100644 index 0000000..8bbc332 --- /dev/null +++ b/main.py @@ -0,0 +1,42 @@ +import asyncio +import logging +import sys +from functools import partial + +from wyoming.server import AsyncServer + +import engine +from config import get_wyoming_host, get_wyoming_port, load_config +from wyoming_handler import ChatterboxWyomingHandler +from wyoming_voices import create_wyoming_info, load_voices + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +async def main() -> None: + load_config() + + logger.info("Loading TTS model...") + if not engine.load_model(): + logger.error("Failed to load model, exiting") + sys.exit(1) + + voices = load_voices() + wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices) + + host = get_wyoming_host() + port = get_wyoming_port() + uri = f"tcp://{host}:{port}" + + logger.info(f"Starting Wyoming server on {uri}") + server = AsyncServer.from_uri(uri) + await server.run(partial(ChatterboxWyomingHandler, wyoming_info, voices)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/requirements-rocm-init.txt b/requirements-rocm-init.txt new file mode 100644 index 0000000..44a44f9 --- /dev/null +++ b/requirements-rocm-init.txt @@ -0,0 +1,5 @@ +--index-url https://download.pytorch.org/whl/rocm6.1 +torch==2.5.1 +torchaudio==2.5.1 +torchvision==0.20.1 +pytorch_triton_rocm==3.1.0 diff --git a/requirements-rocm.txt b/requirements-rocm.txt new file mode 100644 index 0000000..86fbabf --- /dev/null +++ b/requirements-rocm.txt @@ -0,0 +1,17 @@ +# Audio processing +numpy +soundfile +librosa + +# ML dependencies (pinned to match chatterbox without overwriting ROCm torch) +transformers==4.46.3 +diffusers==0.29.0 +safetensors +huggingface-hub +accelerate + +# Wyoming protocol +wyoming>=1.5.4 + +# Config / utilities +PyYAML>=6.0 diff --git a/wyoming_handler.py b/wyoming_handler.py new file mode 100644 index 0000000..bcb96a7 --- /dev/null +++ b/wyoming_handler.py @@ -0,0 +1,166 @@ +import asyncio +import logging +import time +from typing import Dict, Optional + +import numpy as np +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.event import Event +from wyoming.info import Describe, Info +from wyoming.server import AsyncEventHandler +from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop + +import engine +from config import ( + get_gen_cfg_weight, + get_gen_exaggeration, + get_gen_seed, + get_gen_temperature, + get_wyoming_chunk_size, +) +from wyoming_voices import resolve_voice + +logger = logging.getLogger(__name__) + + +def _split_text(text: str, max_chunk: int = 300) -> list[str]: + """Split text at sentence boundaries to keep chunks under max_chunk chars.""" + import re + + # Split on sentence-ending punctuation + sentences = re.split(r"(?<=[.!?])\s+", text.strip()) + + chunks = [] + current = "" + for sentence in sentences: + if not sentence: + continue + if current and len(current) + 1 + len(sentence) > max_chunk: + chunks.append(current.strip()) + current = sentence + else: + current = (current + " " + sentence).strip() if current else sentence + + if current: + chunks.append(current.strip()) + + return chunks or [text] + + +class ChatterboxWyomingHandler(AsyncEventHandler): + def __init__( + self, + wyoming_info: Info, + voices: Dict[str, str], + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self._info = wyoming_info + self._voices = voices + + self._streaming = False + self._streaming_text = "" + self._streaming_voice: Optional[str] = None + + async def handle_event(self, event: Event) -> bool: + if Describe.is_type(event.type): + await self.write_event(self._info.event()) + return True + + if SynthesizeStart.is_type(event.type): + start = SynthesizeStart.from_event(event) + self._streaming = True + self._streaming_text = "" + self._streaming_voice = start.voice.name if start.voice else None + return True + + if SynthesizeChunk.is_type(event.type): + chunk = SynthesizeChunk.from_event(event) + if self._streaming: + self._streaming_text += chunk.text + return True + + if SynthesizeStop.is_type(event.type): + if self._streaming and self._streaming_text: + await self._synthesize_and_stream( + self._streaming_text, + self._streaming_voice, + ) + self._streaming = False + self._streaming_text = "" + self._streaming_voice = None + await self.write_event(SynthesizeStopped().event()) + return True + + if Synthesize.is_type(event.type): + synth = Synthesize.from_event(event) + voice_name = synth.voice.name if synth.voice else None + await self._synthesize_and_stream(synth.text, voice_name) + await self.write_event(SynthesizeStopped().event()) + return True + + return True + + async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None: + audio_prompt = resolve_voice(voice_name, self._voices) + sample_rate = engine.get_sample_rate() + chunk_size = get_wyoming_chunk_size() + + chunks = _split_text(text, max_chunk=chunk_size) + logger.info( + f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', " + f"prompt='{audio_prompt}'" + ) + + await self.write_event( + AudioStart(rate=sample_rate, width=2, channels=1).event() + ) + + first_chunk = True + start_time = time.monotonic() + + for i, chunk_text in enumerate(chunks): + logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}") + + try: + audio_tensor, sr = await asyncio.get_event_loop().run_in_executor( + None, + self._synthesize_chunk, + chunk_text, + audio_prompt, + ) + except Exception: + logger.exception(f"Synthesis failed for chunk {i+1}") + continue + + audio_np = audio_tensor.cpu().numpy().squeeze() + audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes() + + if first_chunk: + ttfa = time.monotonic() - start_time + logger.info(f"Time to first audio: {ttfa:.3f}s") + first_chunk = False + + await self.write_event( + AudioChunk( + audio=audio_bytes, + rate=sample_rate, + width=2, + channels=1, + ).event() + ) + + await self.write_event(AudioStop().event()) + total = time.monotonic() - start_time + logger.info(f"Synthesis complete in {total:.3f}s") + + def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]): + return engine.synthesize( + text=text, + audio_prompt_path=audio_prompt, + temperature=get_gen_temperature(), + exaggeration=get_gen_exaggeration(), + cfg_weight=get_gen_cfg_weight(), + seed=get_gen_seed(), + ) diff --git a/wyoming_voices.py b/wyoming_voices.py new file mode 100644 index 0000000..1abdfa3 --- /dev/null +++ b/wyoming_voices.py @@ -0,0 +1,99 @@ +import logging +from pathlib import Path +from typing import Dict, Optional + +from wyoming.info import Attribution, Info, TtsProgram, TtsVoice, TtsVoiceSpeaker + +logger = logging.getLogger(__name__) + +VOICE_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg"} + + +def load_voices() -> Dict[str, str]: + """Scan voice directories and return {voice_name: file_path} mapping.""" + from config import get_predefined_voices_path, get_reference_audio_path + + voices: Dict[str, str] = {} + + def _scan_dir(directory: Path) -> None: + if not directory.exists(): + return + for f in sorted(directory.iterdir()): + if f.suffix.lower() in VOICE_EXTENSIONS: + name = f.stem + if name not in voices: + voices[name] = str(f) + + # Reference audio first so predefined voices take priority on collision + _scan_dir(get_reference_audio_path()) + _scan_dir(get_predefined_voices_path()) + + logger.info(f"Discovered {len(voices)} voice(s): {list(voices.keys())}") + return voices + + +def resolve_voice(voice_name: Optional[str], voices: Dict[str, str]) -> Optional[str]: + """Resolve a voice name to its audio file path.""" + from config import get_predefined_voices_path, get_reference_audio_path, get_default_voice_id + + if not voice_name: + default = get_default_voice_id() + voice_name = Path(default).stem + + # Exact name match in discovered voices + if voice_name in voices: + return voices[voice_name] + + # Try predefined voices dir with extensions + for ext in VOICE_EXTENSIONS: + p = get_predefined_voices_path() / f"{voice_name}{ext}" + if p.exists(): + return str(p) + + # Try reference audio dir with extensions + for ext in VOICE_EXTENSIONS: + p = get_reference_audio_path() / f"{voice_name}{ext}" + if p.exists(): + return str(p) + + # Fall back to any voice + if voices: + fallback = next(iter(voices.values())) + logger.warning(f"Voice '{voice_name}' not found, falling back to '{fallback}'") + return fallback + + return None + + +def create_wyoming_info(sample_rate: int, voices: Dict[str, str]) -> Info: + """Build the Wyoming Info object advertised to Home Assistant.""" + tts_voices = [ + TtsVoice( + name=name, + description=f"Chatterbox voice: {name}", + attribution=Attribution( + name="ResembleAI", + url="https://github.com/resemble-ai/chatterbox", + ), + installed=True, + languages=["en"], + speakers=[TtsVoiceSpeaker(name=name)], + ) + for name in sorted(voices.keys()) + ] + + return Info( + tts=[ + TtsProgram( + name="chatterbox", + description="Chatterbox TTS with ROCm/AMD GPU support", + attribution=Attribution( + name="ResembleAI", + url="https://github.com/resemble-ai/chatterbox", + ), + installed=True, + voices=tts_voices, + version="1.0", + ) + ] + )