diff --git a/kokoro/__init__.py b/kokoro/__init__.py index d1675d6..c9dbfa3 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.3' +__version__ = '0.3.0' -from .models import KModel +from .model import KModel from .pipeline import KPipeline diff --git a/kokoro/model.py b/kokoro/model.py new file mode 100644 index 0000000..be5ad30 --- /dev/null +++ b/kokoro/model.py @@ -0,0 +1,91 @@ +from .istftnet import Decoder +from .modules import CustomAlbert, ProsodyPredictor, TextEncoder +from huggingface_hub import hf_hub_download +from numbers import Number +from transformers import AlbertConfig +from typing import Dict, Optional, Union +import json +import torch + +class KModel(torch.nn.Module): + ''' + KModel is a torch.nn.Module with 2 main responsibilities: + 1. Init weights, downloading config.json + model.pth from HF if needed + 2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor) + + You likely only need one KModel instance, and it can be reused across + multiple KPipelines to avoid redundant memory allocation. + + Unlike KPipeline, KModel is language-blind. + + KModel stores self.vocab and thus knows how to map phonemes -> input_ids, + so there is no need to repeatedly download config.json outside of KModel. + ''' + + REPO_ID = 'hexgrad/Kokoro-82M' + + def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None): + super().__init__() + if not isinstance(config, dict): + if not config: + config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json') + with open(config, 'r', encoding='utf-8') as r: + config = json.load(r) + self.vocab = config['vocab'] + self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) + self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) + self.context_length = self.bert.config.max_position_embeddings + self.predictor = ProsodyPredictor( + style_dim=config['style_dim'], d_hid=config['hidden_dim'], + nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] + ) + self.text_encoder = TextEncoder( + channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], + depth=config['n_layer'], n_symbols=config['n_token'] + ) + self.decoder = Decoder( + dim_in=config['hidden_dim'], style_dim=config['style_dim'], + dim_out=config['n_mels'], **config['istftnet'] + ) + if not model: + model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth') + for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): + assert hasattr(self, key), key + try: + getattr(self, key).load_state_dict(state_dict) + except: + state_dict = {k[7:]: v for k, v in state_dict.items()} + getattr(self, key).load_state_dict(state_dict, strict=False) + + @property + def device(self): + return self.bert.device + + @torch.no_grad() + def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor: + input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) + assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) + input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) + input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device) + text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) + text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device) + bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) + d_en = self.bert_encoder(bert_dur).transpose(-1, -2) + ref_s = ref_s.to(self.device) + s = ref_s[:, 128:] + d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) + x, _ = self.predictor.lstm(d) + duration = self.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long() + pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1 + c_frame += pred_dur[0,i].item() + pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) + en = d.transpose(-1, -2) @ pred_aln_trg + F0_pred, N_pred = self.predictor.F0Ntrain(en, s) + t_en = self.text_encoder(input_ids, input_lengths, text_mask) + asr = t_en @ pred_aln_trg + return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu() diff --git a/kokoro/models.py b/kokoro/modules.py similarity index 72% rename from kokoro/models.py rename to kokoro/modules.py index 0344a6e..2b72307 100644 --- a/kokoro/models.py +++ b/kokoro/modules.py @@ -1,7 +1,7 @@ # https://github.com/yl4579/StyleTTS2/blob/main/models.py -from .istftnet import AdaIN1d, AdainResBlk1d, Decoder +from .istftnet import AdainResBlk1d from torch.nn.utils import weight_norm -from transformers import AlbertConfig, AlbertModel +from transformers import AlbertModel import numpy as np import torch import torch.nn as nn @@ -182,60 +182,3 @@ class CustomAlbert(AlbertModel): def forward(self, *args, **kwargs): outputs = super().forward(*args, **kwargs) return outputs.last_hidden_state - - -class KModel(nn.Module): - def __init__(self, config, path): - super().__init__() - self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) - self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) - self.predictor = ProsodyPredictor( - style_dim=config['style_dim'], d_hid=config['hidden_dim'], - nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] - ) - self.text_encoder = TextEncoder( - channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], - depth=config['n_layer'], n_symbols=config['n_token'] - ) - self.decoder = Decoder( - dim_in=config['hidden_dim'], style_dim=config['style_dim'], - dim_out=config['n_mels'], **config['istftnet'] - ) - for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items(): - assert hasattr(self, key), key - try: - getattr(self, key).load_state_dict(state_dict) - except: - state_dict = {k[7:]: v for k, v in state_dict.items()} - getattr(self, key).load_state_dict(state_dict, strict=False) - - @classmethod - def length_to_mask(cls, lengths): - mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) - mask = torch.gt(mask+1, lengths.unsqueeze(1)) - return mask - - @torch.no_grad() - def forward(self, input_ids, ref_s, speed): - device = ref_s.device - input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device) - input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device) - text_mask = type(self).length_to_mask(input_lengths).to(device) - bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) - d_en = self.bert_encoder(bert_dur).transpose(-1, -2) - s = ref_s[:, 128:] - d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) - x, _ = self.predictor.lstm(d) - duration = self.predictor.duration_proj(x) - duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() - pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item()) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1 - c_frame += pred_dur[0,i].item() - en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device) - F0_pred, N_pred = self.predictor.F0Ntrain(en, s) - t_en = self.text_encoder(input_ids, input_lengths, text_mask) - asr = t_en @ pred_aln_trg.unsqueeze(0).to(device) - return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy() diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index a2ce329..b91bd85 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -1,8 +1,8 @@ -from .models import KModel +from .model import KModel from huggingface_hub import hf_hub_download from misaki import en, espeak -import json -import os +from numbers import Number +from typing import Generator, List, Optional, Tuple, Union import re import torch @@ -15,50 +15,74 @@ LANG_CODES = dict( i='it', p='pt-br', ) -REPO_ID = 'hexgrad/Kokoro-82M' class KPipeline: - def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None): + ''' + KPipeline is a language-aware support class with 2 main responsibilities: + 1. Perform language-specific G2P, mapping (and chunking) text -> phonemes + 2. Manage and store voices, lazily downloaded from HF if needed + + You are expected to have one KPipeline per language. If you have multiple + KPipelines, you should reuse one KModel instance across all of them. + + KPipeline is designed to work with a KModel, but this is not required. + There are 2 ways to pass an existing model into a pipeline: + 1. On init: us_pipeline = KPipeline(lang_code='a', model=model) + 2. On call: us_pipeline(text, voice, model=model) + + By default, KPipeline will automatically initialize its own KModel. To + suppress this, construct a "quiet" KPipeline with model=False. + + A "quiet" KPipeline yields (graphemes, phonemes, None) without generating + any audio. You can use this to phonemize and chunk your text in advance. + + A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio). + ''' + def __init__(self, lang_code: str, model: Union[KModel, bool] = True, trf: bool = False): assert lang_code in LANG_CODES, (lang_code, LANG_CODES) self.lang_code = lang_code - if config_path is None: - config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json') - assert os.path.exists(config_path) - with open(config_path, 'r') as r: - config = json.load(r) - if model_path is None: - model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth') - assert os.path.exists(model_path) - self.vocab = config['vocab'] - self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device - self.model = KModel(config, model_path).to(self.device).eval() + self.model = None + if isinstance(model, KModel): + self.model = model + elif model: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = KModel().to(device).eval() self.voices = {} if lang_code in 'ab': try: fallback = espeak.EspeakFallback(british=lang_code=='b') except Exception as e: - print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e) + print('⚠️ WARNING: EspeakFallback not enabled. OOD words will be skipped.', e) fallback = None self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) else: language = LANG_CODES[lang_code] - print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") + print(f"⚠️ WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") self.g2p = espeak.EspeakG2P(language=language) - def load_voice(self, voice): + def load_voice(self, voice: str) -> torch.FloatTensor: if voice in self.voices: - return - v = voice.split('/')[-1] - if not v.startswith(self.lang_code): - v = LANG_CODES.get(v, voice) - p = LANG_CODES.get(self.lang_code, self.lang_code) - print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.') - voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt') - assert os.path.exists(voice_path) - self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device) + return self.voices[voice] + if voice.endswith('.pt'): + f = voice + else: + f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt') + if not voice.startswith(self.lang_code): + v = LANG_CODES.get(voice, voice) + p = LANG_CODES.get(self.lang_code, self.lang_code) + print(f'⚠️ WARNING: Language mismatch, loading {v} voice into {p} pipeline.') + pack = torch.load(f, weights_only=True) + self.voices[voice] = pack + return pack @classmethod - def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}): + def waterfall_last( + cls, + pairs: List[Tuple[str, str]], + next_count: int, + waterfall: List[str] = ['!.?…', ':;', ',—'], + bumps: List[str] = [')', '”'] + ) -> int: for w in waterfall: z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None) if z is not None: @@ -70,7 +94,10 @@ class KPipeline: return z return len(pairs) - def en_tokenize(self, tokens): + def en_tokenize( + self, + tokens: List[Union[en.MutableToken, List[en.MutableToken]]] + ) -> Generator[Tuple[str, str], None, None]: pairs = [] count = 0 for w in tokens: @@ -78,11 +105,11 @@ class KPipeline: if t.phonemes is None: continue next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else '' - next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T + next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T next_ps += ' ' if t.whitespace else '' next_count = count + len(next_ps.rstrip()) if next_count > 510: - z = type(self).waterfall_last(pairs, next_count) + z = KPipeline.waterfall_last(pairs, next_count) text, ps = zip(*pairs[:z]) ps = ''.join(ps) yield ''.join(text).strip(), ps.strip() @@ -96,14 +123,27 @@ class KPipeline: text, ps = zip(*pairs) yield ''.join(text).strip(), ''.join(ps).strip() - def p2ii(self, ps): - input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps))) - assert input_ids and len(input_ids) <= 510, input_ids - return input_ids + @classmethod + def infer( + cls, + model: Optional[KModel], + ps: str, + pack: torch.FloatTensor, + speed: Number + ) -> Optional[torch.FloatTensor]: + return model(ps, pack[len(ps)-1], speed) if model else None - def __call__(self, text, voice, speed=1, split_pattern=r'\n+'): - assert isinstance(text, str) or isinstance(text, list), type(text) - self.load_voice(voice) + def __call__( + self, + text: Union[str, List[str]], + voice: str, + speed: Number = 1, + split_pattern: Optional[str] = r'\n+', + model: Optional[KModel] = None + ) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]: + pack = self.load_voice(voice) + model = model or self.model + pack = pack.to(model.device) if model else pack if isinstance(text, str): text = re.split(split_pattern, text.strip()) if split_pattern else [text] for graphemes in text: @@ -113,16 +153,14 @@ class KPipeline: if not ps: continue elif len(ps) > 510: - print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") - continue - input_ids = self.p2ii(ps) - yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) + print(f"⚠️ WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + ps = ps[:510] + yield gs, ps, KPipeline.infer(model, ps, pack, speed) else: ps = self.g2p(graphemes) if not ps: continue elif len(ps) > 510: - print(f'WARNING: Truncating len(ps) == {len(ps)} > 510') + print(f'⚠️ WARNING: Truncating len(ps) == {len(ps)} > 510') ps = ps[:510] - input_ids = self.p2ii(ps) - yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed) + yield graphemes, ps, KPipeline.infer(model, ps, pack, speed) diff --git a/setup.py b/setup.py index 21a2438..7d10e33 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='kokoro', # Name of the package - version='0.2.3', # Initial version + version='0.3.0', # Initial version packages=find_packages(), # Automatically finds packages install_requires=[ # List your dependencies here 'huggingface_hub',