import logging import os from pathlib import Path from typing import Optional import yaml logger = logging.getLogger(__name__) _config: dict = {} _config_path = Path(os.environ.get("CONFIG_PATH", "config.yaml")) DEFAULTS = { "model": { "repo_id": "chatterbox-turbo", }, "tts_engine": { "device": "", # empty = auto-detect "predefined_voices_path": "voices", "reference_audio_path": "reference_audio", "default_voice_id": "default.wav", }, "generation_defaults": { "temperature": 0.8, "exaggeration": 0.5, "cfg_weight": 0.5, "seed": 0, }, "wyoming": { "host": "0.0.0.0", "port": 10200, "chunk_size": 120, }, "paths": { "model_cache": "/app/hf_cache", }, } def _deep_merge(base: dict, override: dict) -> dict: result = base.copy() for key, value in override.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = _deep_merge(result[key], value) else: result[key] = value return result def load_config() -> None: global _config _config = _deep_merge(DEFAULTS, {}) if _config_path.exists(): try: with open(_config_path) as f: user_config = yaml.safe_load(f) or {} _config = _deep_merge(_config, user_config) logger.info(f"Loaded config from {_config_path}") except Exception: logger.exception(f"Failed to load {_config_path}, using defaults") else: logger.warning(f"Config file not found at {_config_path}, using defaults") # Set HuggingFace cache path from config cache_path = _config.get("paths", {}).get("model_cache", "") if cache_path: os.environ.setdefault("HF_HOME", cache_path) def get_model_repo_id() -> str: return _config.get("model", {}).get("repo_id", "chatterbox-turbo") def get_device_override() -> Optional[str]: return _config.get("tts_engine", {}).get("device") or None def get_predefined_voices_path() -> Path: return Path(_config.get("tts_engine", {}).get("predefined_voices_path", "voices")) def get_reference_audio_path() -> Path: return Path(_config.get("tts_engine", {}).get("reference_audio_path", "reference_audio")) def get_default_voice_id() -> str: return _config.get("tts_engine", {}).get("default_voice_id", "default.wav") def get_wyoming_host() -> str: return _config.get("wyoming", {}).get("host", "0.0.0.0") def get_wyoming_port() -> int: return int(_config.get("wyoming", {}).get("port", 10200)) def get_wyoming_chunk_size() -> int: return int(_config.get("wyoming", {}).get("chunk_size", 300)) def get_gen_temperature() -> float: return float(_config.get("generation_defaults", {}).get("temperature", 0.8)) def get_gen_exaggeration() -> float: return float(_config.get("generation_defaults", {}).get("exaggeration", 0.5)) def get_gen_cfg_weight() -> float: return float(_config.get("generation_defaults", {}).get("cfg_weight", 0.5)) def get_gen_seed() -> int: return int(_config.get("generation_defaults", {}).get("seed", 0))