Files
rocm-chatterbox-whisper/wyoming_handler.py
scott 16ea2853f5
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
Initial implementation: Chatterbox TTS with ROCm and Wyoming
Wyoming-only server built around the official chatterbox TTS model.
Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml
management, and Gitea CI for container builds.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 09:51:09 -04:00

167 lines
5.3 KiB
Python

import asyncio
import logging
import time
from typing import Dict, Optional
import numpy as np
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
import engine
from config import (
get_gen_cfg_weight,
get_gen_exaggeration,
get_gen_seed,
get_gen_temperature,
get_wyoming_chunk_size,
)
from wyoming_voices import resolve_voice
logger = logging.getLogger(__name__)
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
import re
# Split on sentence-ending punctuation
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
chunks = []
current = ""
for sentence in sentences:
if not sentence:
continue
if current and len(current) + 1 + len(sentence) > max_chunk:
chunks.append(current.strip())
current = sentence
else:
current = (current + " " + sentence).strip() if current else sentence
if current:
chunks.append(current.strip())
return chunks or [text]
class ChatterboxWyomingHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
voices: Dict[str, str],
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._info = wyoming_info
self._voices = voices
self._streaming = False
self._streaming_text = ""
self._streaming_voice: Optional[str] = None
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self._info.event())
return True
if SynthesizeStart.is_type(event.type):
start = SynthesizeStart.from_event(event)
self._streaming = True
self._streaming_text = ""
self._streaming_voice = start.voice.name if start.voice else None
return True
if SynthesizeChunk.is_type(event.type):
chunk = SynthesizeChunk.from_event(event)
if self._streaming:
self._streaming_text += chunk.text
return True
if SynthesizeStop.is_type(event.type):
if self._streaming and self._streaming_text:
await self._synthesize_and_stream(
self._streaming_text,
self._streaming_voice,
)
self._streaming = False
self._streaming_text = ""
self._streaming_voice = None
await self.write_event(SynthesizeStopped().event())
return True
if Synthesize.is_type(event.type):
synth = Synthesize.from_event(event)
voice_name = synth.voice.name if synth.voice else None
await self._synthesize_and_stream(synth.text, voice_name)
await self.write_event(SynthesizeStopped().event())
return True
return True
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
audio_prompt = resolve_voice(voice_name, self._voices)
sample_rate = engine.get_sample_rate()
chunk_size = get_wyoming_chunk_size()
chunks = _split_text(text, max_chunk=chunk_size)
logger.info(
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
f"prompt='{audio_prompt}'"
)
await self.write_event(
AudioStart(rate=sample_rate, width=2, channels=1).event()
)
first_chunk = True
start_time = time.monotonic()
for i, chunk_text in enumerate(chunks):
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
try:
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
None,
self._synthesize_chunk,
chunk_text,
audio_prompt,
)
except Exception:
logger.exception(f"Synthesis failed for chunk {i+1}")
continue
audio_np = audio_tensor.cpu().numpy().squeeze()
audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
if first_chunk:
ttfa = time.monotonic() - start_time
logger.info(f"Time to first audio: {ttfa:.3f}s")
first_chunk = False
await self.write_event(
AudioChunk(
audio=audio_bytes,
rate=sample_rate,
width=2,
channels=1,
).event()
)
await self.write_event(AudioStop().event())
total = time.monotonic() - start_time
logger.info(f"Synthesis complete in {total:.3f}s")
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
return engine.synthesize(
text=text,
audio_prompt_path=audio_prompt,
temperature=get_gen_temperature(),
exaggeration=get_gen_exaggeration(),
cfg_weight=get_gen_cfg_weight(),
seed=get_gen_seed(),
)