Word level timestamps (#46)
* Use MToken * ALIASES #43 * Typo: add missing comma * Change conditions to is not None * Finish WLTs
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = '0.3.5'
|
||||
__version__ = '0.7.0'
|
||||
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
@@ -8,6 +8,18 @@ from typing import Generator, List, Optional, Tuple, Union
|
||||
import re
|
||||
import torch
|
||||
|
||||
ALIASES = {
|
||||
'en-us': 'a',
|
||||
'en-gb': 'b',
|
||||
'es': 'e',
|
||||
'fr-fr': 'f',
|
||||
'hi': 'h',
|
||||
'it': 'i',
|
||||
'pt-br': 'p',
|
||||
'ja': 'j',
|
||||
'zh': 'z',
|
||||
}
|
||||
|
||||
LANG_CODES = dict(
|
||||
# pip install misaki[en]
|
||||
a='American English',
|
||||
@@ -66,6 +78,8 @@ class KPipeline:
|
||||
If None, will auto-select cuda if available
|
||||
If 'cuda' and not available, will explicitly raise an error
|
||||
"""
|
||||
lang_code = lang_code.lower()
|
||||
lang_code = ALIASES.get(lang_code, lang_code)
|
||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||
self.lang_code = lang_code
|
||||
self.model = None
|
||||
@@ -91,7 +105,7 @@ class KPipeline:
|
||||
logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
|
||||
logger.warning({str(e)})
|
||||
fallback = None
|
||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
|
||||
elif lang_code == 'j':
|
||||
try:
|
||||
from misaki import ja
|
||||
@@ -142,56 +156,60 @@ class KPipeline:
|
||||
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||
return self.voices[voice]
|
||||
|
||||
@classmethod
|
||||
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||
|
||||
@classmethod
|
||||
def waterfall_last(
|
||||
cls,
|
||||
pairs: List[Tuple[str, str]],
|
||||
tokens: List[en.MToken],
|
||||
next_count: int,
|
||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||
bumps: List[str] = [')', '”']
|
||||
) -> int:
|
||||
for w in waterfall:
|
||||
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
|
||||
if z is not None:
|
||||
z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
|
||||
if z is None:
|
||||
continue
|
||||
z += 1
|
||||
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||
z += 1
|
||||
if z < len(pairs) and pairs[z][1].strip() in bumps:
|
||||
z += 1
|
||||
_, ps = zip(*pairs[:z])
|
||||
if next_count - len(''.join(ps)) <= 510:
|
||||
return z
|
||||
return len(pairs)
|
||||
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
|
||||
return z
|
||||
return len(tokens)
|
||||
|
||||
@classmethod
|
||||
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||
|
||||
def en_tokenize(
|
||||
self,
|
||||
tokens: List[Union[en.MutableToken, List[en.MutableToken]]]
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
pairs = []
|
||||
count = 0
|
||||
for w in tokens:
|
||||
for t in (w if isinstance(w, list) else [w]):
|
||||
if t.phonemes is None:
|
||||
continue
|
||||
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
|
||||
next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T
|
||||
next_ps += ' ' if t.whitespace else ''
|
||||
next_count = count + len(next_ps.rstrip())
|
||||
if next_count > 510:
|
||||
z = KPipeline.waterfall_last(pairs, next_count)
|
||||
text, ps = zip(*pairs[:z])
|
||||
ps = ''.join(ps)
|
||||
text_chunk = ''.join(text).strip()
|
||||
ps_chunk = ps.strip()
|
||||
logger.debug(f"Chunking text at {z}: '{text_chunk[:30]}{'...' if len(text_chunk) > 30 else ''}'")
|
||||
yield text_chunk, ps_chunk
|
||||
pairs = pairs[z:]
|
||||
count -= len(ps)
|
||||
if not pairs:
|
||||
next_ps = next_ps.lstrip()
|
||||
pairs.append((t.text + t.whitespace, next_ps))
|
||||
count += len(next_ps)
|
||||
if pairs:
|
||||
text, ps = zip(*pairs)
|
||||
yield ''.join(text).strip(), ''.join(ps).strip()
|
||||
tokens: List[en.MToken]
|
||||
) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
|
||||
tks = []
|
||||
pcount = 0
|
||||
for t in tokens:
|
||||
# American English: ɾ => T
|
||||
t.phonemes = '' if t.phonemes is None else t.phonemes.replace('ɾ', 'T')
|
||||
next_ps = t.phonemes + (' ' if t.whitespace else '')
|
||||
next_pcount = pcount + len(next_ps.rstrip())
|
||||
if next_pcount > 510:
|
||||
z = KPipeline.waterfall_last(tks, next_pcount)
|
||||
text = KPipeline.tokens_to_text(tks[:z])
|
||||
logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
|
||||
ps = KPipeline.tokens_to_ps(tks[:z])
|
||||
yield text, ps, tks[:z]
|
||||
tks = tks[z:]
|
||||
pcount = len(KPipeline.tokens_to_ps(tks))
|
||||
if not tks:
|
||||
next_ps = next_ps.lstrip()
|
||||
tks.append(t)
|
||||
pcount += len(next_ps)
|
||||
if tks:
|
||||
text = KPipeline.tokens_to_text(tks)
|
||||
ps = KPipeline.tokens_to_ps(tks)
|
||||
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||
|
||||
@classmethod
|
||||
def infer(
|
||||
@@ -203,10 +221,49 @@ class KPipeline:
|
||||
) -> KModel.Output:
|
||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||
|
||||
@classmethod
|
||||
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||
# We will count nice round half-frames, so the divisor is 80
|
||||
MAGIC_DIVISOR = 80
|
||||
if not tokens or len(pred_dur) < 3:
|
||||
# We expect at least 3: <bos>, token, <eos>
|
||||
return
|
||||
# We track 2 counts, measured in half-frames: (left, right)
|
||||
# This way we can cut space characters in half
|
||||
# TODO: Is -3 an appropriate offset?
|
||||
left = right = 2 * max(0, pred_dur[0].item() - 3)
|
||||
# Updates:
|
||||
# left = right + (2 * token_dur) + space_dur
|
||||
# right = left + space_dur
|
||||
i = 1
|
||||
for t in tokens:
|
||||
if i >= len(pred_dur)-1:
|
||||
break
|
||||
if not t.phonemes:
|
||||
if t.whitespace:
|
||||
i += 1
|
||||
left = right + pred_dur[i].item()
|
||||
right = left + pred_dur[i].item()
|
||||
i += 1
|
||||
continue
|
||||
j = i + len(t.phonemes)
|
||||
if j >= len(pred_dur):
|
||||
break
|
||||
t.start_ts = left / MAGIC_DIVISOR
|
||||
token_dur = pred_dur[i: j].sum().item()
|
||||
space_dur = pred_dur[j].item() if t.whitespace else 0
|
||||
left = right + (2 * token_dur) + space_dur
|
||||
t.end_ts = left / MAGIC_DIVISOR
|
||||
right = left + space_dur
|
||||
i = j + (1 if t.whitespace else 0)
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
graphemes: str
|
||||
phonemes: str
|
||||
tokens: Optional[List[en.MToken]] = None
|
||||
output: Optional[KModel.Output] = None
|
||||
|
||||
@property
|
||||
@@ -249,14 +306,16 @@ class KPipeline:
|
||||
if self.lang_code in 'ab':
|
||||
logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
|
||||
_, tokens = self.g2p(graphemes)
|
||||
for gs, ps in self.en_tokenize(tokens):
|
||||
for gs, ps, tks in self.en_tokenize(tokens):
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
ps = ps[:510]
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
yield self.Result(graphemes=gs, phonemes=ps, output=output)
|
||||
if output is not None and output.pred_dur is not None:
|
||||
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||
else:
|
||||
ps = self.g2p(graphemes)
|
||||
if not ps:
|
||||
|
||||
Reference in New Issue
Block a user