From 44ab97292ee1ad92573b95061c7a608dc53016fe Mon Sep 17 00:00:00 2001 From: hexgrad Date: Tue, 28 Jan 2025 14:26:16 -0800 Subject: [PATCH] Use EspeakG2P --- kokoro/__init__.py | 2 +- kokoro/pipeline.py | 58 ++++++++++++++++++++++++++++++++-------------- setup.py | 2 +- 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/kokoro/__init__.py b/kokoro/__init__.py index 5fe3637..4046dd4 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.1' +__version__ = '0.2.2' from .models import KModel from .pipeline import KPipeline diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index 31df88c..a2ce329 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -9,6 +9,11 @@ import torch LANG_CODES = dict( a='American English', b='British English', + e='es', + f='fr-fr', + h='hi', + i='it', + p='pt-br', ) REPO_ID = 'hexgrad/Kokoro-82M' @@ -24,16 +29,21 @@ class KPipeline: if model_path is None: model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth') assert os.path.exists(model_path) - try: - fallback = espeak.EspeakFallback(british=lang_code=='b') - except Exception as e: - print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e) - fallback = None - self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) self.vocab = config['vocab'] self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device self.model = KModel(config, model_path).to(self.device).eval() self.voices = {} + if lang_code in 'ab': + try: + fallback = espeak.EspeakFallback(british=lang_code=='b') + except Exception as e: + print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e) + fallback = None + self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) + else: + language = LANG_CODES[lang_code] + print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") + self.g2p = espeak.EspeakG2P(language=language) def load_voice(self, voice): if voice in self.voices: @@ -60,7 +70,7 @@ class KPipeline: return z return len(pairs) - def tokenize(self, tokens): + def en_tokenize(self, tokens): pairs = [] count = 0 for w in tokens: @@ -86,19 +96,33 @@ class KPipeline: text, ps = zip(*pairs) yield ''.join(text).strip(), ''.join(ps).strip() - def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'): + def p2ii(self, ps): + input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps))) + assert input_ids and len(input_ids) <= 510, input_ids + return input_ids + + def __call__(self, text, voice, speed=1, split_pattern=r'\n+'): assert isinstance(text, str) or isinstance(text, list), type(text) self.load_voice(voice) - if isinstance(text, str) and split_pattern: - text = re.split(split_pattern, text.strip()) + if isinstance(text, str): + text = re.split(split_pattern, text.strip()) if split_pattern else [text] for graphemes in text: - _, tokens = self.g2p(graphemes) - for gs, ps in self.tokenize(tokens): + if self.lang_code in 'ab': + _, tokens = self.g2p(graphemes) + for gs, ps in self.en_tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + continue + input_ids = self.p2ii(ps) + yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) + else: + ps = self.g2p(graphemes) if not ps: continue elif len(ps) > 510: - print('TODO: Unexpected len(ps) > 510', len(ps), ps) - continue - input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps))) - assert input_ids and len(input_ids) <= 510, input_ids - yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) + print(f'WARNING: Truncating len(ps) == {len(ps)} > 510') + ps = ps[:510] + input_ids = self.p2ii(ps) + yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) diff --git a/setup.py b/setup.py index 64faef4..8b43fd5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='kokoro', # Name of the package - version='0.2.1', # Initial version + version='0.2.2', # Initial version packages=find_packages(), # Automatically finds packages install_requires=[ # List your dependencies here 'huggingface_hub',