Backward compatible KModel.Output and KPipeline.Result dataclasses (#40)

* Backward compatible KModel.Output and KPipeline.Result dataclasses

* Typo: Bool => bool

* Allow voice=None for quiet pipelines

* Specify class names

* Fixed and tested

* Update README.md
This commit is contained in:
hexgrad
2025-02-02 11:05:25 -08:00
committed by GitHub
parent 0abd867239
commit 2f4d94bba2
5 changed files with 69 additions and 28 deletions

View File

@@ -8,7 +8,7 @@ An inference library for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)
You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md). You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md).
```py ```py
# 1⃣ Install kokoro # 1⃣ Install kokoro
!pip install -q kokoro>=0.3.1 soundfile !pip install -q kokoro>=0.3.5 soundfile
# 2⃣ Install espeak, used for English OOD fallback and some non-English languages # 2⃣ Install espeak, used for English OOD fallback and some non-English languages
!apt-get -qq -y install espeak-ng > /dev/null 2>&1 !apt-get -qq -y install espeak-ng > /dev/null 2>&1
# 🇪🇸 'e' => Spanish es # 🇪🇸 'e' => Spanish es

View File

@@ -1,4 +1,4 @@
__version__ = '0.3.4' __version__ = '0.3.5'
from loguru import logger from loguru import logger
import sys import sys
@@ -17,7 +17,7 @@ logger.add(
) )
# Disable before release or as needed # Disable before release or as needed
# logger.disable("kokoro") logger.disable("kokoro")
from .model import KModel from .model import KModel
from .pipeline import KPipeline from .pipeline import KPipeline

View File

@@ -1,10 +1,11 @@
from .istftnet import Decoder from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from loguru import logger
from numbers import Number from numbers import Number
from transformers import AlbertConfig from transformers import AlbertConfig
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from loguru import logger
import json import json
import torch import torch
@@ -65,8 +66,19 @@ class KModel(torch.nn.Module):
def device(self): def device(self):
return self.bert.device return self.bert.device
@dataclass
class Output:
audio: torch.FloatTensor
pred_dur: Optional[torch.LongTensor] = None
@torch.no_grad() @torch.no_grad()
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor: def forward(
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: Number = 1,
return_output: bool = False # MARK: BACKWARD COMPAT
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
@@ -82,16 +94,15 @@ class KModel(torch.nn.Module):
x, _ = self.predictor.lstm(d) x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x) duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long() pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
logger.debug(f"pred_dur: {pred_dur}") logger.debug(f"pred_dur: {pred_dur}")
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
c_frame = 0 pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
for i in range(pred_aln_trg.size(0)): pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s) F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask) t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg asr = t_en @ pred_aln_trg
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu() audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio

View File

@@ -1,9 +1,10 @@
from .model import KModel from .model import KModel
from dataclasses import dataclass
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak from misaki import en, espeak
from numbers import Number from numbers import Number
from typing import Generator, List, Optional, Tuple, Union from typing import Generator, List, Optional, Tuple, Union
from loguru import logger
import re import re
import torch import torch
@@ -134,6 +135,7 @@ class KPipeline:
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor: def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor:
if voice in self.voices: if voice in self.voices:
return self.voices[voice] return self.voices[voice]
logger.debug(f"Loading voice: {voice}")
packs = [self.load_single_voice(v) for v in voice.split(delimiter)] packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
if len(packs) == 1: if len(packs) == 1:
return packs[0] return packs[0]
@@ -194,26 +196,52 @@ class KPipeline:
@classmethod @classmethod
def infer( def infer(
cls, cls,
model: Optional[KModel], model: KModel,
ps: str, ps: str,
pack: torch.FloatTensor, pack: torch.FloatTensor,
speed: Number speed: Number = 1
) -> Optional[torch.FloatTensor]: ) -> KModel.Output:
return model(ps, pack[len(ps)-1], speed) if model else None return model(ps, pack[len(ps)-1], speed, return_output=True)
@dataclass
class Result:
graphemes: str
phonemes: str
output: Optional[KModel.Output] = None
@property
def audio(self) -> Optional[torch.FloatTensor]:
return None if self.output is None else self.output.audio
@property
def pred_dur(self) -> Optional[torch.LongTensor]:
return None if self.output is None else self.output.pred_dur
### MARK: BEGIN BACKWARD COMPAT ###
def __iter__(self):
yield self.graphemes
yield self.phonemes
yield self.audio
def __getitem__(self, index):
return [self.graphemes, self.phonemes, self.audio][index]
def __len__(self):
return 3
#### MARK: END BACKWARD COMPAT ####
def __call__( def __call__(
self, self,
text: Union[str, List[str]], text: Union[str, List[str]],
voice: str, voice: Optional[str] = None,
speed: Number = 1, speed: Number = 1,
split_pattern: Optional[str] = r'\n+', split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None model: Optional[KModel] = None
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]: ) -> Generator['KPipeline.Result', None, None]:
logger.debug(f"Loading voice: {voice}")
pack = self.load_voice(voice)
model = model or self.model model = model or self.model
pack = pack.to(model.device) if model else pack if model and voice is None:
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}") raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
if isinstance(text, str): if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text] text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text: for graphemes in text:
@@ -227,7 +255,8 @@ class KPipeline:
elif len(ps) > 510: elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510] ps = ps[:510]
yield gs, ps, KPipeline.infer(model, ps, pack, speed) output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=gs, phonemes=ps, output=output)
else: else:
ps = self.g2p(graphemes) ps = self.g2p(graphemes)
if not ps: if not ps:
@@ -235,4 +264,5 @@ class KPipeline:
elif len(ps) > 510: elif len(ps) > 510:
logger.warning(f'Truncating len(ps) == {len(ps)} > 510') logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510] ps = ps[:510]
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed) output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=graphemes, phonemes=ps, output=output)

View File

@@ -2,18 +2,18 @@ from setuptools import setup, find_packages
setup( setup(
name='kokoro', name='kokoro',
version='0.3.4', version='0.3.5',
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'huggingface_hub', 'huggingface_hub',
'loguru', 'loguru',
'misaki[en]>=0.6.5', 'misaki[en]>=0.6.7',
'numpy==1.26.4', 'numpy==1.26.4',
'scipy', 'scipy',
'torch', 'torch',
'transformers', 'transformers',
], ],
python_requires='>=3.6', python_requires='>=3.7',
author='hexgrad', author='hexgrad',
author_email='hello@hexgrad.com', author_email='hello@hexgrad.com',
description='TTS', description='TTS',