Backward compatible KModel.Output and KPipeline.Result dataclasses (#40)

* Backward compatible KModel.Output and KPipeline.Result dataclasses

* Typo: Bool => bool

* Allow voice=None for quiet pipelines

* Specify class names

* Fixed and tested

* Update README.md
This commit is contained in:
hexgrad
2025-02-02 11:05:25 -08:00
committed by GitHub
parent 0abd867239
commit 2f4d94bba2
5 changed files with 69 additions and 28 deletions

View File

@@ -8,7 +8,7 @@ An inference library for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)
You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md).
```py
# 1⃣ Install kokoro
!pip install -q kokoro>=0.3.1 soundfile
!pip install -q kokoro>=0.3.5 soundfile
# 2⃣ Install espeak, used for English OOD fallback and some non-English languages
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
# 🇪🇸 'e' => Spanish es

View File

@@ -1,4 +1,4 @@
__version__ = '0.3.4'
__version__ = '0.3.5'
from loguru import logger
import sys
@@ -17,7 +17,7 @@ logger.add(
)
# Disable before release or as needed
# logger.disable("kokoro")
logger.disable("kokoro")
from .model import KModel
from .pipeline import KPipeline

View File

@@ -1,10 +1,11 @@
from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from numbers import Number
from transformers import AlbertConfig
from typing import Dict, Optional, Union
from loguru import logger
import json
import torch
@@ -65,8 +66,19 @@ class KModel(torch.nn.Module):
def device(self):
return self.bert.device
@dataclass
class Output:
audio: torch.FloatTensor
pred_dur: Optional[torch.LongTensor] = None
@torch.no_grad()
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
def forward(
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: Number = 1,
return_output: bool = False # MARK: BACKWARD COMPAT
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
@@ -82,16 +94,15 @@ class KModel(torch.nn.Module):
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
logger.debug(f"pred_dur: {pred_dur}")
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio

View File

@@ -1,9 +1,10 @@
from .model import KModel
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak
from numbers import Number
from typing import Generator, List, Optional, Tuple, Union
from loguru import logger
import re
import torch
@@ -134,6 +135,7 @@ class KPipeline:
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor:
if voice in self.voices:
return self.voices[voice]
logger.debug(f"Loading voice: {voice}")
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
if len(packs) == 1:
return packs[0]
@@ -194,26 +196,52 @@ class KPipeline:
@classmethod
def infer(
cls,
model: Optional[KModel],
model: KModel,
ps: str,
pack: torch.FloatTensor,
speed: Number
) -> Optional[torch.FloatTensor]:
return model(ps, pack[len(ps)-1], speed) if model else None
speed: Number = 1
) -> KModel.Output:
return model(ps, pack[len(ps)-1], speed, return_output=True)
@dataclass
class Result:
graphemes: str
phonemes: str
output: Optional[KModel.Output] = None
@property
def audio(self) -> Optional[torch.FloatTensor]:
return None if self.output is None else self.output.audio
@property
def pred_dur(self) -> Optional[torch.LongTensor]:
return None if self.output is None else self.output.pred_dur
### MARK: BEGIN BACKWARD COMPAT ###
def __iter__(self):
yield self.graphemes
yield self.phonemes
yield self.audio
def __getitem__(self, index):
return [self.graphemes, self.phonemes, self.audio][index]
def __len__(self):
return 3
#### MARK: END BACKWARD COMPAT ####
def __call__(
self,
text: Union[str, List[str]],
voice: str,
voice: Optional[str] = None,
speed: Number = 1,
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
logger.debug(f"Loading voice: {voice}")
pack = self.load_voice(voice)
) -> Generator['KPipeline.Result', None, None]:
model = model or self.model
pack = pack.to(model.device) if model else pack
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
if model and voice is None:
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
pack = self.load_voice(voice).to(model.device) if model else None
if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text:
@@ -227,7 +255,8 @@ class KPipeline:
elif len(ps) > 510:
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
ps = ps[:510]
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=gs, phonemes=ps, output=output)
else:
ps = self.g2p(graphemes)
if not ps:
@@ -235,4 +264,5 @@ class KPipeline:
elif len(ps) > 510:
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510]
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=graphemes, phonemes=ps, output=output)

View File

@@ -2,18 +2,18 @@ from setuptools import setup, find_packages
setup(
name='kokoro',
version='0.3.4',
version='0.3.5',
packages=find_packages(),
install_requires=[
'huggingface_hub',
'loguru',
'misaki[en]>=0.6.5',
'misaki[en]>=0.6.7',
'numpy==1.26.4',
'scipy',
'torch',
'transformers',
],
python_requires='>=3.6',
python_requires='>=3.7',
author='hexgrad',
author_email='hello@hexgrad.com',
description='TTS',