From efa91a8a3fa732a71fe0d25cceaf1b95dcef9a25 Mon Sep 17 00:00:00 2001 From: hexgrad <166769057+hexgrad@users.noreply.github.com> Date: Wed, 26 Feb 2025 17:30:50 -0800 Subject: [PATCH] Match misaki==0.8.0 dev branch (#114) * Match misaki==0.8.0 dev branch * en_callable, speed callable --- kokoro/__init__.py | 2 +- kokoro/model.py | 28 ++++++++++++++++++++-------- kokoro/pipeline.py | 46 +++++++++++++++++++++++++++------------------- pyproject.toml | 4 ++-- 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/kokoro/__init__.py b/kokoro/__init__.py index 303b735..2d7dece 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.7.16' +__version__ = '0.8.0' from loguru import logger import sys diff --git a/kokoro/model.py b/kokoro/model.py index 25659ad..08ffe8f 100644 --- a/kokoro/model.py +++ b/kokoro/model.py @@ -3,7 +3,6 @@ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder from dataclasses import dataclass from huggingface_hub import hf_hub_download from loguru import logger -from numbers import Number from transformers import AlbertConfig from typing import Dict, Optional, Union import json @@ -24,14 +23,27 @@ class KModel(torch.nn.Module): so there is no need to repeatedly download config.json outside of KModel. ''' - REPO_ID = 'hexgrad/Kokoro-82M' + MODEL_NAMES = { + 'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth', + 'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth', + } - def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False): + def __init__( + self, + repo_id: Optional[str] = None, + config: Union[Dict, str, None] = None, + model: Optional[str] = None, + disable_complex: bool = False + ): super().__init__() + if repo_id is None: + repo_id = 'hexgrad/Kokoro-82M' + print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") + self.repo_id = repo_id if not isinstance(config, dict): if not config: logger.debug("No config provided, downloading from HF") - config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json') + config = hf_hub_download(repo_id=repo_id, filename='config.json') with open(config, 'r', encoding='utf-8') as r: config = json.load(r) logger.debug(f"Loaded config: {config}") @@ -52,7 +64,7 @@ class KModel(torch.nn.Module): dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] ) if not model: - model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth') + model = hf_hub_download(repo_id=repo_id, filename=MODEL_NAMES[repo_id]) for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): assert hasattr(self, key), key try: @@ -76,7 +88,7 @@ class KModel(torch.nn.Module): self, input_ids: torch.LongTensor, ref_s: torch.FloatTensor, - speed: Number = 1 + speed: float = 1 ) -> tuple[torch.FloatTensor, torch.LongTensor]: input_lengths = torch.full( (input_ids.shape[0],), @@ -110,7 +122,7 @@ class KModel(torch.nn.Module): self, phonemes: str, ref_s: torch.FloatTensor, - speed: Number = 1, + speed: float = 1, return_output: bool = False ) -> Union['KModel.Output', torch.FloatTensor]: input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) @@ -133,7 +145,7 @@ class KModelForONNX(torch.nn.Module): self, input_ids: torch.LongTensor, ref_s: torch.FloatTensor, - speed: Number = 1 + speed: float = 1 ) -> tuple[torch.FloatTensor, torch.LongTensor]: waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) return waveform diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index d7969a2..6afa9a4 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from huggingface_hub import hf_hub_download from loguru import logger from misaki import en, espeak -from typing import Generator, List, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple, Union import re import torch @@ -63,8 +63,10 @@ class KPipeline: def __init__( self, lang_code: str, + repo_id: Optional[str] = None, model: Union[KModel, bool] = True, trf: bool = False, + en_callable: Optional[Callable[[str], str]] = None, device: Optional[str] = None ): """Initialize a KPipeline. @@ -77,6 +79,10 @@ class KPipeline: If None, will auto-select cuda if available If 'cuda' and not available, will explicitly raise an error """ + if repo_id is None: + repo_id = 'hexgrad/Kokoro-82M' + print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") + self.repo_id = repo_id lang_code = lang_code.lower() lang_code = ALIASES.get(lang_code, lang_code) assert lang_code in LANG_CODES, (lang_code, LANG_CODES) @@ -90,7 +96,7 @@ class KPipeline: if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' try: - self.model = KModel().to(device).eval() + self.model = KModel(repo_id=repo_id).to(device).eval() except RuntimeError as e: if device == 'cuda': raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. @@ -115,7 +121,10 @@ class KPipeline: elif lang_code == 'z': try: from misaki import zh - self.g2p = zh.ZHG2P() + self.g2p = zh.ZHG2P( + version=None if repo_id.endswith('/Kokoro-82M') else '1.1', + en_callable=en_callable + ) except ImportError: logger.error("You need to `pip install misaki[zh]` to use lang_code='z'") raise @@ -130,7 +139,7 @@ class KPipeline: if voice.endswith('.pt'): f = voice else: - f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt') + f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt') if not voice.startswith(self.lang_code): v = LANG_CODES.get(voice, voice) p = LANG_CODES.get(self.lang_code, self.lang_code) @@ -157,13 +166,12 @@ class KPipeline: self.voices[voice] = torch.mean(torch.stack(packs), dim=0) return self.voices[voice] - @classmethod - def tokens_to_ps(cls, tokens: List[en.MToken]) -> str: + @staticmethod + def tokens_to_ps(tokens: List[en.MToken]) -> str: return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip() - @classmethod + @staticmethod def waterfall_last( - cls, tokens: List[en.MToken], next_count: int, waterfall: List[str] = ['!.?…', ':;', ',—'], @@ -176,12 +184,12 @@ class KPipeline: z += 1 if z < len(tokens) and tokens[z].phonemes in bumps: z += 1 - if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510: + if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510: return z return len(tokens) - @classmethod - def tokens_to_text(cls, tokens: List[en.MToken]) -> str: + @staticmethod + def tokens_to_text(tokens: List[en.MToken]) -> str: return ''.join(t.text + t.whitespace for t in tokens).strip() def en_tokenize( @@ -212,14 +220,15 @@ class KPipeline: ps = KPipeline.tokens_to_ps(tks) yield ''.join(text).strip(), ''.join(ps).strip(), tks - @classmethod + @staticmethod def infer( - cls, model: KModel, ps: str, pack: torch.FloatTensor, - speed: float = 1 + speed: Union[float, Callable[[int], float]] = 1 ) -> KModel.Output: + if callable(speed): + speed = speed(len(ps)) return model(ps, pack[len(ps)-1], speed, return_output=True) def generate_from_tokens( @@ -272,8 +281,8 @@ class KPipeline: KPipeline.join_timestamps(tks, output.pred_dur) yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output) - @classmethod - def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor): + @staticmethod + def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor): # Multiply by 600 to go from pred_dur frames to sample_rate 24000 # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds # We will count nice round half-frames, so the divisor is 80 @@ -343,7 +352,7 @@ class KPipeline: self, text: Union[str, List[str]], voice: Optional[str] = None, - speed: float = 1, + speed: Union[float, Callable[[int], float]] = 1, split_pattern: Optional[str] = r'\n+', model: Optional[KModel] = None ) -> Generator['KPipeline.Result', None, None]: @@ -412,7 +421,7 @@ class KPipeline: if not chunk.strip(): continue - ps = self.g2p(chunk) + ps, _ = self.g2p(chunk) if not ps: continue elif len(ps) > 510: @@ -421,4 +430,3 @@ class KPipeline: output = KPipeline.infer(model, ps, pack, speed) if model else None yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index) - diff --git a/pyproject.toml b/pyproject.toml index b3cbf9a..f41e626 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "kokoro" -version = "0.7.16" +version = "0.8.0" description = "TTS" readme = "README.md" authors = [ @@ -20,7 +20,7 @@ requires-python = ">=3.10, <3.13" dependencies = [ "huggingface_hub", "loguru", - "misaki[en]>=0.7.16", + "misaki[en]>=0.8.0", "numpy==1.26.4", "scipy", "torch",