diff --git a/README.md b/README.md index 91108ce..28dabfb 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ An inference library for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) You can run this cell on [Google Colab](https://colab.research.google.com/). [Listen to samples](https://huggingface.co/hexgrad/Kokoro-82M/blob/main/SAMPLES.md). ```py # 1️⃣ Install kokoro -!pip install -q kokoro>=0.3.1 soundfile +!pip install -q kokoro>=0.3.5 soundfile # 2️⃣ Install espeak, used for English OOD fallback and some non-English languages !apt-get -qq -y install espeak-ng > /dev/null 2>&1 # 🇪🇸 'e' => Spanish es diff --git a/kokoro/__init__.py b/kokoro/__init__.py index 2f300e7..565e84d 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.4' +__version__ = '0.3.5' from loguru import logger import sys @@ -17,7 +17,7 @@ logger.add( ) # Disable before release or as needed -# logger.disable("kokoro") +logger.disable("kokoro") from .model import KModel from .pipeline import KPipeline diff --git a/kokoro/model.py b/kokoro/model.py index 4a2428d..27ab5c1 100644 --- a/kokoro/model.py +++ b/kokoro/model.py @@ -1,10 +1,11 @@ from .istftnet import Decoder from .modules import CustomAlbert, ProsodyPredictor, TextEncoder +from dataclasses import dataclass from huggingface_hub import hf_hub_download +from loguru import logger from numbers import Number from transformers import AlbertConfig from typing import Dict, Optional, Union -from loguru import logger import json import torch @@ -65,8 +66,19 @@ class KModel(torch.nn.Module): def device(self): return self.bert.device + @dataclass + class Output: + audio: torch.FloatTensor + pred_dur: Optional[torch.LongTensor] = None + @torch.no_grad() - def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor: + def forward( + self, + phonemes: str, + ref_s: torch.FloatTensor, + speed: Number = 1, + return_output: bool = False # MARK: BACKWARD COMPAT + ) -> Union['KModel.Output', torch.FloatTensor]: input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) @@ -82,16 +94,15 @@ class KModel(torch.nn.Module): x, _ = self.predictor.lstm(d) duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() + pred_dur = torch.round(duration).clamp(min=1).long().squeeze() logger.debug(f"pred_dur: {pred_dur}") - pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1 - c_frame += pred_dur[0,i].item() + indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur) + pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device) + pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1 pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) en = d.transpose(-1, -2) @ pred_aln_trg F0_pred, N_pred = self.predictor.F0Ntrain(en, s) t_en = self.text_encoder(input_ids, input_lengths, text_mask) asr = t_en @ pred_aln_trg - return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu() + audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu() + return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index ef0f135..6e0fc87 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -1,9 +1,10 @@ from .model import KModel +from dataclasses import dataclass from huggingface_hub import hf_hub_download +from loguru import logger from misaki import en, espeak from numbers import Number from typing import Generator, List, Optional, Tuple, Union -from loguru import logger import re import torch @@ -134,6 +135,7 @@ class KPipeline: def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor: if voice in self.voices: return self.voices[voice] + logger.debug(f"Loading voice: {voice}") packs = [self.load_single_voice(v) for v in voice.split(delimiter)] if len(packs) == 1: return packs[0] @@ -194,26 +196,52 @@ class KPipeline: @classmethod def infer( cls, - model: Optional[KModel], + model: KModel, ps: str, pack: torch.FloatTensor, - speed: Number - ) -> Optional[torch.FloatTensor]: - return model(ps, pack[len(ps)-1], speed) if model else None + speed: Number = 1 + ) -> KModel.Output: + return model(ps, pack[len(ps)-1], speed, return_output=True) + + @dataclass + class Result: + graphemes: str + phonemes: str + output: Optional[KModel.Output] = None + + @property + def audio(self) -> Optional[torch.FloatTensor]: + return None if self.output is None else self.output.audio + + @property + def pred_dur(self) -> Optional[torch.LongTensor]: + return None if self.output is None else self.output.pred_dur + + ### MARK: BEGIN BACKWARD COMPAT ### + def __iter__(self): + yield self.graphemes + yield self.phonemes + yield self.audio + + def __getitem__(self, index): + return [self.graphemes, self.phonemes, self.audio][index] + + def __len__(self): + return 3 + #### MARK: END BACKWARD COMPAT #### def __call__( self, text: Union[str, List[str]], - voice: str, + voice: Optional[str] = None, speed: Number = 1, split_pattern: Optional[str] = r'\n+', model: Optional[KModel] = None - ) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]: - logger.debug(f"Loading voice: {voice}") - pack = self.load_voice(voice) + ) -> Generator['KPipeline.Result', None, None]: model = model or self.model - pack = pack.to(model.device) if model else pack - logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}") + if model and voice is None: + raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")') + pack = self.load_voice(voice).to(model.device) if model else None if isinstance(text, str): text = re.split(split_pattern, text.strip()) if split_pattern else [text] for graphemes in text: @@ -227,7 +255,8 @@ class KPipeline: elif len(ps) > 510: logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") ps = ps[:510] - yield gs, ps, KPipeline.infer(model, ps, pack, speed) + output = KPipeline.infer(model, ps, pack, speed) if model else None + yield self.Result(graphemes=gs, phonemes=ps, output=output) else: ps = self.g2p(graphemes) if not ps: @@ -235,4 +264,5 @@ class KPipeline: elif len(ps) > 510: logger.warning(f'Truncating len(ps) == {len(ps)} > 510') ps = ps[:510] - yield graphemes, ps, KPipeline.infer(model, ps, pack, speed) + output = KPipeline.infer(model, ps, pack, speed) if model else None + yield self.Result(graphemes=graphemes, phonemes=ps, output=output) diff --git a/setup.py b/setup.py index b8a6222..ab480dc 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,18 @@ from setuptools import setup, find_packages setup( name='kokoro', - version='0.3.4', + version='0.3.5', packages=find_packages(), install_requires=[ 'huggingface_hub', 'loguru', - 'misaki[en]>=0.6.5', + 'misaki[en]>=0.6.7', 'numpy==1.26.4', 'scipy', 'torch', 'transformers', ], - python_requires='>=3.6', + python_requires='>=3.7', author='hexgrad', author_email='hello@hexgrad.com', description='TTS',