Use EspeakG2P
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
__version__ = '0.2.1'
|
__version__ = '0.2.2'
|
||||||
|
|
||||||
from .models import KModel
|
from .models import KModel
|
||||||
from .pipeline import KPipeline
|
from .pipeline import KPipeline
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ import torch
|
|||||||
LANG_CODES = dict(
|
LANG_CODES = dict(
|
||||||
a='American English',
|
a='American English',
|
||||||
b='British English',
|
b='British English',
|
||||||
|
e='es',
|
||||||
|
f='fr-fr',
|
||||||
|
h='hi',
|
||||||
|
i='it',
|
||||||
|
p='pt-br',
|
||||||
)
|
)
|
||||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||||
|
|
||||||
@@ -24,16 +29,21 @@ class KPipeline:
|
|||||||
if model_path is None:
|
if model_path is None:
|
||||||
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
|
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
|
||||||
assert os.path.exists(model_path)
|
assert os.path.exists(model_path)
|
||||||
try:
|
|
||||||
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
|
||||||
except Exception as e:
|
|
||||||
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
|
|
||||||
fallback = None
|
|
||||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
|
||||||
self.vocab = config['vocab']
|
self.vocab = config['vocab']
|
||||||
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||||
self.model = KModel(config, model_path).to(self.device).eval()
|
self.model = KModel(config, model_path).to(self.device).eval()
|
||||||
self.voices = {}
|
self.voices = {}
|
||||||
|
if lang_code in 'ab':
|
||||||
|
try:
|
||||||
|
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||||
|
except Exception as e:
|
||||||
|
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
|
||||||
|
fallback = None
|
||||||
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||||
|
else:
|
||||||
|
language = LANG_CODES[lang_code]
|
||||||
|
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||||
|
self.g2p = espeak.EspeakG2P(language=language)
|
||||||
|
|
||||||
def load_voice(self, voice):
|
def load_voice(self, voice):
|
||||||
if voice in self.voices:
|
if voice in self.voices:
|
||||||
@@ -60,7 +70,7 @@ class KPipeline:
|
|||||||
return z
|
return z
|
||||||
return len(pairs)
|
return len(pairs)
|
||||||
|
|
||||||
def tokenize(self, tokens):
|
def en_tokenize(self, tokens):
|
||||||
pairs = []
|
pairs = []
|
||||||
count = 0
|
count = 0
|
||||||
for w in tokens:
|
for w in tokens:
|
||||||
@@ -86,19 +96,33 @@ class KPipeline:
|
|||||||
text, ps = zip(*pairs)
|
text, ps = zip(*pairs)
|
||||||
yield ''.join(text).strip(), ''.join(ps).strip()
|
yield ''.join(text).strip(), ''.join(ps).strip()
|
||||||
|
|
||||||
def __call__(self, text, voice='af', speed=1, split_pattern=r'\n+'):
|
def p2ii(self, ps):
|
||||||
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
|
||||||
|
assert input_ids and len(input_ids) <= 510, input_ids
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def __call__(self, text, voice, speed=1, split_pattern=r'\n+'):
|
||||||
assert isinstance(text, str) or isinstance(text, list), type(text)
|
assert isinstance(text, str) or isinstance(text, list), type(text)
|
||||||
self.load_voice(voice)
|
self.load_voice(voice)
|
||||||
if isinstance(text, str) and split_pattern:
|
if isinstance(text, str):
|
||||||
text = re.split(split_pattern, text.strip())
|
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||||
for graphemes in text:
|
for graphemes in text:
|
||||||
_, tokens = self.g2p(graphemes)
|
if self.lang_code in 'ab':
|
||||||
for gs, ps in self.tokenize(tokens):
|
_, tokens = self.g2p(graphemes)
|
||||||
|
for gs, ps in self.en_tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
|
continue
|
||||||
|
input_ids = self.p2ii(ps)
|
||||||
|
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
||||||
|
else:
|
||||||
|
ps = self.g2p(graphemes)
|
||||||
if not ps:
|
if not ps:
|
||||||
continue
|
continue
|
||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
print('TODO: Unexpected len(ps) > 510', len(ps), ps)
|
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510')
|
||||||
continue
|
ps = ps[:510]
|
||||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
|
input_ids = self.p2ii(ps)
|
||||||
assert input_ids and len(input_ids) <= 510, input_ids
|
yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
||||||
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='kokoro', # Name of the package
|
name='kokoro', # Name of the package
|
||||||
version='0.2.1', # Initial version
|
version='0.2.2', # Initial version
|
||||||
packages=find_packages(), # Automatically finds packages
|
packages=find_packages(), # Automatically finds packages
|
||||||
install_requires=[ # List your dependencies here
|
install_requires=[ # List your dependencies here
|
||||||
'huggingface_hub',
|
'huggingface_hub',
|
||||||
|
|||||||
Reference in New Issue
Block a user