Refactor (#16)

* Refactor

* Bump to 0.2.4

* Fix typo

* Add missing @classmethod

* Simplify REPO_ID

* Use explicit class names

* Fix input_lengths typo

* Read config with utf-8 encoding, issue #18
This commit is contained in:
hexgrad
2025-01-29 10:28:49 -08:00
committed by GitHub
parent d388ee9e0b
commit aed687eab3
5 changed files with 180 additions and 108 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.2.3' __version__ = '0.3.0'
from .models import KModel from .model import KModel
from .pipeline import KPipeline from .pipeline import KPipeline

91
kokoro/model.py Normal file
View File

@@ -0,0 +1,91 @@
from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from huggingface_hub import hf_hub_download
from numbers import Number
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
import torch
class KModel(torch.nn.Module):
'''
KModel is a torch.nn.Module with 2 main responsibilities:
1. Init weights, downloading config.json + model.pth from HF if needed
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
You likely only need one KModel instance, and it can be reused across
multiple KPipelines to avoid redundant memory allocation.
Unlike KPipeline, KModel is language-blind.
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
so there is no need to repeatedly download config.json outside of KModel.
'''
REPO_ID = 'hexgrad/Kokoro-82M'
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None):
super().__init__()
if not isinstance(config, dict):
if not config:
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.context_length = self.bert.config.max_position_embeddings
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], **config['istftnet']
)
if not model:
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@property
def device(self):
return self.bert.device
@torch.no_grad()
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
ref_s = ref_s.to(self.device)
s = ref_s[:, 128:]
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()

View File

@@ -1,7 +1,7 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py # https://github.com/yl4579/StyleTTS2/blob/main/models.py
from .istftnet import AdaIN1d, AdainResBlk1d, Decoder from .istftnet import AdainResBlk1d
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from transformers import AlbertConfig, AlbertModel from transformers import AlbertModel
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -182,60 +182,3 @@ class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs) outputs = super().forward(*args, **kwargs)
return outputs.last_hidden_state return outputs.last_hidden_state
class KModel(nn.Module):
def __init__(self, config, path):
super().__init__()
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], **config['istftnet']
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@classmethod
def length_to_mask(cls, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward(self, input_ids, ref_s, speed):
device = ref_s.device
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device)
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device)
text_mask = type(self).length_to_mask(input_lengths).to(device)
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()

View File

@@ -1,8 +1,8 @@
from .models import KModel from .model import KModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from misaki import en, espeak from misaki import en, espeak
import json from numbers import Number
import os from typing import Generator, List, Optional, Tuple, Union
import re import re
import torch import torch
@@ -15,50 +15,74 @@ LANG_CODES = dict(
i='it', i='it',
p='pt-br', p='pt-br',
) )
REPO_ID = 'hexgrad/Kokoro-82M'
class KPipeline: class KPipeline:
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None): '''
KPipeline is a language-aware support class with 2 main responsibilities:
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
2. Manage and store voices, lazily downloaded from HF if needed
You are expected to have one KPipeline per language. If you have multiple
KPipelines, you should reuse one KModel instance across all of them.
KPipeline is designed to work with a KModel, but this is not required.
There are 2 ways to pass an existing model into a pipeline:
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
2. On call: us_pipeline(text, voice, model=model)
By default, KPipeline will automatically initialize its own KModel. To
suppress this, construct a "quiet" KPipeline with model=False.
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
any audio. You can use this to phonemize and chunk your text in advance.
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
'''
def __init__(self, lang_code: str, model: Union[KModel, bool] = True, trf: bool = False):
assert lang_code in LANG_CODES, (lang_code, LANG_CODES) assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code self.lang_code = lang_code
if config_path is None: self.model = None
config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json') if isinstance(model, KModel):
assert os.path.exists(config_path) self.model = model
with open(config_path, 'r') as r: elif model:
config = json.load(r) device = 'cuda' if torch.cuda.is_available() else 'cpu'
if model_path is None: self.model = KModel().to(device).eval()
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
assert os.path.exists(model_path)
self.vocab = config['vocab']
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.model = KModel(config, model_path).to(self.device).eval()
self.voices = {} self.voices = {}
if lang_code in 'ab': if lang_code in 'ab':
try: try:
fallback = espeak.EspeakFallback(british=lang_code=='b') fallback = espeak.EspeakFallback(british=lang_code=='b')
except Exception as e: except Exception as e:
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e) print('⚠️ WARNING: EspeakFallback not enabled. OOD words will be skipped.', e)
fallback = None fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
else: else:
language = LANG_CODES[lang_code] language = LANG_CODES[lang_code]
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") print(f"⚠️ WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
self.g2p = espeak.EspeakG2P(language=language) self.g2p = espeak.EspeakG2P(language=language)
def load_voice(self, voice): def load_voice(self, voice: str) -> torch.FloatTensor:
if voice in self.voices: if voice in self.voices:
return return self.voices[voice]
v = voice.split('/')[-1] if voice.endswith('.pt'):
if not v.startswith(self.lang_code): f = voice
v = LANG_CODES.get(v, voice) else:
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt')
if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code) p = LANG_CODES.get(self.lang_code, self.lang_code)
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.') print(f'⚠️ WARNING: Language mismatch, loading {v} voice into {p} pipeline.')
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt') pack = torch.load(f, weights_only=True)
assert os.path.exists(voice_path) self.voices[voice] = pack
self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device) return pack
@classmethod @classmethod
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', ''}): def waterfall_last(
cls,
pairs: List[Tuple[str, str]],
next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'],
bumps: List[str] = [')', '']
) -> int:
for w in waterfall: for w in waterfall:
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None) z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
if z is not None: if z is not None:
@@ -70,7 +94,10 @@ class KPipeline:
return z return z
return len(pairs) return len(pairs)
def en_tokenize(self, tokens): def en_tokenize(
self,
tokens: List[Union[en.MutableToken, List[en.MutableToken]]]
) -> Generator[Tuple[str, str], None, None]:
pairs = [] pairs = []
count = 0 count = 0
for w in tokens: for w in tokens:
@@ -78,11 +105,11 @@ class KPipeline:
if t.phonemes is None: if t.phonemes is None:
continue continue
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else '' next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T
next_ps += ' ' if t.whitespace else '' next_ps += ' ' if t.whitespace else ''
next_count = count + len(next_ps.rstrip()) next_count = count + len(next_ps.rstrip())
if next_count > 510: if next_count > 510:
z = type(self).waterfall_last(pairs, next_count) z = KPipeline.waterfall_last(pairs, next_count)
text, ps = zip(*pairs[:z]) text, ps = zip(*pairs[:z])
ps = ''.join(ps) ps = ''.join(ps)
yield ''.join(text).strip(), ps.strip() yield ''.join(text).strip(), ps.strip()
@@ -96,14 +123,27 @@ class KPipeline:
text, ps = zip(*pairs) text, ps = zip(*pairs)
yield ''.join(text).strip(), ''.join(ps).strip() yield ''.join(text).strip(), ''.join(ps).strip()
def p2ii(self, ps): @classmethod
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps))) def infer(
assert input_ids and len(input_ids) <= 510, input_ids cls,
return input_ids model: Optional[KModel],
ps: str,
pack: torch.FloatTensor,
speed: Number
) -> Optional[torch.FloatTensor]:
return model(ps, pack[len(ps)-1], speed) if model else None
def __call__(self, text, voice, speed=1, split_pattern=r'\n+'): def __call__(
assert isinstance(text, str) or isinstance(text, list), type(text) self,
self.load_voice(voice) text: Union[str, List[str]],
voice: str,
speed: Number = 1,
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
pack = self.load_voice(voice)
model = model or self.model
pack = pack.to(model.device) if model else pack
if isinstance(text, str): if isinstance(text, str):
text = re.split(split_pattern, text.strip()) if split_pattern else [text] text = re.split(split_pattern, text.strip()) if split_pattern else [text]
for graphemes in text: for graphemes in text:
@@ -113,16 +153,14 @@ class KPipeline:
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") print(f"⚠️ WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
continue ps = ps[:510]
input_ids = self.p2ii(ps) yield gs, ps, KPipeline.infer(model, ps, pack, speed)
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
else: else:
ps = self.g2p(graphemes) ps = self.g2p(graphemes)
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510') print(f'⚠️ WARNING: Truncating len(ps) == {len(ps)} > 510')
ps = ps[:510] ps = ps[:510]
input_ids = self.p2ii(ps) yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='kokoro', # Name of the package name='kokoro', # Name of the package
version='0.2.3', # Initial version version='0.3.0', # Initial version
packages=find_packages(), # Automatically finds packages packages=find_packages(), # Automatically finds packages
install_requires=[ # List your dependencies here install_requires=[ # List your dependencies here
'huggingface_hub', 'huggingface_hub',