Initial implementation: Chatterbox TTS with ROCm and Wyoming
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
Wyoming-only server built around the official chatterbox TTS model. Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml management, and Gitea CI for container builds. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
33
.gitea/workflows/build.yml
Normal file
33
.gitea/workflows/build.yml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
name: Build ROCm Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to Gitea Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: git.sdgarren.com
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||||
|
|
||||||
|
- name: Build and Push
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: Dockerfile.rocm
|
||||||
|
push: true
|
||||||
|
tags: |
|
||||||
|
git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
|
||||||
|
git.sdgarren.com/scott/rocm-chatterbox-whisper:${{ gitea.sha }}
|
||||||
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.egg-info/
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Voice and audio files (mount via volume)
|
||||||
|
voices/
|
||||||
|
reference_audio/
|
||||||
|
outputs/
|
||||||
|
|
||||||
|
# Model cache
|
||||||
|
hf_cache/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
43
Dockerfile.rocm
Normal file
43
Dockerfile.rocm
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
FROM rocm/dev-ubuntu-22.04:latest
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
HF_HOME=/app/hf_cache \
|
||||||
|
PIP_NO_CACHE_DIR=1
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3 \
|
||||||
|
python3-pip \
|
||||||
|
python3-dev \
|
||||||
|
git \
|
||||||
|
ffmpeg \
|
||||||
|
libsndfile1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Step 1: Install ROCm-compatible PyTorch stack first.
|
||||||
|
# This must happen before anything else to prevent pip from pulling CPU wheels.
|
||||||
|
COPY requirements-rocm-init.txt .
|
||||||
|
RUN pip3 install -r requirements-rocm-init.txt
|
||||||
|
|
||||||
|
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch).
|
||||||
|
COPY requirements-rocm.txt .
|
||||||
|
RUN pip3 install -r requirements-rocm.txt
|
||||||
|
|
||||||
|
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
|
||||||
|
RUN pip3 install --no-deps chatterbox-tts
|
||||||
|
|
||||||
|
# Application source
|
||||||
|
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
|
||||||
|
|
||||||
|
# Default config (can be overridden by volume mount)
|
||||||
|
COPY config.yaml .
|
||||||
|
|
||||||
|
# Create default directories
|
||||||
|
RUN mkdir -p voices reference_audio hf_cache
|
||||||
|
|
||||||
|
EXPOSE 10200
|
||||||
|
|
||||||
|
CMD ["python3", "main.py"]
|
||||||
115
config.py
Normal file
115
config.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_config: dict = {}
|
||||||
|
_config_path = Path(os.environ.get("CONFIG_PATH", "config.yaml"))
|
||||||
|
|
||||||
|
DEFAULTS = {
|
||||||
|
"model": {
|
||||||
|
"repo_id": "chatterbox-turbo",
|
||||||
|
},
|
||||||
|
"tts_engine": {
|
||||||
|
"device": "", # empty = auto-detect
|
||||||
|
"predefined_voices_path": "voices",
|
||||||
|
"reference_audio_path": "reference_audio",
|
||||||
|
"default_voice_id": "default.wav",
|
||||||
|
},
|
||||||
|
"generation_defaults": {
|
||||||
|
"temperature": 0.8,
|
||||||
|
"exaggeration": 0.5,
|
||||||
|
"cfg_weight": 0.5,
|
||||||
|
"seed": 0,
|
||||||
|
},
|
||||||
|
"wyoming": {
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 10200,
|
||||||
|
"chunk_size": 300,
|
||||||
|
},
|
||||||
|
"paths": {
|
||||||
|
"model_cache": "/app/hf_cache",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_merge(base: dict, override: dict) -> dict:
|
||||||
|
result = base.copy()
|
||||||
|
for key, value in override.items():
|
||||||
|
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||||
|
result[key] = _deep_merge(result[key], value)
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> None:
|
||||||
|
global _config
|
||||||
|
_config = _deep_merge(DEFAULTS, {})
|
||||||
|
if _config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(_config_path) as f:
|
||||||
|
user_config = yaml.safe_load(f) or {}
|
||||||
|
_config = _deep_merge(_config, user_config)
|
||||||
|
logger.info(f"Loaded config from {_config_path}")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to load {_config_path}, using defaults")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Config file not found at {_config_path}, using defaults")
|
||||||
|
|
||||||
|
# Set HuggingFace cache path from config
|
||||||
|
cache_path = _config.get("paths", {}).get("model_cache", "")
|
||||||
|
if cache_path:
|
||||||
|
os.environ.setdefault("HF_HOME", cache_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_repo_id() -> str:
|
||||||
|
return _config.get("model", {}).get("repo_id", "chatterbox-turbo")
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_override() -> Optional[str]:
|
||||||
|
return _config.get("tts_engine", {}).get("device") or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_predefined_voices_path() -> Path:
|
||||||
|
return Path(_config.get("tts_engine", {}).get("predefined_voices_path", "voices"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference_audio_path() -> Path:
|
||||||
|
return Path(_config.get("tts_engine", {}).get("reference_audio_path", "reference_audio"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_voice_id() -> str:
|
||||||
|
return _config.get("tts_engine", {}).get("default_voice_id", "default.wav")
|
||||||
|
|
||||||
|
|
||||||
|
def get_wyoming_host() -> str:
|
||||||
|
return _config.get("wyoming", {}).get("host", "0.0.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
def get_wyoming_port() -> int:
|
||||||
|
return int(_config.get("wyoming", {}).get("port", 10200))
|
||||||
|
|
||||||
|
|
||||||
|
def get_wyoming_chunk_size() -> int:
|
||||||
|
return int(_config.get("wyoming", {}).get("chunk_size", 300))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_temperature() -> float:
|
||||||
|
return float(_config.get("generation_defaults", {}).get("temperature", 0.8))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_exaggeration() -> float:
|
||||||
|
return float(_config.get("generation_defaults", {}).get("exaggeration", 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_cfg_weight() -> float:
|
||||||
|
return float(_config.get("generation_defaults", {}).get("cfg_weight", 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_seed() -> int:
|
||||||
|
return int(_config.get("generation_defaults", {}).get("seed", 0))
|
||||||
29
config.yaml
Normal file
29
config.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
model:
|
||||||
|
# Options: chatterbox, chatterbox-turbo
|
||||||
|
repo_id: chatterbox-turbo
|
||||||
|
|
||||||
|
tts_engine:
|
||||||
|
# Device: cuda, cpu, or leave empty for auto-detect
|
||||||
|
device: ""
|
||||||
|
predefined_voices_path: voices
|
||||||
|
reference_audio_path: reference_audio
|
||||||
|
# Fallback voice (stem name, e.g. "default" matches default.wav)
|
||||||
|
default_voice_id: default.wav
|
||||||
|
|
||||||
|
generation_defaults:
|
||||||
|
# Turbo model: uses temperature only (exaggeration/cfg_weight ignored)
|
||||||
|
# Standard model: uses exaggeration and cfg_weight (temperature ignored)
|
||||||
|
temperature: 0.8
|
||||||
|
exaggeration: 0.5
|
||||||
|
cfg_weight: 0.5
|
||||||
|
# seed: 0 = random each call, >0 = reproducible output
|
||||||
|
seed: 0
|
||||||
|
|
||||||
|
wyoming:
|
||||||
|
host: "0.0.0.0"
|
||||||
|
port: 10200
|
||||||
|
# Max characters per synthesis chunk (split at sentence boundaries)
|
||||||
|
chunk_size: 300
|
||||||
|
|
||||||
|
paths:
|
||||||
|
model_cache: /app/hf_cache
|
||||||
35
docker-compose.yml
Normal file
35
docker-compose.yml
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
services:
|
||||||
|
chatterbox-whisper:
|
||||||
|
image: git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.rocm
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "${WYOMING_PORT:-10200}:10200"
|
||||||
|
devices:
|
||||||
|
- /dev/kfd
|
||||||
|
- /dev/dri
|
||||||
|
group_add:
|
||||||
|
- video
|
||||||
|
- render
|
||||||
|
ipc: host
|
||||||
|
shm_size: 8g
|
||||||
|
security_opt:
|
||||||
|
- seccomp=unconfined
|
||||||
|
volumes:
|
||||||
|
- ./config.yaml:/app/config.yaml
|
||||||
|
- ./voices:/app/voices
|
||||||
|
- ./reference_audio:/app/reference_audio
|
||||||
|
- hf_cache:/app/hf_cache
|
||||||
|
environment:
|
||||||
|
- HF_HUB_ENABLE_HF_TRANSFER=1
|
||||||
|
# Set your GPU architecture:
|
||||||
|
# 10.3.0 = RX 5000/6000 series
|
||||||
|
# 11.0.0 = RX 7000 series
|
||||||
|
# 9.0.6 = Vega
|
||||||
|
- HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||||
|
# - HF_TOKEN=your_token_here
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
hf_cache:
|
||||||
91
engine.py
Normal file
91
engine.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
chatterbox_model = None
|
||||||
|
_sample_rate = 24000
|
||||||
|
_is_turbo = False
|
||||||
|
|
||||||
|
|
||||||
|
def _test_cuda() -> bool:
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.zeros(1).cuda()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_device() -> str:
|
||||||
|
return "cuda" if _test_cuda() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def load_model() -> bool:
|
||||||
|
global chatterbox_model, _sample_rate, _is_turbo
|
||||||
|
|
||||||
|
from config import get_model_repo_id, get_device_override
|
||||||
|
|
||||||
|
device = get_device_override() or detect_device()
|
||||||
|
repo_id = get_model_repo_id()
|
||||||
|
|
||||||
|
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "turbo" in repo_id.lower():
|
||||||
|
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||||
|
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
||||||
|
_is_turbo = True
|
||||||
|
else:
|
||||||
|
from chatterbox.tts import ChatterboxTTS
|
||||||
|
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
||||||
|
_is_turbo = False
|
||||||
|
|
||||||
|
_sample_rate = 24000
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load model")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_rate() -> int:
|
||||||
|
return _sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
text: str,
|
||||||
|
audio_prompt_path: Optional[str] = None,
|
||||||
|
exaggeration: float = 0.5,
|
||||||
|
cfg_weight: float = 0.5,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
seed: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, int]:
|
||||||
|
if chatterbox_model is None:
|
||||||
|
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||||
|
|
||||||
|
if seed > 0:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
kwargs: dict = {}
|
||||||
|
if audio_prompt_path:
|
||||||
|
kwargs["audio_prompt_path"] = audio_prompt_path
|
||||||
|
|
||||||
|
if _is_turbo:
|
||||||
|
kwargs["temperature"] = temperature
|
||||||
|
else:
|
||||||
|
kwargs["exaggeration"] = exaggeration
|
||||||
|
kwargs["cfg_weight"] = cfg_weight
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
wav = chatterbox_model.generate(text=text, **kwargs)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return wav, _sample_rate
|
||||||
42
main.py
Normal file
42
main.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from wyoming.server import AsyncServer
|
||||||
|
|
||||||
|
import engine
|
||||||
|
from config import get_wyoming_host, get_wyoming_port, load_config
|
||||||
|
from wyoming_handler import ChatterboxWyomingHandler
|
||||||
|
from wyoming_voices import create_wyoming_info, load_voices
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
stream=sys.stdout,
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
load_config()
|
||||||
|
|
||||||
|
logger.info("Loading TTS model...")
|
||||||
|
if not engine.load_model():
|
||||||
|
logger.error("Failed to load model, exiting")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
voices = load_voices()
|
||||||
|
wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices)
|
||||||
|
|
||||||
|
host = get_wyoming_host()
|
||||||
|
port = get_wyoming_port()
|
||||||
|
uri = f"tcp://{host}:{port}"
|
||||||
|
|
||||||
|
logger.info(f"Starting Wyoming server on {uri}")
|
||||||
|
server = AsyncServer.from_uri(uri)
|
||||||
|
await server.run(partial(ChatterboxWyomingHandler, wyoming_info, voices))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
5
requirements-rocm-init.txt
Normal file
5
requirements-rocm-init.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
--index-url https://download.pytorch.org/whl/rocm6.1
|
||||||
|
torch==2.5.1
|
||||||
|
torchaudio==2.5.1
|
||||||
|
torchvision==0.20.1
|
||||||
|
pytorch_triton_rocm==3.1.0
|
||||||
17
requirements-rocm.txt
Normal file
17
requirements-rocm.txt
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Audio processing
|
||||||
|
numpy
|
||||||
|
soundfile
|
||||||
|
librosa
|
||||||
|
|
||||||
|
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
|
||||||
|
transformers==4.46.3
|
||||||
|
diffusers==0.29.0
|
||||||
|
safetensors
|
||||||
|
huggingface-hub
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
# Wyoming protocol
|
||||||
|
wyoming>=1.5.4
|
||||||
|
|
||||||
|
# Config / utilities
|
||||||
|
PyYAML>=6.0
|
||||||
166
wyoming_handler.py
Normal file
166
wyoming_handler.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
|
from wyoming.event import Event
|
||||||
|
from wyoming.info import Describe, Info
|
||||||
|
from wyoming.server import AsyncEventHandler
|
||||||
|
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
|
||||||
|
|
||||||
|
import engine
|
||||||
|
from config import (
|
||||||
|
get_gen_cfg_weight,
|
||||||
|
get_gen_exaggeration,
|
||||||
|
get_gen_seed,
|
||||||
|
get_gen_temperature,
|
||||||
|
get_wyoming_chunk_size,
|
||||||
|
)
|
||||||
|
from wyoming_voices import resolve_voice
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
|
||||||
|
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Split on sentence-ending punctuation
|
||||||
|
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current = ""
|
||||||
|
for sentence in sentences:
|
||||||
|
if not sentence:
|
||||||
|
continue
|
||||||
|
if current and len(current) + 1 + len(sentence) > max_chunk:
|
||||||
|
chunks.append(current.strip())
|
||||||
|
current = sentence
|
||||||
|
else:
|
||||||
|
current = (current + " " + sentence).strip() if current else sentence
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunks.append(current.strip())
|
||||||
|
|
||||||
|
return chunks or [text]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterboxWyomingHandler(AsyncEventHandler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
wyoming_info: Info,
|
||||||
|
voices: Dict[str, str],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._info = wyoming_info
|
||||||
|
self._voices = voices
|
||||||
|
|
||||||
|
self._streaming = False
|
||||||
|
self._streaming_text = ""
|
||||||
|
self._streaming_voice: Optional[str] = None
|
||||||
|
|
||||||
|
async def handle_event(self, event: Event) -> bool:
|
||||||
|
if Describe.is_type(event.type):
|
||||||
|
await self.write_event(self._info.event())
|
||||||
|
return True
|
||||||
|
|
||||||
|
if SynthesizeStart.is_type(event.type):
|
||||||
|
start = SynthesizeStart.from_event(event)
|
||||||
|
self._streaming = True
|
||||||
|
self._streaming_text = ""
|
||||||
|
self._streaming_voice = start.voice.name if start.voice else None
|
||||||
|
return True
|
||||||
|
|
||||||
|
if SynthesizeChunk.is_type(event.type):
|
||||||
|
chunk = SynthesizeChunk.from_event(event)
|
||||||
|
if self._streaming:
|
||||||
|
self._streaming_text += chunk.text
|
||||||
|
return True
|
||||||
|
|
||||||
|
if SynthesizeStop.is_type(event.type):
|
||||||
|
if self._streaming and self._streaming_text:
|
||||||
|
await self._synthesize_and_stream(
|
||||||
|
self._streaming_text,
|
||||||
|
self._streaming_voice,
|
||||||
|
)
|
||||||
|
self._streaming = False
|
||||||
|
self._streaming_text = ""
|
||||||
|
self._streaming_voice = None
|
||||||
|
await self.write_event(SynthesizeStopped().event())
|
||||||
|
return True
|
||||||
|
|
||||||
|
if Synthesize.is_type(event.type):
|
||||||
|
synth = Synthesize.from_event(event)
|
||||||
|
voice_name = synth.voice.name if synth.voice else None
|
||||||
|
await self._synthesize_and_stream(synth.text, voice_name)
|
||||||
|
await self.write_event(SynthesizeStopped().event())
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
|
||||||
|
audio_prompt = resolve_voice(voice_name, self._voices)
|
||||||
|
sample_rate = engine.get_sample_rate()
|
||||||
|
chunk_size = get_wyoming_chunk_size()
|
||||||
|
|
||||||
|
chunks = _split_text(text, max_chunk=chunk_size)
|
||||||
|
logger.info(
|
||||||
|
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
|
||||||
|
f"prompt='{audio_prompt}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.write_event(
|
||||||
|
AudioStart(rate=sample_rate, width=2, channels=1).event()
|
||||||
|
)
|
||||||
|
|
||||||
|
first_chunk = True
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
for i, chunk_text in enumerate(chunks):
|
||||||
|
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
self._synthesize_chunk,
|
||||||
|
chunk_text,
|
||||||
|
audio_prompt,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Synthesis failed for chunk {i+1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
audio_np = audio_tensor.cpu().numpy().squeeze()
|
||||||
|
audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||||
|
|
||||||
|
if first_chunk:
|
||||||
|
ttfa = time.monotonic() - start_time
|
||||||
|
logger.info(f"Time to first audio: {ttfa:.3f}s")
|
||||||
|
first_chunk = False
|
||||||
|
|
||||||
|
await self.write_event(
|
||||||
|
AudioChunk(
|
||||||
|
audio=audio_bytes,
|
||||||
|
rate=sample_rate,
|
||||||
|
width=2,
|
||||||
|
channels=1,
|
||||||
|
).event()
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.write_event(AudioStop().event())
|
||||||
|
total = time.monotonic() - start_time
|
||||||
|
logger.info(f"Synthesis complete in {total:.3f}s")
|
||||||
|
|
||||||
|
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
|
||||||
|
return engine.synthesize(
|
||||||
|
text=text,
|
||||||
|
audio_prompt_path=audio_prompt,
|
||||||
|
temperature=get_gen_temperature(),
|
||||||
|
exaggeration=get_gen_exaggeration(),
|
||||||
|
cfg_weight=get_gen_cfg_weight(),
|
||||||
|
seed=get_gen_seed(),
|
||||||
|
)
|
||||||
99
wyoming_voices.py
Normal file
99
wyoming_voices.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from wyoming.info import Attribution, Info, TtsProgram, TtsVoice, TtsVoiceSpeaker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VOICE_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg"}
|
||||||
|
|
||||||
|
|
||||||
|
def load_voices() -> Dict[str, str]:
|
||||||
|
"""Scan voice directories and return {voice_name: file_path} mapping."""
|
||||||
|
from config import get_predefined_voices_path, get_reference_audio_path
|
||||||
|
|
||||||
|
voices: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def _scan_dir(directory: Path) -> None:
|
||||||
|
if not directory.exists():
|
||||||
|
return
|
||||||
|
for f in sorted(directory.iterdir()):
|
||||||
|
if f.suffix.lower() in VOICE_EXTENSIONS:
|
||||||
|
name = f.stem
|
||||||
|
if name not in voices:
|
||||||
|
voices[name] = str(f)
|
||||||
|
|
||||||
|
# Reference audio first so predefined voices take priority on collision
|
||||||
|
_scan_dir(get_reference_audio_path())
|
||||||
|
_scan_dir(get_predefined_voices_path())
|
||||||
|
|
||||||
|
logger.info(f"Discovered {len(voices)} voice(s): {list(voices.keys())}")
|
||||||
|
return voices
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_voice(voice_name: Optional[str], voices: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Resolve a voice name to its audio file path."""
|
||||||
|
from config import get_predefined_voices_path, get_reference_audio_path, get_default_voice_id
|
||||||
|
|
||||||
|
if not voice_name:
|
||||||
|
default = get_default_voice_id()
|
||||||
|
voice_name = Path(default).stem
|
||||||
|
|
||||||
|
# Exact name match in discovered voices
|
||||||
|
if voice_name in voices:
|
||||||
|
return voices[voice_name]
|
||||||
|
|
||||||
|
# Try predefined voices dir with extensions
|
||||||
|
for ext in VOICE_EXTENSIONS:
|
||||||
|
p = get_predefined_voices_path() / f"{voice_name}{ext}"
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
# Try reference audio dir with extensions
|
||||||
|
for ext in VOICE_EXTENSIONS:
|
||||||
|
p = get_reference_audio_path() / f"{voice_name}{ext}"
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
# Fall back to any voice
|
||||||
|
if voices:
|
||||||
|
fallback = next(iter(voices.values()))
|
||||||
|
logger.warning(f"Voice '{voice_name}' not found, falling back to '{fallback}'")
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_wyoming_info(sample_rate: int, voices: Dict[str, str]) -> Info:
|
||||||
|
"""Build the Wyoming Info object advertised to Home Assistant."""
|
||||||
|
tts_voices = [
|
||||||
|
TtsVoice(
|
||||||
|
name=name,
|
||||||
|
description=f"Chatterbox voice: {name}",
|
||||||
|
attribution=Attribution(
|
||||||
|
name="ResembleAI",
|
||||||
|
url="https://github.com/resemble-ai/chatterbox",
|
||||||
|
),
|
||||||
|
installed=True,
|
||||||
|
languages=["en"],
|
||||||
|
speakers=[TtsVoiceSpeaker(name=name)],
|
||||||
|
)
|
||||||
|
for name in sorted(voices.keys())
|
||||||
|
]
|
||||||
|
|
||||||
|
return Info(
|
||||||
|
tts=[
|
||||||
|
TtsProgram(
|
||||||
|
name="chatterbox",
|
||||||
|
description="Chatterbox TTS with ROCm/AMD GPU support",
|
||||||
|
attribution=Attribution(
|
||||||
|
name="ResembleAI",
|
||||||
|
url="https://github.com/resemble-ai/chatterbox",
|
||||||
|
),
|
||||||
|
installed=True,
|
||||||
|
voices=tts_voices,
|
||||||
|
version="1.0",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user