feat: integrate loguru for centralized logging, debugs, etc (#30)
* feat: integrate loguru for controllable logging, debug, etc * Update pipeline.py - Characters swapped out from under me by my ide
This commit is contained in:
@@ -1,12 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Quick example to show how device selection can be controlled, and was checked
|
Quick example to show how device selection can be controlled, and was checked
|
||||||
"""
|
"""
|
||||||
# import warnings, time
|
|
||||||
import time
|
import time
|
||||||
# import warnings
|
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
# warnings.filterwarnings('ignore')
|
from loguru import logger
|
||||||
# import torch; torch.set_warn_always(False)
|
|
||||||
|
|
||||||
def generate_audio(pipeline, text):
|
def generate_audio(pipeline, text):
|
||||||
for _, _, audio in pipeline(text, voice='af_bella'):
|
for _, _, audio in pipeline(text, voice='af_bella'):
|
||||||
@@ -20,9 +17,9 @@ def time_synthesis(device=None):
|
|||||||
pipeline = KPipeline(lang_code='a', device=device)
|
pipeline = KPipeline(lang_code='a', device=device)
|
||||||
samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.")
|
samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.")
|
||||||
ms = (time.perf_counter() - start) * 1000
|
ms = (time.perf_counter() - start) * 1000
|
||||||
print(f"✓ {device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples")
|
logger.info(f"✓ {device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples")
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(f"✗ {'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}")
|
logger.error(f"✗ {'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}")
|
||||||
|
|
||||||
def compare_shared_model():
|
def compare_shared_model():
|
||||||
try:
|
try:
|
||||||
@@ -34,15 +31,15 @@ def compare_shared_model():
|
|||||||
generate_audio(pipeline, "Testing model reuse.")
|
generate_audio(pipeline, "Testing model reuse.")
|
||||||
|
|
||||||
ms = (time.perf_counter() - start) * 1000
|
ms = (time.perf_counter() - start) * 1000
|
||||||
print(f"✓ reuse | {ms:>5.1f}ms for both models")
|
logger.info(f"✓ reuse | {ms:>5.1f}ms for both models")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ reuse | {str(e)}")
|
logger.error(f"✗ reuse | {str(e)}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print("\nDevice Selection & Performance:")
|
logger.info("Device Selection & Performance")
|
||||||
print("----------------------------------------")
|
logger.info("-" * 40)
|
||||||
time_synthesis()
|
time_synthesis()
|
||||||
time_synthesis('cuda')
|
time_synthesis('cuda')
|
||||||
time_synthesis('cpu')
|
time_synthesis('cpu')
|
||||||
print("----------------------------------------")
|
logger.info("-" * 40)
|
||||||
compare_shared_model()
|
compare_shared_model()
|
||||||
@@ -1,4 +1,23 @@
|
|||||||
__version__ = '0.3.2'
|
__version__ = '0.3.2'
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
# Add custom handler with clean format including module and line number
|
||||||
|
logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
|
||||||
|
colorize=True,
|
||||||
|
level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
|
||||||
|
# "ERROR" to enable only logger.error("message") prints
|
||||||
|
# etc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable before release or as needed
|
||||||
|
# logger.disable("kokoro")
|
||||||
|
|
||||||
from .model import KModel
|
from .model import KModel
|
||||||
from .pipeline import KPipeline
|
from .pipeline import KPipeline
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from huggingface_hub import hf_hub_download
|
|||||||
from numbers import Number
|
from numbers import Number
|
||||||
from transformers import AlbertConfig
|
from transformers import AlbertConfig
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
from loguru import logger
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -28,9 +29,11 @@ class KModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
if not config:
|
if not config:
|
||||||
|
logger.debug("No config provided, downloading from HF")
|
||||||
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
|
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
|
||||||
with open(config, 'r', encoding='utf-8') as r:
|
with open(config, 'r', encoding='utf-8') as r:
|
||||||
config = json.load(r)
|
config = json.load(r)
|
||||||
|
logger.debug(f"Loaded config: {config}")
|
||||||
self.vocab = config['vocab']
|
self.vocab = config['vocab']
|
||||||
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||||
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||||
@@ -54,6 +57,7 @@ class KModel(torch.nn.Module):
|
|||||||
try:
|
try:
|
||||||
getattr(self, key).load_state_dict(state_dict)
|
getattr(self, key).load_state_dict(state_dict)
|
||||||
except:
|
except:
|
||||||
|
logger.debug(f"Did not load {key} from state_dict")
|
||||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
getattr(self, key).load_state_dict(state_dict, strict=False)
|
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
@@ -64,6 +68,7 @@ class KModel(torch.nn.Module):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
|
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
|
||||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||||
|
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
||||||
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
|
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
|
||||||
@@ -78,6 +83,7 @@ class KModel(torch.nn.Module):
|
|||||||
duration = self.predictor.duration_proj(x)
|
duration = self.predictor.duration_proj(x)
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
logger.debug(f"pred_dur: {pred_dur}")
|
||||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||||
c_frame = 0
|
c_frame = 0
|
||||||
for i in range(pred_aln_trg.size(0)):
|
for i in range(pred_aln_trg.size(0)):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from huggingface_hub import hf_hub_download
|
|||||||
from misaki import en, espeak
|
from misaki import en, espeak
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
from loguru import logger
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -86,7 +87,8 @@ class KPipeline:
|
|||||||
try:
|
try:
|
||||||
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('WARNING: EspeakFallback not enabled. OOD words will be skipped.', e)
|
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
||||||
|
logger.warning({str(e)})
|
||||||
fallback = None
|
fallback = None
|
||||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||||
elif lang_code == 'j':
|
elif lang_code == 'j':
|
||||||
@@ -94,18 +96,18 @@ class KPipeline:
|
|||||||
from misaki import ja
|
from misaki import ja
|
||||||
self.g2p = ja.JAG2P()
|
self.g2p = ja.JAG2P()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("ERROR: You need to `pip install misaki[ja]` to use lang_code='j'")
|
logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
|
||||||
raise
|
raise
|
||||||
elif lang_code == 'z':
|
elif lang_code == 'z':
|
||||||
try:
|
try:
|
||||||
from misaki import zh
|
from misaki import zh
|
||||||
self.g2p = zh.ZHG2P()
|
self.g2p = zh.ZHG2P()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("ERROR: You need to `pip install misaki[zh]` to use lang_code='z'")
|
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
language = LANG_CODES[lang_code]
|
language = LANG_CODES[lang_code]
|
||||||
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||||
self.g2p = espeak.EspeakG2P(language=language)
|
self.g2p = espeak.EspeakG2P(language=language)
|
||||||
|
|
||||||
def load_single_voice(self, voice: str):
|
def load_single_voice(self, voice: str):
|
||||||
@@ -118,7 +120,7 @@ class KPipeline:
|
|||||||
if not voice.startswith(self.lang_code):
|
if not voice.startswith(self.lang_code):
|
||||||
v = LANG_CODES.get(voice, voice)
|
v = LANG_CODES.get(voice, voice)
|
||||||
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||||
print(f'WARNING: Language mismatch, loading {v} voice into {p} pipeline.')
|
logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
|
||||||
pack = torch.load(f, weights_only=True)
|
pack = torch.load(f, weights_only=True)
|
||||||
self.voices[voice] = pack
|
self.voices[voice] = pack
|
||||||
return pack
|
return pack
|
||||||
@@ -175,7 +177,10 @@ class KPipeline:
|
|||||||
z = KPipeline.waterfall_last(pairs, next_count)
|
z = KPipeline.waterfall_last(pairs, next_count)
|
||||||
text, ps = zip(*pairs[:z])
|
text, ps = zip(*pairs[:z])
|
||||||
ps = ''.join(ps)
|
ps = ''.join(ps)
|
||||||
yield ''.join(text).strip(), ps.strip()
|
text_chunk = ''.join(text).strip()
|
||||||
|
ps_chunk = ps.strip()
|
||||||
|
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
|
||||||
|
yield text_chunk, ps_chunk
|
||||||
pairs = pairs[z:]
|
pairs = pairs[z:]
|
||||||
count -= len(ps)
|
count -= len(ps)
|
||||||
if not pairs:
|
if not pairs:
|
||||||
@@ -204,20 +209,23 @@ class KPipeline:
|
|||||||
split_pattern: Optional[str] = r'\n+',
|
split_pattern: Optional[str] = r'\n+',
|
||||||
model: Optional[KModel] = None
|
model: Optional[KModel] = None
|
||||||
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
|
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
|
||||||
|
logger.debug(f"Loading voice: {voice}")
|
||||||
pack = self.load_voice(voice)
|
pack = self.load_voice(voice)
|
||||||
model = model or self.model
|
model = model or self.model
|
||||||
pack = pack.to(model.device) if model else pack
|
pack = pack.to(model.device) if model else pack
|
||||||
|
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||||
for graphemes in text:
|
for graphemes in text:
|
||||||
# TODO(misaki): Unify G2P interface between English and non-English
|
# TODO(misaki): Unify G2P interface between English and non-English
|
||||||
if self.lang_code in 'ab':
|
if self.lang_code in 'ab':
|
||||||
|
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
||||||
_, tokens = self.g2p(graphemes)
|
_, tokens = self.g2p(graphemes)
|
||||||
for gs, ps in self.en_tokenize(tokens):
|
for gs, ps in self.en_tokenize(tokens):
|
||||||
if not ps:
|
if not ps:
|
||||||
continue
|
continue
|
||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
print(f"WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
ps = ps[:510]
|
ps = ps[:510]
|
||||||
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
|
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
|
||||||
else:
|
else:
|
||||||
@@ -225,6 +233,6 @@ class KPipeline:
|
|||||||
if not ps:
|
if not ps:
|
||||||
continue
|
continue
|
||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510')
|
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
||||||
ps = ps[:510]
|
ps = ps[:510]
|
||||||
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
|
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ phonemizer
|
|||||||
scipy
|
scipy
|
||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
|
loguru
|
||||||
Reference in New Issue
Block a user