Word level timestamps (#46)

* Use MToken

* ALIASES #43

* Typo: add missing comma

* Change conditions to is not None

* Finish WLTs
This commit is contained in:
hexgrad
2025-02-03 18:27:12 -08:00
committed by GitHub
parent 2f4d94bba2
commit 43bc156514
3 changed files with 103 additions and 44 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.3.5'
__version__ = '0.7.0'
from loguru import logger
import sys

View File

@@ -8,6 +8,18 @@ from typing import Generator, List, Optional, Tuple, Union
import re
import torch
ALIASES = {
'en-us': 'a',
'en-gb': 'b',
'es': 'e',
'fr-fr': 'f',
'hi': 'h',
'it': 'i',
'pt-br': 'p',
'ja': 'j',
'zh': 'z',
}
LANG_CODES = dict(
# pip install misaki[en]
a='American English',
@@ -66,6 +78,8 @@ class KPipeline:
If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error
"""
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code
self.model = None
@@ -91,7 +105,7 @@ class KPipeline:
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
logger.warning({str(e)})
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
elif lang_code == 'j':
try:
from misaki import ja
@@ -142,56 +156,60 @@ class KPipeline:
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice]
@classmethod
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@classmethod
def waterfall_last(
cls,
pairs: List[Tuple[str, str]],
tokens: List[en.MToken],
next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'],
bumps: List[str] = [')', '']
) -> int:
for w in waterfall:
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
if z is not None:
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
if z is None:
continue
z += 1
if z < len(pairs) and pairs[z][1].strip() in bumps:
if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1
_, ps = zip(*pairs[:z])
if next_count - len(''.join(ps)) <= 510:
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
return z
return len(pairs)
return len(tokens)
@classmethod
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize(
self,
tokens: List[Union[en.MutableToken, List[en.MutableToken]]]
) -> Generator[Tuple[str, str], None, None]:
pairs = []
count = 0
for w in tokens:
for t in (w if isinstance(w, list) else [w]):
if t.phonemes is None:
continue
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T
next_ps += ' ' if t.whitespace else ''
next_count = count + len(next_ps.rstrip())
if next_count > 510:
z = KPipeline.waterfall_last(pairs, next_count)
text, ps = zip(*pairs[:z])
ps = ''.join(ps)
text_chunk = ''.join(text).strip()
ps_chunk = ps.strip()
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
yield text_chunk, ps_chunk
pairs = pairs[z:]
count -= len(ps)
if not pairs:
tokens: List[en.MToken]
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
tks = []
pcount = 0
for t in tokens:
# American English: ɾ => T
t.phonemes = '' if t.phonemes is None else t.phonemes.replace('ɾ', 'T')
next_ps = t.phonemes + (' ' if t.whitespace else '')
next_pcount = pcount + len(next_ps.rstrip())
if next_pcount > 510:
z = KPipeline.waterfall_last(tks, next_pcount)
text = KPipeline.tokens_to_text(tks[:z])
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
ps = KPipeline.tokens_to_ps(tks[:z])
yield text, ps, tks[:z]
tks = tks[z:]
pcount = len(KPipeline.tokens_to_ps(tks))
if not tks:
next_ps = next_ps.lstrip()
pairs.append((t.text + t.whitespace, next_ps))
count += len(next_ps)
if pairs:
text, ps = zip(*pairs)
yield ''.join(text).strip(), ''.join(ps).strip()
tks.append(t)
pcount += len(next_ps)
if tks:
text = KPipeline.tokens_to_text(tks)
ps = KPipeline.tokens_to_ps(tks)
yield ''.join(text).strip(), ''.join(ps).strip(), tks
@classmethod
def infer(
@@ -203,10 +221,49 @@ class KPipeline:
) -> KModel.Output:
return model(ps, pack[len(ps)-1], speed, return_output=True)
@classmethod
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80
MAGIC_DIVISOR = 80
if not tokens or len(pred_dur) < 3:
# We expect at least 3: <bos>, token, <eos>
return
# We track 2 counts, measured in half-frames: (left, right)
# This way we can cut space characters in half
# TODO: Is -3 an appropriate offset?
left = right = 2 * max(0, pred_dur[0].item() - 3)
# Updates:
# left = right + (2 * token_dur) + space_dur
# right = left + space_dur
i = 1
for t in tokens:
if i >= len(pred_dur)-1:
break
if not t.phonemes:
if t.whitespace:
i += 1
left = right + pred_dur[i].item()
right = left + pred_dur[i].item()
i += 1
continue
j = i + len(t.phonemes)
if j >= len(pred_dur):
break
t.start_ts = left / MAGIC_DIVISOR
token_dur = pred_dur[i: j].sum().item()
space_dur = pred_dur[j].item() if t.whitespace else 0
left = right + (2 * token_dur) + space_dur
t.end_ts = left / MAGIC_DIVISOR
right = left + space_dur
i = j + (1 if t.whitespace else 0)
@dataclass
class Result:
graphemes: str
phonemes: str
tokens: Optional[List[en.MToken]] = None
output: Optional[KModel.Output] = None
@property
@@ -249,14 +306,16 @@ class KPipeline:
if self.lang_code in 'ab':
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
_, tokens = self.g2p(graphemes)
for gs, ps in self.en_tokenize(tokens):
for gs, ps, tks in self.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510]
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=gs, phonemes=ps, output=output)
if output is not None and output.pred_dur is not None:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
else:
ps = self.g2p(graphemes)
if not ps:

View File

@@ -2,12 +2,12 @@ from setuptools import setup, find_packages
setup(
name='kokoro',
version='0.3.5',
version='0.7.0',
packages=find_packages(),
install_requires=[
'huggingface_hub',
'loguru',
'misaki[en]>=0.6.7',
'misaki[en]>=0.7.0',
'numpy==1.26.4',
'scipy',
'torch',