Initial implementation: Chatterbox TTS with ROCm and Wyoming
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
Wyoming-only server built around the official chatterbox TTS model. Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml management, and Gitea CI for container builds. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
33
.gitea/workflows/build.yml
Normal file
33
.gitea/workflows/build.yml
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Build ROCm Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Gitea Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: git.sdgarren.com
|
||||
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||
|
||||
- name: Build and Push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile.rocm
|
||||
push: true
|
||||
tags: |
|
||||
git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
|
||||
git.sdgarren.com/scott/rocm-chatterbox-whisper:${{ gitea.sha }}
|
||||
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
.env
|
||||
|
||||
# Voice and audio files (mount via volume)
|
||||
voices/
|
||||
reference_audio/
|
||||
outputs/
|
||||
|
||||
# Model cache
|
||||
hf_cache/
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
43
Dockerfile.rocm
Normal file
43
Dockerfile.rocm
Normal file
@@ -0,0 +1,43 @@
|
||||
FROM rocm/dev-ubuntu-22.04:latest
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
HF_HOME=/app/hf_cache \
|
||||
PIP_NO_CACHE_DIR=1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
git \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Step 1: Install ROCm-compatible PyTorch stack first.
|
||||
# This must happen before anything else to prevent pip from pulling CPU wheels.
|
||||
COPY requirements-rocm-init.txt .
|
||||
RUN pip3 install -r requirements-rocm-init.txt
|
||||
|
||||
# Step 2: Install remaining dependencies (pinned to avoid overwriting torch).
|
||||
COPY requirements-rocm.txt .
|
||||
RUN pip3 install -r requirements-rocm.txt
|
||||
|
||||
# Step 3: Install chatterbox with --no-deps so pip cannot replace ROCm torch.
|
||||
RUN pip3 install --no-deps chatterbox-tts
|
||||
|
||||
# Application source
|
||||
COPY engine.py config.py wyoming_handler.py wyoming_voices.py main.py ./
|
||||
|
||||
# Default config (can be overridden by volume mount)
|
||||
COPY config.yaml .
|
||||
|
||||
# Create default directories
|
||||
RUN mkdir -p voices reference_audio hf_cache
|
||||
|
||||
EXPOSE 10200
|
||||
|
||||
CMD ["python3", "main.py"]
|
||||
115
config.py
Normal file
115
config.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_config: dict = {}
|
||||
_config_path = Path(os.environ.get("CONFIG_PATH", "config.yaml"))
|
||||
|
||||
DEFAULTS = {
|
||||
"model": {
|
||||
"repo_id": "chatterbox-turbo",
|
||||
},
|
||||
"tts_engine": {
|
||||
"device": "", # empty = auto-detect
|
||||
"predefined_voices_path": "voices",
|
||||
"reference_audio_path": "reference_audio",
|
||||
"default_voice_id": "default.wav",
|
||||
},
|
||||
"generation_defaults": {
|
||||
"temperature": 0.8,
|
||||
"exaggeration": 0.5,
|
||||
"cfg_weight": 0.5,
|
||||
"seed": 0,
|
||||
},
|
||||
"wyoming": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 10200,
|
||||
"chunk_size": 300,
|
||||
},
|
||||
"paths": {
|
||||
"model_cache": "/app/hf_cache",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
result = base.copy()
|
||||
for key, value in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def load_config() -> None:
|
||||
global _config
|
||||
_config = _deep_merge(DEFAULTS, {})
|
||||
if _config_path.exists():
|
||||
try:
|
||||
with open(_config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
_config = _deep_merge(_config, user_config)
|
||||
logger.info(f"Loaded config from {_config_path}")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load {_config_path}, using defaults")
|
||||
else:
|
||||
logger.warning(f"Config file not found at {_config_path}, using defaults")
|
||||
|
||||
# Set HuggingFace cache path from config
|
||||
cache_path = _config.get("paths", {}).get("model_cache", "")
|
||||
if cache_path:
|
||||
os.environ.setdefault("HF_HOME", cache_path)
|
||||
|
||||
|
||||
def get_model_repo_id() -> str:
|
||||
return _config.get("model", {}).get("repo_id", "chatterbox-turbo")
|
||||
|
||||
|
||||
def get_device_override() -> Optional[str]:
|
||||
return _config.get("tts_engine", {}).get("device") or None
|
||||
|
||||
|
||||
def get_predefined_voices_path() -> Path:
|
||||
return Path(_config.get("tts_engine", {}).get("predefined_voices_path", "voices"))
|
||||
|
||||
|
||||
def get_reference_audio_path() -> Path:
|
||||
return Path(_config.get("tts_engine", {}).get("reference_audio_path", "reference_audio"))
|
||||
|
||||
|
||||
def get_default_voice_id() -> str:
|
||||
return _config.get("tts_engine", {}).get("default_voice_id", "default.wav")
|
||||
|
||||
|
||||
def get_wyoming_host() -> str:
|
||||
return _config.get("wyoming", {}).get("host", "0.0.0.0")
|
||||
|
||||
|
||||
def get_wyoming_port() -> int:
|
||||
return int(_config.get("wyoming", {}).get("port", 10200))
|
||||
|
||||
|
||||
def get_wyoming_chunk_size() -> int:
|
||||
return int(_config.get("wyoming", {}).get("chunk_size", 300))
|
||||
|
||||
|
||||
def get_gen_temperature() -> float:
|
||||
return float(_config.get("generation_defaults", {}).get("temperature", 0.8))
|
||||
|
||||
|
||||
def get_gen_exaggeration() -> float:
|
||||
return float(_config.get("generation_defaults", {}).get("exaggeration", 0.5))
|
||||
|
||||
|
||||
def get_gen_cfg_weight() -> float:
|
||||
return float(_config.get("generation_defaults", {}).get("cfg_weight", 0.5))
|
||||
|
||||
|
||||
def get_gen_seed() -> int:
|
||||
return int(_config.get("generation_defaults", {}).get("seed", 0))
|
||||
29
config.yaml
Normal file
29
config.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
model:
|
||||
# Options: chatterbox, chatterbox-turbo
|
||||
repo_id: chatterbox-turbo
|
||||
|
||||
tts_engine:
|
||||
# Device: cuda, cpu, or leave empty for auto-detect
|
||||
device: ""
|
||||
predefined_voices_path: voices
|
||||
reference_audio_path: reference_audio
|
||||
# Fallback voice (stem name, e.g. "default" matches default.wav)
|
||||
default_voice_id: default.wav
|
||||
|
||||
generation_defaults:
|
||||
# Turbo model: uses temperature only (exaggeration/cfg_weight ignored)
|
||||
# Standard model: uses exaggeration and cfg_weight (temperature ignored)
|
||||
temperature: 0.8
|
||||
exaggeration: 0.5
|
||||
cfg_weight: 0.5
|
||||
# seed: 0 = random each call, >0 = reproducible output
|
||||
seed: 0
|
||||
|
||||
wyoming:
|
||||
host: "0.0.0.0"
|
||||
port: 10200
|
||||
# Max characters per synthesis chunk (split at sentence boundaries)
|
||||
chunk_size: 300
|
||||
|
||||
paths:
|
||||
model_cache: /app/hf_cache
|
||||
35
docker-compose.yml
Normal file
35
docker-compose.yml
Normal file
@@ -0,0 +1,35 @@
|
||||
services:
|
||||
chatterbox-whisper:
|
||||
image: git.sdgarren.com/scott/rocm-chatterbox-whisper:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.rocm
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${WYOMING_PORT:-10200}:10200"
|
||||
devices:
|
||||
- /dev/kfd
|
||||
- /dev/dri
|
||||
group_add:
|
||||
- video
|
||||
- render
|
||||
ipc: host
|
||||
shm_size: 8g
|
||||
security_opt:
|
||||
- seccomp=unconfined
|
||||
volumes:
|
||||
- ./config.yaml:/app/config.yaml
|
||||
- ./voices:/app/voices
|
||||
- ./reference_audio:/app/reference_audio
|
||||
- hf_cache:/app/hf_cache
|
||||
environment:
|
||||
- HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
# Set your GPU architecture:
|
||||
# 10.3.0 = RX 5000/6000 series
|
||||
# 11.0.0 = RX 7000 series
|
||||
# 9.0.6 = Vega
|
||||
- HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
# - HF_TOKEN=your_token_here
|
||||
|
||||
volumes:
|
||||
hf_cache:
|
||||
91
engine.py
Normal file
91
engine.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
chatterbox_model = None
|
||||
_sample_rate = 24000
|
||||
_is_turbo = False
|
||||
|
||||
|
||||
def _test_cuda() -> bool:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.zeros(1).cuda()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def detect_device() -> str:
|
||||
return "cuda" if _test_cuda() else "cpu"
|
||||
|
||||
|
||||
def load_model() -> bool:
|
||||
global chatterbox_model, _sample_rate, _is_turbo
|
||||
|
||||
from config import get_model_repo_id, get_device_override
|
||||
|
||||
device = get_device_override() or detect_device()
|
||||
repo_id = get_model_repo_id()
|
||||
|
||||
logger.info(f"Loading model '{repo_id}' on device '{device}'")
|
||||
|
||||
try:
|
||||
if "turbo" in repo_id.lower():
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
chatterbox_model = ChatterboxTurboTTS.from_pretrained(device)
|
||||
_is_turbo = True
|
||||
else:
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
chatterbox_model = ChatterboxTTS.from_pretrained(device)
|
||||
_is_turbo = False
|
||||
|
||||
_sample_rate = 24000
|
||||
logger.info("Model loaded successfully")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load model")
|
||||
return False
|
||||
|
||||
|
||||
def get_sample_rate() -> int:
|
||||
return _sample_rate
|
||||
|
||||
|
||||
def synthesize(
|
||||
text: str,
|
||||
audio_prompt_path: Optional[str] = None,
|
||||
exaggeration: float = 0.5,
|
||||
cfg_weight: float = 0.5,
|
||||
temperature: float = 0.8,
|
||||
seed: int = 0,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
if chatterbox_model is None:
|
||||
raise RuntimeError("Model not loaded. Call load_model() first.")
|
||||
|
||||
if seed > 0:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
kwargs: dict = {}
|
||||
if audio_prompt_path:
|
||||
kwargs["audio_prompt_path"] = audio_prompt_path
|
||||
|
||||
if _is_turbo:
|
||||
kwargs["temperature"] = temperature
|
||||
else:
|
||||
kwargs["exaggeration"] = exaggeration
|
||||
kwargs["cfg_weight"] = cfg_weight
|
||||
|
||||
with torch.inference_mode():
|
||||
wav = chatterbox_model.generate(text=text, **kwargs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return wav, _sample_rate
|
||||
42
main.py
Normal file
42
main.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from wyoming.server import AsyncServer
|
||||
|
||||
import engine
|
||||
from config import get_wyoming_host, get_wyoming_port, load_config
|
||||
from wyoming_handler import ChatterboxWyomingHandler
|
||||
from wyoming_voices import create_wyoming_info, load_voices
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
stream=sys.stdout,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
load_config()
|
||||
|
||||
logger.info("Loading TTS model...")
|
||||
if not engine.load_model():
|
||||
logger.error("Failed to load model, exiting")
|
||||
sys.exit(1)
|
||||
|
||||
voices = load_voices()
|
||||
wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices)
|
||||
|
||||
host = get_wyoming_host()
|
||||
port = get_wyoming_port()
|
||||
uri = f"tcp://{host}:{port}"
|
||||
|
||||
logger.info(f"Starting Wyoming server on {uri}")
|
||||
server = AsyncServer.from_uri(uri)
|
||||
await server.run(partial(ChatterboxWyomingHandler, wyoming_info, voices))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
requirements-rocm-init.txt
Normal file
5
requirements-rocm-init.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--index-url https://download.pytorch.org/whl/rocm6.1
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
torchvision==0.20.1
|
||||
pytorch_triton_rocm==3.1.0
|
||||
17
requirements-rocm.txt
Normal file
17
requirements-rocm.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Audio processing
|
||||
numpy
|
||||
soundfile
|
||||
librosa
|
||||
|
||||
# ML dependencies (pinned to match chatterbox without overwriting ROCm torch)
|
||||
transformers==4.46.3
|
||||
diffusers==0.29.0
|
||||
safetensors
|
||||
huggingface-hub
|
||||
accelerate
|
||||
|
||||
# Wyoming protocol
|
||||
wyoming>=1.5.4
|
||||
|
||||
# Config / utilities
|
||||
PyYAML>=6.0
|
||||
166
wyoming_handler.py
Normal file
166
wyoming_handler.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.server import AsyncEventHandler
|
||||
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
|
||||
|
||||
import engine
|
||||
from config import (
|
||||
get_gen_cfg_weight,
|
||||
get_gen_exaggeration,
|
||||
get_gen_seed,
|
||||
get_gen_temperature,
|
||||
get_wyoming_chunk_size,
|
||||
)
|
||||
from wyoming_voices import resolve_voice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
|
||||
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
|
||||
import re
|
||||
|
||||
# Split on sentence-ending punctuation
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||
|
||||
chunks = []
|
||||
current = ""
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if current and len(current) + 1 + len(sentence) > max_chunk:
|
||||
chunks.append(current.strip())
|
||||
current = sentence
|
||||
else:
|
||||
current = (current + " " + sentence).strip() if current else sentence
|
||||
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks or [text]
|
||||
|
||||
|
||||
class ChatterboxWyomingHandler(AsyncEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
wyoming_info: Info,
|
||||
voices: Dict[str, str],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._info = wyoming_info
|
||||
self._voices = voices
|
||||
|
||||
self._streaming = False
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice: Optional[str] = None
|
||||
|
||||
async def handle_event(self, event: Event) -> bool:
|
||||
if Describe.is_type(event.type):
|
||||
await self.write_event(self._info.event())
|
||||
return True
|
||||
|
||||
if SynthesizeStart.is_type(event.type):
|
||||
start = SynthesizeStart.from_event(event)
|
||||
self._streaming = True
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice = start.voice.name if start.voice else None
|
||||
return True
|
||||
|
||||
if SynthesizeChunk.is_type(event.type):
|
||||
chunk = SynthesizeChunk.from_event(event)
|
||||
if self._streaming:
|
||||
self._streaming_text += chunk.text
|
||||
return True
|
||||
|
||||
if SynthesizeStop.is_type(event.type):
|
||||
if self._streaming and self._streaming_text:
|
||||
await self._synthesize_and_stream(
|
||||
self._streaming_text,
|
||||
self._streaming_voice,
|
||||
)
|
||||
self._streaming = False
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice = None
|
||||
await self.write_event(SynthesizeStopped().event())
|
||||
return True
|
||||
|
||||
if Synthesize.is_type(event.type):
|
||||
synth = Synthesize.from_event(event)
|
||||
voice_name = synth.voice.name if synth.voice else None
|
||||
await self._synthesize_and_stream(synth.text, voice_name)
|
||||
await self.write_event(SynthesizeStopped().event())
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
|
||||
audio_prompt = resolve_voice(voice_name, self._voices)
|
||||
sample_rate = engine.get_sample_rate()
|
||||
chunk_size = get_wyoming_chunk_size()
|
||||
|
||||
chunks = _split_text(text, max_chunk=chunk_size)
|
||||
logger.info(
|
||||
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
|
||||
f"prompt='{audio_prompt}'"
|
||||
)
|
||||
|
||||
await self.write_event(
|
||||
AudioStart(rate=sample_rate, width=2, channels=1).event()
|
||||
)
|
||||
|
||||
first_chunk = True
|
||||
start_time = time.monotonic()
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
|
||||
|
||||
try:
|
||||
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._synthesize_chunk,
|
||||
chunk_text,
|
||||
audio_prompt,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Synthesis failed for chunk {i+1}")
|
||||
continue
|
||||
|
||||
audio_np = audio_tensor.cpu().numpy().squeeze()
|
||||
audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||
|
||||
if first_chunk:
|
||||
ttfa = time.monotonic() - start_time
|
||||
logger.info(f"Time to first audio: {ttfa:.3f}s")
|
||||
first_chunk = False
|
||||
|
||||
await self.write_event(
|
||||
AudioChunk(
|
||||
audio=audio_bytes,
|
||||
rate=sample_rate,
|
||||
width=2,
|
||||
channels=1,
|
||||
).event()
|
||||
)
|
||||
|
||||
await self.write_event(AudioStop().event())
|
||||
total = time.monotonic() - start_time
|
||||
logger.info(f"Synthesis complete in {total:.3f}s")
|
||||
|
||||
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
|
||||
return engine.synthesize(
|
||||
text=text,
|
||||
audio_prompt_path=audio_prompt,
|
||||
temperature=get_gen_temperature(),
|
||||
exaggeration=get_gen_exaggeration(),
|
||||
cfg_weight=get_gen_cfg_weight(),
|
||||
seed=get_gen_seed(),
|
||||
)
|
||||
99
wyoming_voices.py
Normal file
99
wyoming_voices.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from wyoming.info import Attribution, Info, TtsProgram, TtsVoice, TtsVoiceSpeaker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOICE_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg"}
|
||||
|
||||
|
||||
def load_voices() -> Dict[str, str]:
|
||||
"""Scan voice directories and return {voice_name: file_path} mapping."""
|
||||
from config import get_predefined_voices_path, get_reference_audio_path
|
||||
|
||||
voices: Dict[str, str] = {}
|
||||
|
||||
def _scan_dir(directory: Path) -> None:
|
||||
if not directory.exists():
|
||||
return
|
||||
for f in sorted(directory.iterdir()):
|
||||
if f.suffix.lower() in VOICE_EXTENSIONS:
|
||||
name = f.stem
|
||||
if name not in voices:
|
||||
voices[name] = str(f)
|
||||
|
||||
# Reference audio first so predefined voices take priority on collision
|
||||
_scan_dir(get_reference_audio_path())
|
||||
_scan_dir(get_predefined_voices_path())
|
||||
|
||||
logger.info(f"Discovered {len(voices)} voice(s): {list(voices.keys())}")
|
||||
return voices
|
||||
|
||||
|
||||
def resolve_voice(voice_name: Optional[str], voices: Dict[str, str]) -> Optional[str]:
|
||||
"""Resolve a voice name to its audio file path."""
|
||||
from config import get_predefined_voices_path, get_reference_audio_path, get_default_voice_id
|
||||
|
||||
if not voice_name:
|
||||
default = get_default_voice_id()
|
||||
voice_name = Path(default).stem
|
||||
|
||||
# Exact name match in discovered voices
|
||||
if voice_name in voices:
|
||||
return voices[voice_name]
|
||||
|
||||
# Try predefined voices dir with extensions
|
||||
for ext in VOICE_EXTENSIONS:
|
||||
p = get_predefined_voices_path() / f"{voice_name}{ext}"
|
||||
if p.exists():
|
||||
return str(p)
|
||||
|
||||
# Try reference audio dir with extensions
|
||||
for ext in VOICE_EXTENSIONS:
|
||||
p = get_reference_audio_path() / f"{voice_name}{ext}"
|
||||
if p.exists():
|
||||
return str(p)
|
||||
|
||||
# Fall back to any voice
|
||||
if voices:
|
||||
fallback = next(iter(voices.values()))
|
||||
logger.warning(f"Voice '{voice_name}' not found, falling back to '{fallback}'")
|
||||
return fallback
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_wyoming_info(sample_rate: int, voices: Dict[str, str]) -> Info:
|
||||
"""Build the Wyoming Info object advertised to Home Assistant."""
|
||||
tts_voices = [
|
||||
TtsVoice(
|
||||
name=name,
|
||||
description=f"Chatterbox voice: {name}",
|
||||
attribution=Attribution(
|
||||
name="ResembleAI",
|
||||
url="https://github.com/resemble-ai/chatterbox",
|
||||
),
|
||||
installed=True,
|
||||
languages=["en"],
|
||||
speakers=[TtsVoiceSpeaker(name=name)],
|
||||
)
|
||||
for name in sorted(voices.keys())
|
||||
]
|
||||
|
||||
return Info(
|
||||
tts=[
|
||||
TtsProgram(
|
||||
name="chatterbox",
|
||||
description="Chatterbox TTS with ROCm/AMD GPU support",
|
||||
attribution=Attribution(
|
||||
name="ResembleAI",
|
||||
url="https://github.com/resemble-ai/chatterbox",
|
||||
),
|
||||
installed=True,
|
||||
voices=tts_voices,
|
||||
version="1.0",
|
||||
)
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user