Initial implementation: Chatterbox TTS with ROCm and Wyoming
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s

Wyoming-only server built around the official chatterbox TTS model.
Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml
management, and Gitea CI for container builds.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 09:51:09 -04:00
parent 4b15e44181
commit 16ea2853f5
12 changed files with 691 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
name: Build ROCm Image
on:
push:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Gitea Registry
uses: docker/login-action@v3
with:
registry: git.sdgarren.com
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_TOKEN }}
- name: Build and Push
uses: docker/build-push-action@v6
with:
context: .
file: Dockerfile.rocm
push: true
tags: |
git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
git.sdgarren.com/scott/rocm-chatterbox-whisper:${{ gitea.sha }}

16
.gitignore vendored Normal file
View File

@@ -0,0 +1,16 @@
__pycache__/
*.py[cod]
*.egg-info/
.env
# Voice and audio files (mount via volume)
voices/
reference_audio/
outputs/
# Model cache
hf_cache/
# Logs
logs/
*.log

43
Dockerfile.rocm Normal file
View File

@@ -0,0 +1,43 @@
FROM rocm/dev-ubuntu-22.04:latest
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
HF_HOME=/app/hf_cache \
PIP_NO_CACHE_DIR=1
RUN apt-get update && apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-dev \
git \
ffmpeg \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Step 1: Install ROCm-compatible PyTorch stack first.
# This must happen before anything else to prevent pip from pulling CPU wheels.
COPY requirements-rocm-init.txt .
RUN pip3 install -r requirements-rocm-init.txt
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch).
COPY requirements-rocm.txt .
RUN pip3 install -r requirements-rocm.txt
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
RUN pip3 install --no-deps chatterbox-tts
# Application source
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
# Default config (can be overridden by volume mount)
COPY config.yaml .
# Create default directories
RUN mkdir -p voices reference_audio hf_cache
EXPOSE 10200
CMD ["python3", "main.py"]

115
config.py Normal file
View File

@@ -0,0 +1,115 @@
import logging
import os
from pathlib import Path
from typing import Optional
import yaml
logger = logging.getLogger(__name__)
_config: dict = {}
_config_path = Path(os.environ.get("CONFIG_PATH", "config.yaml"))
DEFAULTS = {
"model": {
"repo_id": "chatterbox-turbo",
},
"tts_engine": {
"device": "", # empty = auto-detect
"predefined_voices_path": "voices",
"reference_audio_path": "reference_audio",
"default_voice_id": "default.wav",
},
"generation_defaults": {
"temperature": 0.8,
"exaggeration": 0.5,
"cfg_weight": 0.5,
"seed": 0,
},
"wyoming": {
"host": "0.0.0.0",
"port": 10200,
"chunk_size": 300,
},
"paths": {
"model_cache": "/app/hf_cache",
},
}
def _deep_merge(base: dict, override: dict) -> dict:
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
def load_config() -> None:
global _config
_config = _deep_merge(DEFAULTS, {})
if _config_path.exists():
try:
with open(_config_path) as f:
user_config = yaml.safe_load(f) or {}
_config = _deep_merge(_config, user_config)
logger.info(f"Loaded config from {_config_path}")
except Exception:
logger.exception(f"Failed to load {_config_path}, using defaults")
else:
logger.warning(f"Config file not found at {_config_path}, using defaults")
# Set HuggingFace cache path from config
cache_path = _config.get("paths", {}).get("model_cache", "")
if cache_path:
os.environ.setdefault("HF_HOME", cache_path)
def get_model_repo_id() -> str:
return _config.get("model", {}).get("repo_id", "chatterbox-turbo")
def get_device_override() -> Optional[str]:
return _config.get("tts_engine", {}).get("device") or None
def get_predefined_voices_path() -> Path:
return Path(_config.get("tts_engine", {}).get("predefined_voices_path", "voices"))
def get_reference_audio_path() -> Path:
return Path(_config.get("tts_engine", {}).get("reference_audio_path", "reference_audio"))
def get_default_voice_id() -> str:
return _config.get("tts_engine", {}).get("default_voice_id", "default.wav")
def get_wyoming_host() -> str:
return _config.get("wyoming", {}).get("host", "0.0.0.0")
def get_wyoming_port() -> int:
return int(_config.get("wyoming", {}).get("port", 10200))
def get_wyoming_chunk_size() -> int:
return int(_config.get("wyoming", {}).get("chunk_size", 300))
def get_gen_temperature() -> float:
return float(_config.get("generation_defaults", {}).get("temperature", 0.8))
def get_gen_exaggeration() -> float:
return float(_config.get("generation_defaults", {}).get("exaggeration", 0.5))
def get_gen_cfg_weight() -> float:
return float(_config.get("generation_defaults", {}).get("cfg_weight", 0.5))
def get_gen_seed() -> int:
return int(_config.get("generation_defaults", {}).get("seed", 0))

29
config.yaml Normal file
View File

@@ -0,0 +1,29 @@
model:
# Options: chatterbox, chatterbox-turbo
repo_id: chatterbox-turbo
tts_engine:
# Device: cuda, cpu, or leave empty for auto-detect
device: ""
predefined_voices_path: voices
reference_audio_path: reference_audio
# Fallback voice (stem name, e.g. "default" matches default.wav)
default_voice_id: default.wav
generation_defaults:
# Turbo model: uses temperature only (exaggeration/cfg_weight ignored)
# Standard model: uses exaggeration and cfg_weight (temperature ignored)
temperature: 0.8
exaggeration: 0.5
cfg_weight: 0.5
# seed: 0 = random each call, >0 = reproducible output
seed: 0
wyoming:
host: "0.0.0.0"
port: 10200
# Max characters per synthesis chunk (split at sentence boundaries)
chunk_size: 300
paths:
model_cache: /app/hf_cache

35
docker-compose.yml Normal file
View File

@@ -0,0 +1,35 @@
services:
chatterbox-whisper:
image: git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
build:
context: .
dockerfile: Dockerfile.rocm
restart: unless-stopped
ports:
- "${WYOMING_PORT:-10200}:10200"
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
- render
ipc: host
shm_size: 8g
security_opt:
- seccomp=unconfined
volumes:
- ./config.yaml:/app/config.yaml
- ./voices:/app/voices
- ./reference_audio:/app/reference_audio
- hf_cache:/app/hf_cache
environment:
- HF_HUB_ENABLE_HF_TRANSFER=1
# Set your GPU architecture:
# 10.3.0 = RX 5000/6000 series
# 11.0.0 = RX 7000 series
# 9.0.6 = Vega
- HSA_OVERRIDE_GFX_VERSION=10.3.0
# - HF_TOKEN=your_token_here
volumes:
hf_cache:

91
engine.py Normal file
View File

@@ -0,0 +1,91 @@
import logging
import torch
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
chatterbox_model = None
_sample_rate = 24000
_is_turbo = False
def _test_cuda() -> bool:
try:
if torch.cuda.is_available():
torch.zeros(1).cuda()
return True
except Exception:
pass
return False
def detect_device() -> str:
return "cuda" if _test_cuda() else "cpu"
def load_model() -> bool:
global chatterbox_model, _sample_rate, _is_turbo
from config import get_model_repo_id, get_device_override
device = get_device_override() or detect_device()
repo_id = get_model_repo_id()
logger.info(f"Loading model '{repo_id}' on device '{device}'")
try:
if "turbo" in repo_id.lower():
from chatterbox.tts_turbo import ChatterboxTurboTTS
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
_is_turbo = True
else:
from chatterbox.tts import ChatterboxTTS
chatterbox_model = ChatterboxTTS.from_pretrained(device)
_is_turbo = False
_sample_rate = 24000
logger.info("Model loaded successfully")
return True
except Exception:
logger.exception("Failed to load model")
return False
def get_sample_rate() -> int:
return _sample_rate
def synthesize(
text: str,
audio_prompt_path: Optional[str] = None,
exaggeration: float = 0.5,
cfg_weight: float = 0.5,
temperature: float = 0.8,
seed: int = 0,
) -> Tuple[torch.Tensor, int]:
if chatterbox_model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
if seed > 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
kwargs: dict = {}
if audio_prompt_path:
kwargs["audio_prompt_path"] = audio_prompt_path
if _is_turbo:
kwargs["temperature"] = temperature
else:
kwargs["exaggeration"] = exaggeration
kwargs["cfg_weight"] = cfg_weight
with torch.inference_mode():
wav = chatterbox_model.generate(text=text, **kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
return wav, _sample_rate

42
main.py Normal file
View File

@@ -0,0 +1,42 @@
import asyncio
import logging
import sys
from functools import partial
from wyoming.server import AsyncServer
import engine
from config import get_wyoming_host, get_wyoming_port, load_config
from wyoming_handler import ChatterboxWyomingHandler
from wyoming_voices import create_wyoming_info, load_voices
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
async def main() -> None:
load_config()
logger.info("Loading TTS model...")
if not engine.load_model():
logger.error("Failed to load model, exiting")
sys.exit(1)
voices = load_voices()
wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices)
host = get_wyoming_host()
port = get_wyoming_port()
uri = f"tcp://{host}:{port}"
logger.info(f"Starting Wyoming server on {uri}")
server = AsyncServer.from_uri(uri)
await server.run(partial(ChatterboxWyomingHandler, wyoming_info, voices))
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,5 @@
--index-url https://download.pytorch.org/whl/rocm6.1
torch==2.5.1
torchaudio==2.5.1
torchvision==0.20.1
pytorch_triton_rocm==3.1.0

17
requirements-rocm.txt Normal file
View File

@@ -0,0 +1,17 @@
# Audio processing
numpy
soundfile
librosa
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
transformers==4.46.3
diffusers==0.29.0
safetensors
huggingface-hub
accelerate
# Wyoming protocol
wyoming>=1.5.4
# Config / utilities
PyYAML>=6.0

166
wyoming_handler.py Normal file
View File

@@ -0,0 +1,166 @@
import asyncio
import logging
import time
from typing import Dict, Optional
import numpy as np
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
import engine
from config import (
get_gen_cfg_weight,
get_gen_exaggeration,
get_gen_seed,
get_gen_temperature,
get_wyoming_chunk_size,
)
from wyoming_voices import resolve_voice
logger = logging.getLogger(__name__)
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
import re
# Split on sentence-ending punctuation
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
chunks = []
current = ""
for sentence in sentences:
if not sentence:
continue
if current and len(current) + 1 + len(sentence) > max_chunk:
chunks.append(current.strip())
current = sentence
else:
current = (current + " " + sentence).strip() if current else sentence
if current:
chunks.append(current.strip())
return chunks or [text]
class ChatterboxWyomingHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
voices: Dict[str, str],
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._info = wyoming_info
self._voices = voices
self._streaming = False
self._streaming_text = ""
self._streaming_voice: Optional[str] = None
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self._info.event())
return True
if SynthesizeStart.is_type(event.type):
start = SynthesizeStart.from_event(event)
self._streaming = True
self._streaming_text = ""
self._streaming_voice = start.voice.name if start.voice else None
return True
if SynthesizeChunk.is_type(event.type):
chunk = SynthesizeChunk.from_event(event)
if self._streaming:
self._streaming_text += chunk.text
return True
if SynthesizeStop.is_type(event.type):
if self._streaming and self._streaming_text:
await self._synthesize_and_stream(
self._streaming_text,
self._streaming_voice,
)
self._streaming = False
self._streaming_text = ""
self._streaming_voice = None
await self.write_event(SynthesizeStopped().event())
return True
if Synthesize.is_type(event.type):
synth = Synthesize.from_event(event)
voice_name = synth.voice.name if synth.voice else None
await self._synthesize_and_stream(synth.text, voice_name)
await self.write_event(SynthesizeStopped().event())
return True
return True
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
audio_prompt = resolve_voice(voice_name, self._voices)
sample_rate = engine.get_sample_rate()
chunk_size = get_wyoming_chunk_size()
chunks = _split_text(text, max_chunk=chunk_size)
logger.info(
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
f"prompt='{audio_prompt}'"
)
await self.write_event(
AudioStart(rate=sample_rate, width=2, channels=1).event()
)
first_chunk = True
start_time = time.monotonic()
for i, chunk_text in enumerate(chunks):
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
try:
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
None,
self._synthesize_chunk,
chunk_text,
audio_prompt,
)
except Exception:
logger.exception(f"Synthesis failed for chunk {i+1}")
continue
audio_np = audio_tensor.cpu().numpy().squeeze()
audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
if first_chunk:
ttfa = time.monotonic() - start_time
logger.info(f"Time to first audio: {ttfa:.3f}s")
first_chunk = False
await self.write_event(
AudioChunk(
audio=audio_bytes,
rate=sample_rate,
width=2,
channels=1,
).event()
)
await self.write_event(AudioStop().event())
total = time.monotonic() - start_time
logger.info(f"Synthesis complete in {total:.3f}s")
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
return engine.synthesize(
text=text,
audio_prompt_path=audio_prompt,
temperature=get_gen_temperature(),
exaggeration=get_gen_exaggeration(),
cfg_weight=get_gen_cfg_weight(),
seed=get_gen_seed(),
)

99
wyoming_voices.py Normal file
View File

@@ -0,0 +1,99 @@
import logging
from pathlib import Path
from typing import Dict, Optional
from wyoming.info import Attribution, Info, TtsProgram, TtsVoice, TtsVoiceSpeaker
logger = logging.getLogger(__name__)
VOICE_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg"}
def load_voices() -> Dict[str, str]:
"""Scan voice directories and return {voice_name: file_path} mapping."""
from config import get_predefined_voices_path, get_reference_audio_path
voices: Dict[str, str] = {}
def _scan_dir(directory: Path) -> None:
if not directory.exists():
return
for f in sorted(directory.iterdir()):
if f.suffix.lower() in VOICE_EXTENSIONS:
name = f.stem
if name not in voices:
voices[name] = str(f)
# Reference audio first so predefined voices take priority on collision
_scan_dir(get_reference_audio_path())
_scan_dir(get_predefined_voices_path())
logger.info(f"Discovered {len(voices)} voice(s): {list(voices.keys())}")
return voices
def resolve_voice(voice_name: Optional[str], voices: Dict[str, str]) -> Optional[str]:
"""Resolve a voice name to its audio file path."""
from config import get_predefined_voices_path, get_reference_audio_path, get_default_voice_id
if not voice_name:
default = get_default_voice_id()
voice_name = Path(default).stem
# Exact name match in discovered voices
if voice_name in voices:
return voices[voice_name]
# Try predefined voices dir with extensions
for ext in VOICE_EXTENSIONS:
p = get_predefined_voices_path() / f"{voice_name}{ext}"
if p.exists():
return str(p)
# Try reference audio dir with extensions
for ext in VOICE_EXTENSIONS:
p = get_reference_audio_path() / f"{voice_name}{ext}"
if p.exists():
return str(p)
# Fall back to any voice
if voices:
fallback = next(iter(voices.values()))
logger.warning(f"Voice '{voice_name}' not found, falling back to '{fallback}'")
return fallback
return None
def create_wyoming_info(sample_rate: int, voices: Dict[str, str]) -> Info:
"""Build the Wyoming Info object advertised to Home Assistant."""
tts_voices = [
TtsVoice(
name=name,
description=f"Chatterbox voice: {name}",
attribution=Attribution(
name="ResembleAI",
url="https://github.com/resemble-ai/chatterbox",
),
installed=True,
languages=["en"],
speakers=[TtsVoiceSpeaker(name=name)],
)
for name in sorted(voices.keys())
]
return Info(
tts=[
TtsProgram(
name="chatterbox",
description="Chatterbox TTS with ROCm/AMD GPU support",
attribution=Attribution(
name="ResembleAI",
url="https://github.com/resemble-ai/chatterbox",
),
installed=True,
voices=tts_voices,
version="1.0",
)
]
)