Files
kokoro/server.py
scott 1cda188c98
All checks were successful
Build and Push Docker Image / build (push) Successful in 2m12s
Close connection after synthesis so HA knows response is complete
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 19:51:57 -04:00

235 lines
7.7 KiB
Python

#!/usr/bin/env python3
"""Kokoro TTS Wyoming protocol server for Home Assistant integration."""
import argparse
import asyncio
import logging
import re
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
import numpy as np
import yaml
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.server import AsyncEventHandler
from wyoming.info import Attribution, Describe, Info, TtsProgram, TtsVoice
from wyoming.server import AsyncServer
from wyoming.tts import Synthesize
_LOGGER = logging.getLogger(__name__)
KOKORO_SAMPLE_RATE = 24000
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
SAMPLE_WIDTH = 2 # 16-bit PCM
CHANNELS = 1
def load_config(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
def build_wyoming_info(config: dict) -> Info:
voices = [
TtsVoice(
name=v["name"],
description=v.get("description", v["name"]),
attribution=Attribution(
name="Kokoro-82M",
url="https://huggingface.co/hexgrad/Kokoro-82M",
),
installed=True,
version="1.0",
languages=[v.get("language", "en-us")],
)
for v in config["tts"]["voices"]
]
return Info(
tts=[
TtsProgram(
name="kokoro",
description="Kokoro 82M TTS via ROCm",
attribution=Attribution(
name="hexgrad/Kokoro-82M",
url="https://huggingface.co/hexgrad/Kokoro-82M",
),
installed=True,
version="0.9.4",
voices=voices,
)
]
)
class KokoroEventHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
pipeline,
config: dict,
executor: ThreadPoolExecutor,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.wyoming_info = wyoming_info
self.pipeline = pipeline
self.config = config
self.executor = executor
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info.event())
return True
if Synthesize.is_type(event.type):
await self._handle_synthesize(Synthesize.from_event(event))
return False # close connection so HA knows response is complete
return True
async def _handle_synthesize(self, synth: Synthesize) -> None:
text = synth.text.strip()
if not text:
await self.write_event(AudioStart(
rate=KOKORO_SAMPLE_RATE, width=SAMPLE_WIDTH, channels=CHANNELS
).event())
await self.write_event(AudioStop().event())
return
# Resolve voice: prefer client request, fall back to config default
default_voice = self.config["tts"].get("default_voice", "af_heart")
voice_name = (
synth.voice.name
if (synth.voice and synth.voice.name)
else default_voice
)
speed = self.config["tts"].get("default_speed", 1.0)
_LOGGER.info("Synthesize: voice=%s text=%r", voice_name, text[:80])
await self.write_event(
AudioStart(
rate=KOKORO_SAMPLE_RATE, width=SAMPLE_WIDTH, channels=CHANNELS
).event()
)
try:
chunk_queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def _generate():
chunk_count = 0
try:
_LOGGER.debug("Pipeline thread started")
sentences = [s for s in _SENTENCE_RE.split(text.strip()) if s]
if not sentences:
sentences = [text]
_LOGGER.debug("Split into %d sentence(s)", len(sentences))
for sentence in sentences:
_LOGGER.debug("Synthesizing: %r", sentence[:60])
for _, _, audio in self.pipeline(sentence, voice=voice_name, speed=speed):
if audio is None:
continue
# float32 [-1, 1] → int16
audio_np = audio.cpu().numpy() if hasattr(audio, 'cpu') else audio
pcm = (np.clip(audio_np, -1.0, 1.0) * 32767).astype(np.int16)
chunk_count += 1
_LOGGER.debug("Queueing chunk %d (%d bytes)", chunk_count, len(pcm.tobytes()))
fut = asyncio.run_coroutine_threadsafe(chunk_queue.put(pcm.tobytes()), loop)
fut.result() # propagate any queue errors
_LOGGER.debug("Pipeline finished, %d chunks generated", chunk_count)
except Exception as exc:
_LOGGER.exception("Pipeline thread error")
asyncio.run_coroutine_threadsafe(chunk_queue.put(exc), loop).result()
finally:
asyncio.run_coroutine_threadsafe(chunk_queue.put(None), loop).result()
self.executor.submit(_generate)
chunks_sent = 0
while True:
item = await chunk_queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
chunks_sent += 1
_LOGGER.debug("Sending audio chunk %d", chunks_sent)
await self.write_event(
AudioChunk(
rate=KOKORO_SAMPLE_RATE,
width=SAMPLE_WIDTH,
channels=CHANNELS,
audio=item,
).event()
)
_LOGGER.info("Synthesis complete, sent %d chunks", chunks_sent)
except Exception:
_LOGGER.exception("Error during synthesis")
finally:
await self.write_event(AudioStop().event())
async def main() -> None:
parser = argparse.ArgumentParser(description="Kokoro TTS Wyoming server")
parser.add_argument(
"--config", default="/app/config.yaml", help="Path to config.yaml"
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
config = load_config(args.config)
uri = config["server"]["uri"]
device = config["tts"].get("device", "cuda")
lang = config["tts"].get("language", "a")
_LOGGER.info("Loading Kokoro pipeline (device=%s, lang=%s)...", device, lang)
# Import here so startup logging appears first
import torch
from kokoro import KPipeline
if device == "cuda" and not torch.cuda.is_available():
_LOGGER.warning("CUDA/ROCm not available, falling back to CPU")
device = "cpu"
_LOGGER.info("GPU available: %s", torch.cuda.is_available())
if torch.cuda.is_available():
_LOGGER.info("Device name: %s", torch.cuda.get_device_name(0))
pipeline = KPipeline(lang_code=lang, device=device)
wyoming_info = build_wyoming_info(config)
executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="kokoro")
_LOGGER.info("Starting Wyoming server at %s", uri)
server = AsyncServer.from_uri(uri)
await server.run(
partial(
KokoroEventHandler,
wyoming_info,
pipeline,
config,
executor,
)
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass