Initial implementation: Chatterbox TTS with ROCm and Wyoming
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
All checks were successful
Build ROCm Image / build (push) Successful in 15m27s
Wyoming-only server built around the official chatterbox TTS model. Includes ROCm/AMD GPU support, sentence-level streaming, config.yaml management, and Gitea CI for container builds. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
166
wyoming_handler.py
Normal file
166
wyoming_handler.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.server import AsyncEventHandler
|
||||
from wyoming.tts import Synthesize, SynthesizeChunk, SynthesizeStart, SynthesizeStopped, SynthesizeStop
|
||||
|
||||
import engine
|
||||
from config import (
|
||||
get_gen_cfg_weight,
|
||||
get_gen_exaggeration,
|
||||
get_gen_seed,
|
||||
get_gen_temperature,
|
||||
get_wyoming_chunk_size,
|
||||
)
|
||||
from wyoming_voices import resolve_voice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _split_text(text: str, max_chunk: int = 300) -> list[str]:
|
||||
"""Split text at sentence boundaries to keep chunks under max_chunk chars."""
|
||||
import re
|
||||
|
||||
# Split on sentence-ending punctuation
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||
|
||||
chunks = []
|
||||
current = ""
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if current and len(current) + 1 + len(sentence) > max_chunk:
|
||||
chunks.append(current.strip())
|
||||
current = sentence
|
||||
else:
|
||||
current = (current + " " + sentence).strip() if current else sentence
|
||||
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks or [text]
|
||||
|
||||
|
||||
class ChatterboxWyomingHandler(AsyncEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
wyoming_info: Info,
|
||||
voices: Dict[str, str],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._info = wyoming_info
|
||||
self._voices = voices
|
||||
|
||||
self._streaming = False
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice: Optional[str] = None
|
||||
|
||||
async def handle_event(self, event: Event) -> bool:
|
||||
if Describe.is_type(event.type):
|
||||
await self.write_event(self._info.event())
|
||||
return True
|
||||
|
||||
if SynthesizeStart.is_type(event.type):
|
||||
start = SynthesizeStart.from_event(event)
|
||||
self._streaming = True
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice = start.voice.name if start.voice else None
|
||||
return True
|
||||
|
||||
if SynthesizeChunk.is_type(event.type):
|
||||
chunk = SynthesizeChunk.from_event(event)
|
||||
if self._streaming:
|
||||
self._streaming_text += chunk.text
|
||||
return True
|
||||
|
||||
if SynthesizeStop.is_type(event.type):
|
||||
if self._streaming and self._streaming_text:
|
||||
await self._synthesize_and_stream(
|
||||
self._streaming_text,
|
||||
self._streaming_voice,
|
||||
)
|
||||
self._streaming = False
|
||||
self._streaming_text = ""
|
||||
self._streaming_voice = None
|
||||
await self.write_event(SynthesizeStopped().event())
|
||||
return True
|
||||
|
||||
if Synthesize.is_type(event.type):
|
||||
synth = Synthesize.from_event(event)
|
||||
voice_name = synth.voice.name if synth.voice else None
|
||||
await self._synthesize_and_stream(synth.text, voice_name)
|
||||
await self.write_event(SynthesizeStopped().event())
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
async def _synthesize_and_stream(self, text: str, voice_name: Optional[str]) -> None:
|
||||
audio_prompt = resolve_voice(voice_name, self._voices)
|
||||
sample_rate = engine.get_sample_rate()
|
||||
chunk_size = get_wyoming_chunk_size()
|
||||
|
||||
chunks = _split_text(text, max_chunk=chunk_size)
|
||||
logger.info(
|
||||
f"Synthesizing {len(chunks)} chunk(s) for voice='{voice_name}', "
|
||||
f"prompt='{audio_prompt}'"
|
||||
)
|
||||
|
||||
await self.write_event(
|
||||
AudioStart(rate=sample_rate, width=2, channels=1).event()
|
||||
)
|
||||
|
||||
first_chunk = True
|
||||
start_time = time.monotonic()
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
logger.debug(f"Chunk {i+1}/{len(chunks)}: {chunk_text[:60]!r}")
|
||||
|
||||
try:
|
||||
audio_tensor, sr = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._synthesize_chunk,
|
||||
chunk_text,
|
||||
audio_prompt,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Synthesis failed for chunk {i+1}")
|
||||
continue
|
||||
|
||||
audio_np = audio_tensor.cpu().numpy().squeeze()
|
||||
audio_bytes = (audio_np * 32767).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||
|
||||
if first_chunk:
|
||||
ttfa = time.monotonic() - start_time
|
||||
logger.info(f"Time to first audio: {ttfa:.3f}s")
|
||||
first_chunk = False
|
||||
|
||||
await self.write_event(
|
||||
AudioChunk(
|
||||
audio=audio_bytes,
|
||||
rate=sample_rate,
|
||||
width=2,
|
||||
channels=1,
|
||||
).event()
|
||||
)
|
||||
|
||||
await self.write_event(AudioStop().event())
|
||||
total = time.monotonic() - start_time
|
||||
logger.info(f"Synthesis complete in {total:.3f}s")
|
||||
|
||||
def _synthesize_chunk(self, text: str, audio_prompt: Optional[str]):
|
||||
return engine.synthesize(
|
||||
text=text,
|
||||
audio_prompt_path=audio_prompt,
|
||||
temperature=get_gen_temperature(),
|
||||
exaggeration=get_gen_exaggeration(),
|
||||
cfg_weight=get_gen_cfg_weight(),
|
||||
seed=get_gen_seed(),
|
||||
)
|
||||
Reference in New Issue
Block a user