Files
rocm-chatterbox-whisper/wyoming_handler.py
scott 2b1398109d
Some checks failed
Build ROCm Image / build (push) Has been cancelled
Switch to ONNX runtime with chatterbox-turbo-ONNX (fp16)
Replaces the PyTorch/chatterbox-tts stack with direct ONNX inference
using ResembleAI/chatterbox-turbo-ONNX fp16 weights.

- engine.py: full rewrite — ONNX sessions, autoregressive KV-cache LM
  loop, voice conditionals cache via speech_encoder outputs
- wyoming_handler.py: remove torch dep, use np.asarray for audio bytes
- requirements-rocm-init.txt: onnxruntime-rocm replaces torch wheels
- requirements-rocm.txt: drop chatterbox/torch deps, keep audio utils
- Dockerfile.rocm: remove chatterbox-tts install step

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 19:08:34 -04:00

186 lines
6.2 KiB
Python

import asyncio
import logging
import time
from typing import Dict, Optional
import numpy as np
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
import engine
from config import (
get_gen_cfg_weight,
get_gen_exaggeration,
get_gen_seed,
get_gen_temperature,
get_wyoming_chunk_size,
)
from wyoming_voices import resolve_voice
logger = logging.getLogger(__name__)
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
import re
# Split on sentence-ending punctuation
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
chunks = []
current = ""
for sentence in sentences:
if not sentence:
continue
if current and len(current) + 1 + len(sentence) > max_chunk:
chunks.append(current.strip())
current = sentence
else:
current = (current + " " + sentence).strip() if current else sentence
if current:
chunks.append(current.strip())
return chunks or [text]
class ChatterboxWyomingHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
voices: Dict[str, str],
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._info = wyoming_info
self._voices = voices
self._streaming = False
self._streaming_text = ""
self._streaming_voice: Optional[str] = None
async def handle_event(self, event: Event) -> bool:
logger.info(f"Event received: {event.type}")
try:
return await self._handle_event(event)
except Exception:
logger.exception(f"Unhandled exception in handle_event for {event.type}")
return True
async def _handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self._info.event())
return True
if SynthesizeStart.is_type(event.type):
start = SynthesizeStart.from_event(event)
self._streaming = True
self._streaming_text = ""
self._streaming_voice = start.voice.name if start.voice else None
logger.info(f"Streaming started, voice='{self._streaming_voice}'")
return True
if SynthesizeChunk.is_type(event.type):
chunk = SynthesizeChunk.from_event(event)
if self._streaming:
self._streaming_text += chunk.text
return True
if SynthesizeStop.is_type(event.type):
logger.info(f"SynthesizeStop — text: {self._streaming_text!r}")
if self._streaming and self._streaming_text:
await self._synthesize_and_stream(
self._streaming_text,
self._streaming_voice,
)
self._streaming = False
self._streaming_text = ""
self._streaming_voice = None
await self.write_event(SynthesizeStopped().event())
return True
if Synthesize.is_type(event.type):
if self._streaming:
# Ignore duplicate Synthesize events sent alongside streaming protocol
logger.info("Ignoring Synthesize (streaming protocol active)")
return True
synth = Synthesize.from_event(event)
voice_name = synth.voice.name if synth.voice else None
logger.info(f"Synthesize — voice='{voice_name}', text: {synth.text!r}")
await self._synthesize_and_stream(synth.text, voice_name)
# NOTE: SynthesizeStopped is NOT sent here — it belongs only to the
# streaming protocol (SynthesizeStop path). Sending it here confuses HA.
return True
logger.warning(f"Unhandled event type: {event.type}")
return True
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
audio_prompt = resolve_voice(voice_name, self._voices)
sample_rate = engine.get_sample_rate()
chunk_size = get_wyoming_chunk_size()
chunks = _split_text(text, max_chunk=chunk_size)
logger.info(
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
f"prompt='{audio_prompt}' — text: {text!r}"
)
await self.write_event(
AudioStart(rate=sample_rate, width=2, channels=1).event()
)
first_chunk = True
start_time = time.monotonic()
for i, chunk_text in enumerate(chunks):
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
try:
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
None,
self._synthesize_chunk,
chunk_text,
audio_prompt,
)
except Exception:
logger.exception(f"Synthesis failed for chunk {i+1}")
continue
audio_bytes = (
np.asarray(audio_tensor).squeeze() * 32767
).astype("int16").tobytes()
if first_chunk:
ttfa = time.monotonic() - start_time
logger.info(f"Time to first audio: {ttfa:.3f}s")
first_chunk = False
await self.write_event(
AudioChunk(
audio=audio_bytes,
rate=sample_rate,
width=2,
channels=1,
).event()
)
await self.write_event(AudioStop().event())
total = time.monotonic() - start_time
logger.info(f"Synthesis complete in {total:.3f}s")
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
return engine.synthesize(
text=text,
audio_prompt_path=audio_prompt,
temperature=get_gen_temperature(),
exaggeration=get_gen_exaggeration(),
cfg_weight=get_gen_cfg_weight(),
seed=get_gen_seed(),
)