Word level timestamps (#46)

* Use MToken

* ALIASES #43

* Typo: add missing comma

* Change conditions to is not None

* Finish WLTs
This commit is contained in:
hexgrad
2025-02-03 18:27:12 -08:00
committed by GitHub
parent 2f4d94bba2
commit 43bc156514
3 changed files with 103 additions and 44 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.3.5' __version__ = '0.7.0'
from loguru import logger from loguru import logger
import sys import sys

View File

@@ -8,6 +8,18 @@ from typing import Generator, List, Optional, Tuple, Union
import re import re
import torch import torch
ALIASES = {
'en-us': 'a',
'en-gb': 'b',
'es': 'e',
'fr-fr': 'f',
'hi': 'h',
'it': 'i',
'pt-br': 'p',
'ja': 'j',
'zh': 'z',
}
LANG_CODES = dict( LANG_CODES = dict(
# pip install misaki[en] # pip install misaki[en]
a='American English', a='American English',
@@ -66,6 +78,8 @@ class KPipeline:
If None, will auto-select cuda if available If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error If 'cuda' and not available, will explicitly raise an error
""" """
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES) assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code self.lang_code = lang_code
self.model = None self.model = None
@@ -91,7 +105,7 @@ class KPipeline:
logger.warning("EspeakFallback not Enabled: OOD words will be skipped") logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)}) logger.warning({str(e)})
fallback = None fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
elif lang_code == 'j': elif lang_code == 'j':
try: try:
from misaki import ja from misaki import ja
@@ -142,56 +156,60 @@ class KPipeline:
self.voices[voice] = torch.mean(torch.stack(packs), dim=0) self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice] return self.voices[voice]
@classmethod
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@classmethod @classmethod
def waterfall_last( def waterfall_last(
cls, cls,
pairs: List[Tuple[str, str]], tokens: List[en.MToken],
next_count: int, next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'], waterfall: List[str] = ['!.?…', ':;', ',—'],
bumps: List[str] = [')', ''] bumps: List[str] = [')', '']
) -> int: ) -> int:
for w in waterfall: for w in waterfall:
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None) z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
if z is not None: if z is None:
continue
z += 1
if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1 z += 1
if z < len(pairs) and pairs[z][1].strip() in bumps: if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
z += 1 return z
_, ps = zip(*pairs[:z]) return len(tokens)
if next_count - len(''.join(ps)) <= 510:
return z @classmethod
return len(pairs) def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize( def en_tokenize(
self, self,
tokens: List[Union[en.MutableToken, List[en.MutableToken]]] tokens: List[en.MToken]
) -> Generator[Tuple[str, str], None, None]: ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
pairs = [] tks = []
count = 0 pcount = 0
for w in tokens: for t in tokens:
for t in (w if isinstance(w, list) else [w]): # American English: ɾ => T
if t.phonemes is None: t.phonemes = '' if t.phonemes is None else t.phonemes.replace('ɾ', 'T')
continue next_ps = t.phonemes + (' ' if t.whitespace else '')
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else '' next_pcount = pcount + len(next_ps.rstrip())
next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T if next_pcount > 510:
next_ps += ' ' if t.whitespace else '' z = KPipeline.waterfall_last(tks, next_pcount)
next_count = count + len(next_ps.rstrip()) text = KPipeline.tokens_to_text(tks[:z])
if next_count > 510: logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
z = KPipeline.waterfall_last(pairs, next_count) ps = KPipeline.tokens_to_ps(tks[:z])
text, ps = zip(*pairs[:z]) yield text, ps, tks[:z]
ps = ''.join(ps) tks = tks[z:]
text_chunk = ''.join(text).strip() pcount = len(KPipeline.tokens_to_ps(tks))
ps_chunk = ps.strip() if not tks:
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'") next_ps = next_ps.lstrip()
yield text_chunk, ps_chunk tks.append(t)
pairs = pairs[z:] pcount += len(next_ps)
count -= len(ps) if tks:
if not pairs: text = KPipeline.tokens_to_text(tks)
next_ps = next_ps.lstrip() ps = KPipeline.tokens_to_ps(tks)
pairs.append((t.text + t.whitespace, next_ps)) yield ''.join(text).strip(), ''.join(ps).strip(), tks
count += len(next_ps)
if pairs:
text, ps = zip(*pairs)
yield ''.join(text).strip(), ''.join(ps).strip()
@classmethod @classmethod
def infer( def infer(
@@ -203,10 +221,49 @@ class KPipeline:
) -> KModel.Output: ) -> KModel.Output:
return model(ps, pack[len(ps)-1], speed, return_output=True) return model(ps, pack[len(ps)-1], speed, return_output=True)
@classmethod
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80
MAGIC_DIVISOR = 80
if not tokens or len(pred_dur) < 3:
# We expect at least 3: <bos>, token, <eos>
return
# We track 2 counts, measured in half-frames: (left, right)
# This way we can cut space characters in half
# TODO: Is -3 an appropriate offset?
left = right = 2 * max(0, pred_dur[0].item() - 3)
# Updates:
# left = right + (2 * token_dur) + space_dur
# right = left + space_dur
i = 1
for t in tokens:
if i >= len(pred_dur)-1:
break
if not t.phonemes:
if t.whitespace:
i += 1
left = right + pred_dur[i].item()
right = left + pred_dur[i].item()
i += 1
continue
j = i + len(t.phonemes)
if j >= len(pred_dur):
break
t.start_ts = left / MAGIC_DIVISOR
token_dur = pred_dur[i: j].sum().item()
space_dur = pred_dur[j].item() if t.whitespace else 0
left = right + (2 * token_dur) + space_dur
t.end_ts = left / MAGIC_DIVISOR
right = left + space_dur
i = j + (1 if t.whitespace else 0)
@dataclass @dataclass
class Result: class Result:
graphemes: str graphemes: str
phonemes: str phonemes: str
tokens: Optional[List[en.MToken]] = None
output: Optional[KModel.Output] = None output: Optional[KModel.Output] = None
@property @property
@@ -249,14 +306,16 @@ class KPipeline:
if self.lang_code in 'ab': if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}") logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes) _, tokens = self.g2p(graphemes)
for gs, ps in self.en_tokenize(tokens): for gs, ps, tks in self.en_tokenize(tokens):
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510] ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=gs, phonemes=ps, output=output) if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
else: else:
ps = self.g2p(graphemes) ps = self.g2p(graphemes)
if not ps: if not ps:

View File

@@ -2,12 +2,12 @@ from setuptools import setup, find_packages
setup( setup(
name='kokoro', name='kokoro',
version='0.3.5', version='0.7.0',
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'huggingface_hub', 'huggingface_hub',
'loguru', 'loguru',
'misaki[en]>=0.6.7', 'misaki[en]>=0.7.0',
'numpy==1.26.4', 'numpy==1.26.4',
'scipy', 'scipy',
'torch', 'torch',