Files
rocm-chatterbox-whisper/wyoming_handler.py
scott d0f13dea8d
All checks were successful
Build ROCm Image / build (push) Successful in 3m50s
Log incoming HA text in synthesis request line
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 16:26:52 -04:00

171 lines
5.5 KiB
Python

import asyncio
import logging
import time
from typing import Dict, Optional
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
import engine
from config import (
get_gen_cfg_weight,
get_gen_exaggeration,
get_gen_seed,
get_gen_temperature,
get_wyoming_chunk_size,
)
from wyoming_voices import resolve_voice
logger = logging.getLogger(__name__)
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
import re
# Split on sentence-ending punctuation
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
chunks = []
current = ""
for sentence in sentences:
if not sentence:
continue
if current and len(current) + 1 + len(sentence) > max_chunk:
chunks.append(current.strip())
current = sentence
else:
current = (current + " " + sentence).strip() if current else sentence
if current:
chunks.append(current.strip())
return chunks or [text]
class ChatterboxWyomingHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
voices: Dict[str, str],
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._info = wyoming_info
self._voices = voices
self._streaming = False
self._streaming_text = ""
self._streaming_voice: Optional[str] = None
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self._info.event())
return True
if SynthesizeStart.is_type(event.type):
start = SynthesizeStart.from_event(event)
self._streaming = True
self._streaming_text = ""
self._streaming_voice = start.voice.name if start.voice else None
return True
if SynthesizeChunk.is_type(event.type):
chunk = SynthesizeChunk.from_event(event)
if self._streaming:
self._streaming_text += chunk.text
return True
if SynthesizeStop.is_type(event.type):
if self._streaming and self._streaming_text:
await self._synthesize_and_stream(
self._streaming_text,
self._streaming_voice,
)
self._streaming = False
self._streaming_text = ""
self._streaming_voice = None
await self.write_event(SynthesizeStopped().event())
return True
if Synthesize.is_type(event.type):
if self._streaming:
# Ignore duplicate Synthesize events sent alongside streaming protocol
return True
synth = Synthesize.from_event(event)
voice_name = synth.voice.name if synth.voice else None
await self._synthesize_and_stream(synth.text, voice_name)
# NOTE: SynthesizeStopped is NOT sent here — it belongs only to the
# streaming protocol (SynthesizeStop path). Sending it here confuses HA.
return True
return True
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
audio_prompt = resolve_voice(voice_name, self._voices)
sample_rate = engine.get_sample_rate()
chunk_size = get_wyoming_chunk_size()
chunks = _split_text(text, max_chunk=chunk_size)
logger.info(
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
f"prompt='{audio_prompt}' — text: {text!r}"
)
await self.write_event(
AudioStart(rate=sample_rate, width=2, channels=1).event()
)
first_chunk = True
start_time = time.monotonic()
for i, chunk_text in enumerate(chunks):
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
try:
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
None,
self._synthesize_chunk,
chunk_text,
audio_prompt,
)
except Exception:
logger.exception(f"Synthesis failed for chunk {i+1}")
continue
audio_bytes = (
audio_tensor.cpu().numpy().squeeze() * 32767
).astype("int16").tobytes()
if first_chunk:
ttfa = time.monotonic() - start_time
logger.info(f"Time to first audio: {ttfa:.3f}s")
first_chunk = False
await self.write_event(
AudioChunk(
audio=audio_bytes,
rate=sample_rate,
width=2,
channels=1,
).event()
)
await self.write_event(AudioStop().event())
total = time.monotonic() - start_time
logger.info(f"Synthesis complete in {total:.3f}s")
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
return engine.synthesize(
text=text,
audio_prompt_path=audio_prompt,
temperature=get_gen_temperature(),
exaggeration=get_gen_exaggeration(),
cfg_weight=get_gen_cfg_weight(),
seed=get_gen_seed(),
)