Backward compatible KModel.Output and KPipeline.Result dataclasses (#40)
* Backward compatible KModel.Output and KPipeline.Result dataclasses * Typo: Bool => bool * Allow voice=None for quiet pipelines * Specify class names * Fixed and tested * Update README.md
This commit is contained in:
@@ -8,7 +8,7 @@ An inference library for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)
|
|||||||
You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md).
|
You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md).
|
||||||
```py
|
```py
|
||||||
# 1️⃣ Install kokoro
|
# 1️⃣ Install kokoro
|
||||||
!pip install -q kokoro>=0.3.1 soundfile
|
!pip install -q kokoro>=0.3.5 soundfile
|
||||||
# 2️⃣ Install espeak, used for English OOD fallback and some non-English languages
|
# 2️⃣ Install espeak, used for English OOD fallback and some non-English languages
|
||||||
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
|
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
|
||||||
# 🇪🇸 'e' => Spanish es
|
# 🇪🇸 'e' => Spanish es
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = '0.3.4'
|
__version__ = '0.3.5'
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import sys
|
import sys
|
||||||
@@ -17,7 +17,7 @@ logger.add(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Disable before release or as needed
|
# Disable before release or as needed
|
||||||
# logger.disable("kokoro")
|
logger.disable("kokoro")
|
||||||
|
|
||||||
from .model import KModel
|
from .model import KModel
|
||||||
from .pipeline import KPipeline
|
from .pipeline import KPipeline
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from .istftnet import Decoder
|
from .istftnet import Decoder
|
||||||
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||||
|
from dataclasses import dataclass
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from loguru import logger
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from transformers import AlbertConfig
|
from transformers import AlbertConfig
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
from loguru import logger
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -65,8 +66,19 @@ class KModel(torch.nn.Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.bert.device
|
return self.bert.device
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Output:
|
||||||
|
audio: torch.FloatTensor
|
||||||
|
pred_dur: Optional[torch.LongTensor] = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
phonemes: str,
|
||||||
|
ref_s: torch.FloatTensor,
|
||||||
|
speed: Number = 1,
|
||||||
|
return_output: bool = False # MARK: BACKWARD COMPAT
|
||||||
|
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||||
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||||
@@ -82,16 +94,15 @@ class KModel(torch.nn.Module):
|
|||||||
x, _ = self.predictor.lstm(d)
|
x, _ = self.predictor.lstm(d)
|
||||||
duration = self.predictor.duration_proj(x)
|
duration = self.predictor.duration_proj(x)
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
||||||
logger.debug(f"pred_dur: {pred_dur}")
|
logger.debug(f"pred_dur: {pred_dur}")
|
||||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
||||||
c_frame = 0
|
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
||||||
for i in range(pred_aln_trg.size(0)):
|
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
||||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
|
||||||
c_frame += pred_dur[0,i].item()
|
|
||||||
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
||||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||||
asr = t_en @ pred_aln_trg
|
asr = t_en @ pred_aln_trg
|
||||||
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
||||||
|
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from .model import KModel
|
from .model import KModel
|
||||||
|
from dataclasses import dataclass
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from loguru import logger
|
||||||
from misaki import en, espeak
|
from misaki import en, espeak
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
from loguru import logger
|
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -134,6 +135,7 @@ class KPipeline:
|
|||||||
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor:
|
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor:
|
||||||
if voice in self.voices:
|
if voice in self.voices:
|
||||||
return self.voices[voice]
|
return self.voices[voice]
|
||||||
|
logger.debug(f"Loading voice: {voice}")
|
||||||
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
||||||
if len(packs) == 1:
|
if len(packs) == 1:
|
||||||
return packs[0]
|
return packs[0]
|
||||||
@@ -194,26 +196,52 @@ class KPipeline:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def infer(
|
def infer(
|
||||||
cls,
|
cls,
|
||||||
model: Optional[KModel],
|
model: KModel,
|
||||||
ps: str,
|
ps: str,
|
||||||
pack: torch.FloatTensor,
|
pack: torch.FloatTensor,
|
||||||
speed: Number
|
speed: Number = 1
|
||||||
) -> Optional[torch.FloatTensor]:
|
) -> KModel.Output:
|
||||||
return model(ps, pack[len(ps)-1], speed) if model else None
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Result:
|
||||||
|
graphemes: str
|
||||||
|
phonemes: str
|
||||||
|
output: Optional[KModel.Output] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio(self) -> Optional[torch.FloatTensor]:
|
||||||
|
return None if self.output is None else self.output.audio
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_dur(self) -> Optional[torch.LongTensor]:
|
||||||
|
return None if self.output is None else self.output.pred_dur
|
||||||
|
|
||||||
|
### MARK: BEGIN BACKWARD COMPAT ###
|
||||||
|
def __iter__(self):
|
||||||
|
yield self.graphemes
|
||||||
|
yield self.phonemes
|
||||||
|
yield self.audio
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return [self.graphemes, self.phonemes, self.audio][index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 3
|
||||||
|
#### MARK: END BACKWARD COMPAT ####
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
text: Union[str, List[str]],
|
text: Union[str, List[str]],
|
||||||
voice: str,
|
voice: Optional[str] = None,
|
||||||
speed: Number = 1,
|
speed: Number = 1,
|
||||||
split_pattern: Optional[str] = r'\n+',
|
split_pattern: Optional[str] = r'\n+',
|
||||||
model: Optional[KModel] = None
|
model: Optional[KModel] = None
|
||||||
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
|
) -> Generator['KPipeline.Result', None, None]:
|
||||||
logger.debug(f"Loading voice: {voice}")
|
|
||||||
pack = self.load_voice(voice)
|
|
||||||
model = model or self.model
|
model = model or self.model
|
||||||
pack = pack.to(model.device) if model else pack
|
if model and voice is None:
|
||||||
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
|
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
|
||||||
|
pack = self.load_voice(voice).to(model.device) if model else None
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||||
for graphemes in text:
|
for graphemes in text:
|
||||||
@@ -227,7 +255,8 @@ class KPipeline:
|
|||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
ps = ps[:510]
|
ps = ps[:510]
|
||||||
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
yield self.Result(graphemes=gs, phonemes=ps, output=output)
|
||||||
else:
|
else:
|
||||||
ps = self.g2p(graphemes)
|
ps = self.g2p(graphemes)
|
||||||
if not ps:
|
if not ps:
|
||||||
@@ -235,4 +264,5 @@ class KPipeline:
|
|||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
||||||
ps = ps[:510]
|
ps = ps[:510]
|
||||||
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
yield self.Result(graphemes=graphemes, phonemes=ps, output=output)
|
||||||
|
|||||||
6
setup.py
6
setup.py
@@ -2,18 +2,18 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='kokoro',
|
name='kokoro',
|
||||||
version='0.3.4',
|
version='0.3.5',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'huggingface_hub',
|
'huggingface_hub',
|
||||||
'loguru',
|
'loguru',
|
||||||
'misaki[en]>=0.6.5',
|
'misaki[en]>=0.6.7',
|
||||||
'numpy==1.26.4',
|
'numpy==1.26.4',
|
||||||
'scipy',
|
'scipy',
|
||||||
'torch',
|
'torch',
|
||||||
'transformers',
|
'transformers',
|
||||||
],
|
],
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.7',
|
||||||
author='hexgrad',
|
author='hexgrad',
|
||||||
author_email='hello@hexgrad.com',
|
author_email='hello@hexgrad.com',
|
||||||
description='TTS',
|
description='TTS',
|
||||||
|
|||||||
Reference in New Issue
Block a user