feat: integrate loguru for centralized logging, debugs, etc (#30)

* feat: integrate loguru for controllable logging, debug, etc

* Update pipeline.py

- Characters swapped out from under me by my ide
This commit is contained in:
remsky
2025-01-31 19:59:08 -07:00
committed by GitHub
parent 396766a5b0
commit b6cb300d50
5 changed files with 51 additions and 20 deletions

View File

@@ -1,12 +1,9 @@
""" """
Quick example to show how device selection can be controlled, and was checked Quick example to show how device selection can be controlled, and was checked
""" """
# import warnings, time
import time import time
# import warnings
from kokoro import KPipeline from kokoro import KPipeline
# warnings.filterwarnings('ignore') from loguru import logger
# import torch; torch.set_warn_always(False)
def generate_audio(pipeline, text): def generate_audio(pipeline, text):
for _, _, audio in pipeline(text, voice='af_bella'): for _, _, audio in pipeline(text, voice='af_bella'):
@@ -20,9 +17,9 @@ def time_synthesis(device=None):
pipeline = KPipeline(lang_code='a', device=device) pipeline = KPipeline(lang_code='a', device=device)
samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.") samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.")
ms = (time.perf_counter() - start) * 1000 ms = (time.perf_counter() - start) * 1000
print(f"{device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples") logger.info(f"{device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples")
except RuntimeError as e: except RuntimeError as e:
print(f"{'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}") logger.error(f"{'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}")
def compare_shared_model(): def compare_shared_model():
try: try:
@@ -34,15 +31,15 @@ def compare_shared_model():
generate_audio(pipeline, "Testing model reuse.") generate_audio(pipeline, "Testing model reuse.")
ms = (time.perf_counter() - start) * 1000 ms = (time.perf_counter() - start) * 1000
print(f"✓ reuse | {ms:>5.1f}ms for both models") logger.info(f"✓ reuse | {ms:>5.1f}ms for both models")
except Exception as e: except Exception as e:
print(f"✗ reuse | {str(e)}") logger.error(f"✗ reuse | {str(e)}")
if __name__ == '__main__': if __name__ == '__main__':
print("\nDevice Selection & Performance:") logger.info("Device Selection & Performance")
print("----------------------------------------") logger.info("-" * 40)
time_synthesis() time_synthesis()
time_synthesis('cuda') time_synthesis('cuda')
time_synthesis('cpu') time_synthesis('cpu')
print("----------------------------------------") logger.info("-" * 40)
compare_shared_model() compare_shared_model()

View File

@@ -1,4 +1,23 @@
__version__ = '0.3.2' __version__ = '0.3.2'
from loguru import logger
import sys
# Remove default handler
logger.remove()
# Add custom handler with clean format including module and line number
logger.add(
sys.stderr,
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
colorize=True,
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
# "ERROR" to enable only logger.error("message") prints
# etc
)
# Disable before release or as needed
# logger.disable("kokoro")
from .model import KModel from .model import KModel
from .pipeline import KPipeline from .pipeline import KPipeline

View File

@@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download
from numbers import Number from numbers import Number
from transformers import AlbertConfig from transformers import AlbertConfig
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from loguru import logger
import json import json
import torch import torch
@@ -28,9 +29,11 @@ class KModel(torch.nn.Module):
super().__init__() super().__init__()
if not isinstance(config, dict): if not isinstance(config, dict):
if not config: if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json') config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
with open(config, 'r', encoding='utf-8') as r: with open(config, 'r', encoding='utf-8') as r:
config = json.load(r) config = json.load(r)
logger.debug(f"Loaded config: {config}")
self.vocab = config['vocab'] self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
@@ -54,6 +57,7 @@ class KModel(torch.nn.Module):
try: try:
getattr(self, key).load_state_dict(state_dict) getattr(self, key).load_state_dict(state_dict)
except: except:
logger.debug(f"Did not load {key} from state_dict")
state_dict = {k[7:]: v for k, v in state_dict.items()} state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False) getattr(self, key).load_state_dict(state_dict, strict=False)
@@ -64,6 +68,7 @@ class KModel(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor: def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device) input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
@@ -78,6 +83,7 @@ class KModel(torch.nn.Module):
duration = self.predictor.duration_proj(x) duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long() pred_dur = torch.round(duration).clamp(min=1).long()
logger.debug(f"pred_dur: {pred_dur}")
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0 c_frame = 0
for i in range(pred_aln_trg.size(0)): for i in range(pred_aln_trg.size(0)):

View File

@@ -3,6 +3,7 @@ from huggingface_hub import hf_hub_download
from misaki import en, espeak from misaki import en, espeak
from numbers import Number from numbers import Number
from typing import Generator, List, Optional, Tuple, Union from typing import Generator, List, Optional, Tuple, Union
from loguru import logger
import re import re
import torch import torch
@@ -86,7 +87,8 @@ class KPipeline:
try: try:
fallback = espeak.EspeakFallback(british=lang_code=='b') fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e: except Exception as e:
print('WARNING: EspeakFallback not enabled. OOD words will be skipped.', e) logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)})
fallback = None fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
elif lang_code == 'j': elif lang_code == 'j':
@@ -94,18 +96,18 @@ class KPipeline:
from misaki import ja from misaki import ja
self.g2p = ja.JAG2P() self.g2p = ja.JAG2P()
except ImportError: except ImportError:
print("ERROR: You need to `pip install misaki[ja]` to use lang_code='j'") logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
raise raise
elif lang_code == 'z': elif lang_code == 'z':
try: try:
from misaki import zh from misaki import zh
self.g2p = zh.ZHG2P() self.g2p = zh.ZHG2P()
except ImportError: except ImportError:
print("ERROR: You need to `pip install misaki[zh]` to use lang_code='z'") logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise raise
else: else:
language = LANG_CODES[lang_code] language = LANG_CODES[lang_code]
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language) self.g2p = espeak.EspeakG2P(language=language)
def load_single_voice(self, voice: str): def load_single_voice(self, voice: str):
@@ -118,7 +120,7 @@ class KPipeline:
if not voice.startswith(self.lang_code): if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice) v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code) p = LANG_CODES.get(self.lang_code, self.lang_code)
print(f'WARNING: Language mismatch, loading {v} voice into {p} pipeline.') logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
pack = torch.load(f, weights_only=True) pack = torch.load(f, weights_only=True)
self.voices[voice] = pack self.voices[voice] = pack
return pack return pack
@@ -175,7 +177,10 @@ class KPipeline:
z = KPipeline.waterfall_last(pairs, next_count) z = KPipeline.waterfall_last(pairs, next_count)
text, ps = zip(*pairs[:z]) text, ps = zip(*pairs[:z])
ps = ''.join(ps) ps = ''.join(ps)
yield ''.join(text).strip(), ps.strip() text_chunk = ''.join(text).strip()
ps_chunk = ps.strip()
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
yield text_chunk, ps_chunk
pairs = pairs[z:] pairs = pairs[z:]
count -= len(ps) count -= len(ps)
if not pairs: if not pairs:
@@ -204,20 +209,23 @@ class KPipeline:
split_pattern: Optional[str] = r'\n+', split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None model: Optional[KModel] = None
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]: ) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
logger.debug(f"Loading voice: {voice}")
pack = self.load_voice(voice) pack = self.load_voice(voice)
model = model or self.model model = model or self.model
pack = pack.to(model.device) if model else pack pack = pack.to(model.device) if model else pack
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
if isinstance(text, str): if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text] text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text: for graphemes in text:
# TODO(misaki): Unify G2P interface between English and non-English # TODO(misaki): Unify G2P interface between English and non-English
if self.lang_code in 'ab': if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes) _, tokens = self.g2p(graphemes)
for gs, ps in self.en_tokenize(tokens): for gs, ps in self.en_tokenize(tokens):
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
print(f"WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510] ps = ps[:510]
yield gs, ps, KPipeline.infer(model, ps, pack, speed) yield gs, ps, KPipeline.infer(model, ps, pack, speed)
else: else:
@@ -225,6 +233,6 @@ class KPipeline:
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510') logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510] ps = ps[:510]
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed) yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)

View File

@@ -2,4 +2,5 @@ numpy
phonemizer phonemizer
scipy scipy
torch torch
transformers transformers
loguru