import asyncio import logging import sys from functools import partial from wyoming.server import AsyncServer import engine from config import get_wyoming_host, get_wyoming_port, load_config from wyoming_handler import ChatterboxWyomingHandler from wyoming_voices import create_wyoming_info, load_voices logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", stream=sys.stdout, ) logger = logging.getLogger(__name__) _WARMUP_TEXTS = [ # Short: covers brief HA notifications (lights on/off, etc.) "Okay.", # Medium: covers typical HA announcements "The front door is open. Please close it.", # Long: covers longer TTS requests and pre-compiles dynamic shape graph ( "This is a warmup synthesis to pre-compile neural network kernels " "for longer text lengths used in Home Assistant announcements and notifications." ), ] def _warmup(voices: dict) -> None: from wyoming_voices import resolve_voice audio_prompt = resolve_voice(None, voices) if voices else None logger.info( f"Running {len(_WARMUP_TEXTS)}-pass warmup to pre-compile torch kernels " "for short, medium, and long text lengths..." ) for i, text in enumerate(_WARMUP_TEXTS, 1): try: engine.synthesize(text=text, audio_prompt_path=audio_prompt) logger.info(f"Warmup pass {i}/{len(_WARMUP_TEXTS)} complete") except Exception: logger.warning(f"Warmup pass {i} failed (non-fatal)", exc_info=True) logger.info("Warmup complete") async def main() -> None: load_config() logger.info("Loading TTS model...") if not engine.load_model(): logger.error("Failed to load model, exiting") sys.exit(1) voices = load_voices() wyoming_info = create_wyoming_info(engine.get_sample_rate(), voices) # Run a warmup synthesis before accepting connections so MIOpen benchmarks # and caches the best convolution algorithms for all layer shapes. Without # this, the first real HA request triggers benchmarking (hundreds of runs) # and times out before any audio is returned. _warmup(voices) host = get_wyoming_host() port = get_wyoming_port() uri = f"tcp://{host}:{port}" logger.info(f"Starting Wyoming server on {uri}") server = AsyncServer.from_uri(uri) await server.run(partial(ChatterboxWyomingHandler, wyoming_info, voices)) if __name__ == "__main__": asyncio.run(main())