Use EspeakG2P

This commit is contained in:
hexgrad
2025-01-28 14:26:16 -08:00
parent c56d8e1e5d
commit 44ab97292e
3 changed files with 43 additions and 19 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.2.1'
__version__ = '0.2.2'
from .models import KModel
from .pipeline import KPipeline

View File

@@ -9,6 +9,11 @@ import torch
LANG_CODES = dict(
a='American English',
b='British English',
e='es',
f='fr-fr',
h='hi',
i='it',
p='pt-br',
)
REPO_ID = 'hexgrad/Kokoro-82M'
@@ -24,16 +29,21 @@ class KPipeline:
if model_path is None:
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
assert os.path.exists(model_path)
self.vocab = config['vocab']
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.model = KModel(config, model_path).to(self.device).eval()
self.voices = {}
if lang_code in 'ab':
try:
fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e:
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
self.vocab = config['vocab']
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.model = KModel(config, model_path).to(self.device).eval()
self.voices = {}
else:
language = LANG_CODES[lang_code]
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language)
def load_voice(self, voice):
if voice in self.voices:
@@ -60,7 +70,7 @@ class KPipeline:
return z
return len(pairs)
def tokenize(self, tokens):
def en_tokenize(self, tokens):
pairs = []
count = 0
for w in tokens:
@@ -86,19 +96,33 @@ class KPipeline:
text, ps = zip(*pairs)
yield ''.join(text).strip(), ''.join(ps).strip()
def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'):
def p2ii(self, ps):
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
assert input_ids and len(input_ids) <= 510, input_ids
return input_ids
def __call__(self, text, voice, speed=1, split_pattern=r'\n+'):
assert isinstance(text, str) or isinstance(text, list), type(text)
self.load_voice(voice)
if isinstance(text, str) and split_pattern:
text = re.split(split_pattern, text.strip())
if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text:
if self.lang_code in 'ab':
_, tokens = self.g2p(graphemes)
for gs, ps in self.tokenize(tokens):
for gs, ps in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
print('TODO: Unexpected len(ps) > 510', len(ps), ps)
print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
continue
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
assert input_ids and len(input_ids) <= 510, input_ids
input_ids = self.p2ii(ps)
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
else:
ps = self.g2p(graphemes)
if not ps:
continue
elif len(ps) > 510:
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510]
input_ids = self.p2ii(ps)
yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name='kokoro', # Name of the package
version='0.2.1', # Initial version
version='0.2.2', # Initial version
packages=find_packages(), # Automatically finds packages
install_requires=[ # List your dependencies here
'huggingface_hub',