Word level timestamps (#46)
* Use MToken * ALIASES #43 * Typo: add missing comma * Change conditions to is not None * Finish WLTs
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
__version__ = '0.3.5'
|
__version__ = '0.7.0'
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -8,6 +8,18 @@ from typing import Generator, List, Optional, Tuple, Union
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
ALIASES = {
|
||||||
|
'en-us': 'a',
|
||||||
|
'en-gb': 'b',
|
||||||
|
'es': 'e',
|
||||||
|
'fr-fr': 'f',
|
||||||
|
'hi': 'h',
|
||||||
|
'it': 'i',
|
||||||
|
'pt-br': 'p',
|
||||||
|
'ja': 'j',
|
||||||
|
'zh': 'z',
|
||||||
|
}
|
||||||
|
|
||||||
LANG_CODES = dict(
|
LANG_CODES = dict(
|
||||||
# pip install misaki[en]
|
# pip install misaki[en]
|
||||||
a='American English',
|
a='American English',
|
||||||
@@ -66,6 +78,8 @@ class KPipeline:
|
|||||||
If None, will auto-select cuda if available
|
If None, will auto-select cuda if available
|
||||||
If 'cuda' and not available, will explicitly raise an error
|
If 'cuda' and not available, will explicitly raise an error
|
||||||
"""
|
"""
|
||||||
|
lang_code = lang_code.lower()
|
||||||
|
lang_code = ALIASES.get(lang_code, lang_code)
|
||||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||||
self.lang_code = lang_code
|
self.lang_code = lang_code
|
||||||
self.model = None
|
self.model = None
|
||||||
@@ -91,7 +105,7 @@ class KPipeline:
|
|||||||
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
||||||
logger.warning({str(e)})
|
logger.warning({str(e)})
|
||||||
fallback = None
|
fallback = None
|
||||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
|
||||||
elif lang_code == 'j':
|
elif lang_code == 'j':
|
||||||
try:
|
try:
|
||||||
from misaki import ja
|
from misaki import ja
|
||||||
@@ -142,56 +156,60 @@ class KPipeline:
|
|||||||
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||||
return self.voices[voice]
|
return self.voices[voice]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
|
||||||
|
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def waterfall_last(
|
def waterfall_last(
|
||||||
cls,
|
cls,
|
||||||
pairs: List[Tuple[str, str]],
|
tokens: List[en.MToken],
|
||||||
next_count: int,
|
next_count: int,
|
||||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||||
bumps: List[str] = [')', '”']
|
bumps: List[str] = [')', '”']
|
||||||
) -> int:
|
) -> int:
|
||||||
for w in waterfall:
|
for w in waterfall:
|
||||||
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
|
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
|
||||||
if z is not None:
|
if z is None:
|
||||||
|
continue
|
||||||
z += 1
|
z += 1
|
||||||
if z < len(pairs) and pairs[z][1].strip() in bumps:
|
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||||
z += 1
|
z += 1
|
||||||
_, ps = zip(*pairs[:z])
|
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
|
||||||
if next_count - len(''.join(ps)) <= 510:
|
|
||||||
return z
|
return z
|
||||||
return len(pairs)
|
return len(tokens)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
|
||||||
|
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||||
|
|
||||||
def en_tokenize(
|
def en_tokenize(
|
||||||
self,
|
self,
|
||||||
tokens: List[Union[en.MutableToken, List[en.MutableToken]]]
|
tokens: List[en.MToken]
|
||||||
) -> Generator[Tuple[str, str], None, None]:
|
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
|
||||||
pairs = []
|
tks = []
|
||||||
count = 0
|
pcount = 0
|
||||||
for w in tokens:
|
for t in tokens:
|
||||||
for t in (w if isinstance(w, list) else [w]):
|
# American English: ɾ => T
|
||||||
if t.phonemes is None:
|
t.phonemes = '' if t.phonemes is None else t.phonemes.replace('ɾ', 'T')
|
||||||
continue
|
next_ps = t.phonemes + (' ' if t.whitespace else '')
|
||||||
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
|
next_pcount = pcount + len(next_ps.rstrip())
|
||||||
next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T
|
if next_pcount > 510:
|
||||||
next_ps += ' ' if t.whitespace else ''
|
z = KPipeline.waterfall_last(tks, next_pcount)
|
||||||
next_count = count + len(next_ps.rstrip())
|
text = KPipeline.tokens_to_text(tks[:z])
|
||||||
if next_count > 510:
|
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
|
||||||
z = KPipeline.waterfall_last(pairs, next_count)
|
ps = KPipeline.tokens_to_ps(tks[:z])
|
||||||
text, ps = zip(*pairs[:z])
|
yield text, ps, tks[:z]
|
||||||
ps = ''.join(ps)
|
tks = tks[z:]
|
||||||
text_chunk = ''.join(text).strip()
|
pcount = len(KPipeline.tokens_to_ps(tks))
|
||||||
ps_chunk = ps.strip()
|
if not tks:
|
||||||
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
|
|
||||||
yield text_chunk, ps_chunk
|
|
||||||
pairs = pairs[z:]
|
|
||||||
count -= len(ps)
|
|
||||||
if not pairs:
|
|
||||||
next_ps = next_ps.lstrip()
|
next_ps = next_ps.lstrip()
|
||||||
pairs.append((t.text + t.whitespace, next_ps))
|
tks.append(t)
|
||||||
count += len(next_ps)
|
pcount += len(next_ps)
|
||||||
if pairs:
|
if tks:
|
||||||
text, ps = zip(*pairs)
|
text = KPipeline.tokens_to_text(tks)
|
||||||
yield ''.join(text).strip(), ''.join(ps).strip()
|
ps = KPipeline.tokens_to_ps(tks)
|
||||||
|
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def infer(
|
def infer(
|
||||||
@@ -203,10 +221,49 @@ class KPipeline:
|
|||||||
) -> KModel.Output:
|
) -> KModel.Output:
|
||||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||||
|
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||||
|
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||||
|
# We will count nice round half-frames, so the divisor is 80
|
||||||
|
MAGIC_DIVISOR = 80
|
||||||
|
if not tokens or len(pred_dur) < 3:
|
||||||
|
# We expect at least 3: <bos>, token, <eos>
|
||||||
|
return
|
||||||
|
# We track 2 counts, measured in half-frames: (left, right)
|
||||||
|
# This way we can cut space characters in half
|
||||||
|
# TODO: Is -3 an appropriate offset?
|
||||||
|
left = right = 2 * max(0, pred_dur[0].item() - 3)
|
||||||
|
# Updates:
|
||||||
|
# left = right + (2 * token_dur) + space_dur
|
||||||
|
# right = left + space_dur
|
||||||
|
i = 1
|
||||||
|
for t in tokens:
|
||||||
|
if i >= len(pred_dur)-1:
|
||||||
|
break
|
||||||
|
if not t.phonemes:
|
||||||
|
if t.whitespace:
|
||||||
|
i += 1
|
||||||
|
left = right + pred_dur[i].item()
|
||||||
|
right = left + pred_dur[i].item()
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
j = i + len(t.phonemes)
|
||||||
|
if j >= len(pred_dur):
|
||||||
|
break
|
||||||
|
t.start_ts = left / MAGIC_DIVISOR
|
||||||
|
token_dur = pred_dur[i: j].sum().item()
|
||||||
|
space_dur = pred_dur[j].item() if t.whitespace else 0
|
||||||
|
left = right + (2 * token_dur) + space_dur
|
||||||
|
t.end_ts = left / MAGIC_DIVISOR
|
||||||
|
right = left + space_dur
|
||||||
|
i = j + (1 if t.whitespace else 0)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Result:
|
class Result:
|
||||||
graphemes: str
|
graphemes: str
|
||||||
phonemes: str
|
phonemes: str
|
||||||
|
tokens: Optional[List[en.MToken]] = None
|
||||||
output: Optional[KModel.Output] = None
|
output: Optional[KModel.Output] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -249,14 +306,16 @@ class KPipeline:
|
|||||||
if self.lang_code in 'ab':
|
if self.lang_code in 'ab':
|
||||||
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
||||||
_, tokens = self.g2p(graphemes)
|
_, tokens = self.g2p(graphemes)
|
||||||
for gs, ps in self.en_tokenize(tokens):
|
for gs, ps, tks in self.en_tokenize(tokens):
|
||||||
if not ps:
|
if not ps:
|
||||||
continue
|
continue
|
||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
ps = ps[:510]
|
ps = ps[:510]
|
||||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
yield self.Result(graphemes=gs, phonemes=ps, output=output)
|
if output is not None and output.pred_dur is not None:
|
||||||
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||||
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||||
else:
|
else:
|
||||||
ps = self.g2p(graphemes)
|
ps = self.g2p(graphemes)
|
||||||
if not ps:
|
if not ps:
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -2,12 +2,12 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='kokoro',
|
name='kokoro',
|
||||||
version='0.3.5',
|
version='0.7.0',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'huggingface_hub',
|
'huggingface_hub',
|
||||||
'loguru',
|
'loguru',
|
||||||
'misaki[en]>=0.6.7',
|
'misaki[en]>=0.7.0',
|
||||||
'numpy==1.26.4',
|
'numpy==1.26.4',
|
||||||
'scipy',
|
'scipy',
|
||||||
'torch',
|
'torch',
|
||||||
|
|||||||
Reference in New Issue
Block a user