feat: integrate loguru for centralized logging, debugs, etc (#30)

* feat: integrate loguru for controllable logging, debug, etc

* Update pipeline.py

- Characters swapped out from under me by my ide
This commit is contained in:
remsky
2025-01-31 19:59:08 -07:00
committed by GitHub
parent 396766a5b0
commit b6cb300d50
5 changed files with 51 additions and 20 deletions

View File

@@ -1,4 +1,23 @@
__version__ = '0.3.2'
from loguru import logger
import sys
# Remove default handler
logger.remove()
# Add custom handler with clean format including module and line number
logger.add(
sys.stderr,
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
colorize=True,
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
# "ERROR" to enable only logger.error("message") prints
# etc
)
# Disable before release or as needed
# logger.disable("kokoro")
from .model import KModel
from .pipeline import KPipeline

View File

@@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download
from numbers import Number
from transformers import AlbertConfig
from typing import Dict, Optional, Union
from loguru import logger
import json
import torch
@@ -28,9 +29,11 @@ class KModel(torch.nn.Module):
super().__init__()
if not isinstance(config, dict):
if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
logger.debug(f"Loaded config: {config}")
self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
@@ -54,6 +57,7 @@ class KModel(torch.nn.Module):
try:
getattr(self, key).load_state_dict(state_dict)
except:
logger.debug(f"Did not load {key} from state_dict")
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@@ -64,6 +68,7 @@ class KModel(torch.nn.Module):
@torch.no_grad()
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
@@ -78,6 +83,7 @@ class KModel(torch.nn.Module):
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
logger.debug(f"pred_dur: {pred_dur}")
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):

View File

@@ -3,6 +3,7 @@ from huggingface_hub import hf_hub_download
from misaki import en, espeak
from numbers import Number
from typing import Generator, List, Optional, Tuple, Union
from loguru import logger
import re
import torch
@@ -86,7 +87,8 @@ class KPipeline:
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e:
print('WARNING: EspeakFallback not enabled. OOD words will be skipped.', e)
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)})
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
elif lang_code == 'j':
@@ -94,18 +96,18 @@ class KPipeline:
from misaki import ja
self.g2p = ja.JAG2P()
except ImportError:
print("ERROR: You need to `pip install misaki[ja]` to use lang_code='j'")
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
raise
elif lang_code == 'z':
try:
from misaki import zh
self.g2p = zh.ZHG2P()
except ImportError:
print("ERROR: You need to `pip install misaki[zh]` to use lang_code='z'")
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise
else:
language = LANG_CODES[lang_code]
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language)
def load_single_voice(self, voice: str):
@@ -118,7 +120,7 @@ class KPipeline:
if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code)
print(f'WARNING: Language mismatch, loading {v} voice into {p} pipeline.')
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
pack = torch.load(f, weights_only=True)
self.voices[voice] = pack
return pack
@@ -175,7 +177,10 @@ class KPipeline:
z = KPipeline.waterfall_last(pairs, next_count)
text, ps = zip(*pairs[:z])
ps = ''.join(ps)
yield ''.join(text).strip(), ps.strip()
text_chunk = ''.join(text).strip()
ps_chunk = ps.strip()
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
yield text_chunk, ps_chunk
pairs = pairs[z:]
count -= len(ps)
if not pairs:
@@ -204,20 +209,23 @@ class KPipeline:
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
logger.debug(f"Loading voice: {voice}")
pack = self.load_voice(voice)
model = model or self.model
pack = pack.to(model.device) if model else pack
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text:
# TODO(misaki): Unify G2P interface between English and non-English
if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes)
for gs, ps in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
print(f"WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510]
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
else:
@@ -225,6 +233,6 @@ class KPipeline:
if not ps:
continue
elif len(ps) > 510:
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510')
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510]
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)