diff --git a/examples/device_examples.py b/examples/device_examples.py index 83fb61d..da277af 100644 --- a/examples/device_examples.py +++ b/examples/device_examples.py @@ -1,12 +1,9 @@ """ Quick example to show how device selection can be controlled, and was checked """ -# import warnings, time import time -# import warnings from kokoro import KPipeline -# warnings.filterwarnings('ignore') -# import torch; torch.set_warn_always(False) +from loguru import logger def generate_audio(pipeline, text): for _, _, audio in pipeline(text, voice='af_bella'): @@ -20,9 +17,9 @@ def time_synthesis(device=None): pipeline = KPipeline(lang_code='a', device=device) samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.") ms = (time.perf_counter() - start) * 1000 - print(f"✓ {device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples") + logger.info(f"✓ {device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples") except RuntimeError as e: - print(f"✗ {'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}") + logger.error(f"✗ {'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}") def compare_shared_model(): try: @@ -34,15 +31,15 @@ def compare_shared_model(): generate_audio(pipeline, "Testing model reuse.") ms = (time.perf_counter() - start) * 1000 - print(f"✓ reuse | {ms:>5.1f}ms for both models") + logger.info(f"✓ reuse | {ms:>5.1f}ms for both models") except Exception as e: - print(f"✗ reuse | {str(e)}") + logger.error(f"✗ reuse | {str(e)}") if __name__ == '__main__': - print("\nDevice Selection & Performance:") - print("----------------------------------------") + logger.info("Device Selection & Performance") + logger.info("-" * 40) time_synthesis() time_synthesis('cuda') time_synthesis('cpu') - print("----------------------------------------") + logger.info("-" * 40) compare_shared_model() \ No newline at end of file diff --git a/kokoro/__init__.py b/kokoro/__init__.py index 0e2a5c8..77f3202 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,23 @@ __version__ = '0.3.2' +from loguru import logger +import sys + +# Remove default handler +logger.remove() + +# Add custom handler with clean format including module and line number +logger.add( + sys.stderr, + format="{time:HH:mm:ss} | {module:>16}:{line} | {level: >8} | {message}", + colorize=True, + level="INFO" # "DEBUG" to enable logger.debug("message") and up prints + # "ERROR" to enable only logger.error("message") prints + # etc +) + +# Disable before release or as needed +# logger.disable("kokoro") + from .model import KModel from .pipeline import KPipeline diff --git a/kokoro/model.py b/kokoro/model.py index be5ad30..4a2428d 100644 --- a/kokoro/model.py +++ b/kokoro/model.py @@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download from numbers import Number from transformers import AlbertConfig from typing import Dict, Optional, Union +from loguru import logger import json import torch @@ -28,9 +29,11 @@ class KModel(torch.nn.Module): super().__init__() if not isinstance(config, dict): if not config: + logger.debug("No config provided, downloading from HF") config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json') with open(config, 'r', encoding='utf-8') as r: config = json.load(r) + logger.debug(f"Loaded config: {config}") self.vocab = config['vocab'] self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) @@ -54,6 +57,7 @@ class KModel(torch.nn.Module): try: getattr(self, key).load_state_dict(state_dict) except: + logger.debug(f"Did not load {key} from state_dict") state_dict = {k[7:]: v for k, v in state_dict.items()} getattr(self, key).load_state_dict(state_dict, strict=False) @@ -64,6 +68,7 @@ class KModel(torch.nn.Module): @torch.no_grad() def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor: input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) + logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device) @@ -78,6 +83,7 @@ class KModel(torch.nn.Module): duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) / speed pred_dur = torch.round(duration).clamp(min=1).long() + logger.debug(f"pred_dur: {pred_dur}") pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) c_frame = 0 for i in range(pred_aln_trg.size(0)): diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index 366ba3d..ef0f135 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -3,6 +3,7 @@ from huggingface_hub import hf_hub_download from misaki import en, espeak from numbers import Number from typing import Generator, List, Optional, Tuple, Union +from loguru import logger import re import torch @@ -86,7 +87,8 @@ class KPipeline: try: fallback = espeak.EspeakFallback(british=lang_code=='b') except Exception as e: - print('WARNING: EspeakFallback not enabled. OOD words will be skipped.', e) + logger.warning("EspeakFallback not Enabled: OOD words will be skipped") + logger.warning({str(e)}) fallback = None self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) elif lang_code == 'j': @@ -94,18 +96,18 @@ class KPipeline: from misaki import ja self.g2p = ja.JAG2P() except ImportError: - print("ERROR: You need to `pip install misaki[ja]` to use lang_code='j'") + logger.error("You need to `pip install misaki[ja]` to use lang_code='j'") raise elif lang_code == 'z': try: from misaki import zh self.g2p = zh.ZHG2P() except ImportError: - print("ERROR: You need to `pip install misaki[zh]` to use lang_code='z'") + logger.error("You need to `pip install misaki[zh]` to use lang_code='z'") raise else: language = LANG_CODES[lang_code] - print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") + logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") self.g2p = espeak.EspeakG2P(language=language) def load_single_voice(self, voice: str): @@ -118,7 +120,7 @@ class KPipeline: if not voice.startswith(self.lang_code): v = LANG_CODES.get(voice, voice) p = LANG_CODES.get(self.lang_code, self.lang_code) - print(f'WARNING: Language mismatch, loading {v} voice into {p} pipeline.') + logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.') pack = torch.load(f, weights_only=True) self.voices[voice] = pack return pack @@ -175,7 +177,10 @@ class KPipeline: z = KPipeline.waterfall_last(pairs, next_count) text, ps = zip(*pairs[:z]) ps = ''.join(ps) - yield ''.join(text).strip(), ps.strip() + text_chunk = ''.join(text).strip() + ps_chunk = ps.strip() + logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'") + yield text_chunk, ps_chunk pairs = pairs[z:] count -= len(ps) if not pairs: @@ -204,20 +209,23 @@ class KPipeline: split_pattern: Optional[str] = r'\n+', model: Optional[KModel] = None ) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]: + logger.debug(f"Loading voice: {voice}") pack = self.load_voice(voice) model = model or self.model pack = pack.to(model.device) if model else pack + logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}") if isinstance(text, str): text = re.split(split_pattern, text.strip()) if split_pattern else [text] for graphemes in text: # TODO(misaki): Unify G2P interface between English and non-English if self.lang_code in 'ab': + logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}") _, tokens = self.g2p(graphemes) for gs, ps in self.en_tokenize(tokens): if not ps: continue elif len(ps) > 510: - print(f"WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") ps = ps[:510] yield gs, ps, KPipeline.infer(model, ps, pack, speed) else: @@ -225,6 +233,6 @@ class KPipeline: if not ps: continue elif len(ps) > 510: - print(f'WARNING: Truncating len(ps) == {len(ps)} > 510') + logger.warning(f'Truncating len(ps) == {len(ps)} > 510') ps = ps[:510] yield graphemes, ps, KPipeline.infer(model, ps, pack, speed) diff --git a/requirements.txt b/requirements.txt index 89f6e83..47c811c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ numpy phonemizer scipy torch -transformers \ No newline at end of file +transformers +loguru \ No newline at end of file