Refactor (#16)
* Refactor * Bump to 0.2.4 * Fix typo * Add missing @classmethod * Simplify REPO_ID * Use explicit class names * Fix input_lengths typo * Read config with utf-8 encoding, issue #18
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = '0.2.3'
|
||||
__version__ = '0.3.0'
|
||||
|
||||
from .models import KModel
|
||||
from .model import KModel
|
||||
from .pipeline import KPipeline
|
||||
|
||||
91
kokoro/model.py
Normal file
91
kokoro/model.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from .istftnet import Decoder
|
||||
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||
from huggingface_hub import hf_hub_download
|
||||
from numbers import Number
|
||||
from transformers import AlbertConfig
|
||||
from typing import Dict, Optional, Union
|
||||
import json
|
||||
import torch
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
'''
|
||||
KModel is a torch.nn.Module with 2 main responsibilities:
|
||||
1. Init weights, downloading config.json + model.pth from HF if needed
|
||||
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
|
||||
|
||||
You likely only need one KModel instance, and it can be reused across
|
||||
multiple KPipelines to avoid redundant memory allocation.
|
||||
|
||||
Unlike KPipeline, KModel is language-blind.
|
||||
|
||||
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
|
||||
so there is no need to repeatedly download config.json outside of KModel.
|
||||
'''
|
||||
|
||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||
|
||||
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None):
|
||||
super().__init__()
|
||||
if not isinstance(config, dict):
|
||||
if not config:
|
||||
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
|
||||
with open(config, 'r', encoding='utf-8') as r:
|
||||
config = json.load(r)
|
||||
self.vocab = config['vocab']
|
||||
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||
self.context_length = self.bert.config.max_position_embeddings
|
||||
self.predictor = ProsodyPredictor(
|
||||
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
|
||||
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
|
||||
)
|
||||
self.text_encoder = TextEncoder(
|
||||
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
|
||||
depth=config['n_layer'], n_symbols=config['n_token']
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||
dim_out=config['n_mels'], **config['istftnet']
|
||||
)
|
||||
if not model:
|
||||
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
|
||||
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
||||
assert hasattr(self, key), key
|
||||
try:
|
||||
getattr(self, key).load_state_dict(state_dict)
|
||||
except:
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.bert.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
|
||||
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
|
||||
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
|
||||
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
|
||||
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
ref_s = ref_s.to(self.device)
|
||||
s = ref_s[:, 128:]
|
||||
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = self.predictor.lstm(d)
|
||||
duration = self.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||
c_frame += pred_dur[0,i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
||||
@@ -1,7 +1,7 @@
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||
from .istftnet import AdaIN1d, AdainResBlk1d, Decoder
|
||||
from .istftnet import AdainResBlk1d
|
||||
from torch.nn.utils import weight_norm
|
||||
from transformers import AlbertConfig, AlbertModel
|
||||
from transformers import AlbertModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -182,60 +182,3 @@ class CustomAlbert(AlbertModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
outputs = super().forward(*args, **kwargs)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
class KModel(nn.Module):
|
||||
def __init__(self, config, path):
|
||||
super().__init__()
|
||||
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
|
||||
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
|
||||
self.predictor = ProsodyPredictor(
|
||||
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
|
||||
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
|
||||
)
|
||||
self.text_encoder = TextEncoder(
|
||||
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
|
||||
depth=config['n_layer'], n_symbols=config['n_token']
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
|
||||
dim_out=config['n_mels'], **config['istftnet']
|
||||
)
|
||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True).items():
|
||||
assert hasattr(self, key), key
|
||||
try:
|
||||
getattr(self, key).load_state_dict(state_dict)
|
||||
except:
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
getattr(self, key).load_state_dict(state_dict, strict=False)
|
||||
|
||||
@classmethod
|
||||
def length_to_mask(cls, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids, ref_s, speed):
|
||||
device = ref_s.device
|
||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(device)
|
||||
text_mask = type(self).length_to_mask(input_lengths).to(device)
|
||||
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
|
||||
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = self.predictor.lstm(d)
|
||||
duration = self.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||
c_frame += pred_dur[0,i].item()
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
@@ -1,8 +1,8 @@
|
||||
from .models import KModel
|
||||
from .model import KModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
from misaki import en, espeak
|
||||
import json
|
||||
import os
|
||||
from numbers import Number
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
import re
|
||||
import torch
|
||||
|
||||
@@ -15,50 +15,74 @@ LANG_CODES = dict(
|
||||
i='it',
|
||||
p='pt-br',
|
||||
)
|
||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||
|
||||
class KPipeline:
|
||||
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None):
|
||||
'''
|
||||
KPipeline is a language-aware support class with 2 main responsibilities:
|
||||
1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
|
||||
2. Manage and store voices, lazily downloaded from HF if needed
|
||||
|
||||
You are expected to have one KPipeline per language. If you have multiple
|
||||
KPipelines, you should reuse one KModel instance across all of them.
|
||||
|
||||
KPipeline is designed to work with a KModel, but this is not required.
|
||||
There are 2 ways to pass an existing model into a pipeline:
|
||||
1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
|
||||
2. On call: us_pipeline(text, voice, model=model)
|
||||
|
||||
By default, KPipeline will automatically initialize its own KModel. To
|
||||
suppress this, construct a "quiet" KPipeline with model=False.
|
||||
|
||||
A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
|
||||
any audio. You can use this to phonemize and chunk your text in advance.
|
||||
|
||||
A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
|
||||
'''
|
||||
def __init__(self, lang_code: str, model: Union[KModel, bool] = True, trf: bool = False):
|
||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||
self.lang_code = lang_code
|
||||
if config_path is None:
|
||||
config_path = hf_hub_download(repo_id=REPO_ID, filename='config.json')
|
||||
assert os.path.exists(config_path)
|
||||
with open(config_path, 'r') as r:
|
||||
config = json.load(r)
|
||||
if model_path is None:
|
||||
model_path = hf_hub_download(repo_id=REPO_ID, filename='kokoro-v1_0.pth')
|
||||
assert os.path.exists(model_path)
|
||||
self.vocab = config['vocab']
|
||||
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
self.model = KModel(config, model_path).to(self.device).eval()
|
||||
self.model = None
|
||||
if isinstance(model, KModel):
|
||||
self.model = model
|
||||
elif model:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.model = KModel().to(device).eval()
|
||||
self.voices = {}
|
||||
if lang_code in 'ab':
|
||||
try:
|
||||
fallback = espeak.EspeakFallback(british=lang_code=='b')
|
||||
except Exception as e:
|
||||
print('WARNING: EspeakFallback not enabled. Out-of-dictionary words will be skipped.', e)
|
||||
print('⚠️ WARNING: EspeakFallback not enabled. OOD words will be skipped.', e)
|
||||
fallback = None
|
||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||
else:
|
||||
language = LANG_CODES[lang_code]
|
||||
print(f"WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||
print(f"⚠️ WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
|
||||
self.g2p = espeak.EspeakG2P(language=language)
|
||||
|
||||
def load_voice(self, voice):
|
||||
def load_voice(self, voice: str) -> torch.FloatTensor:
|
||||
if voice in self.voices:
|
||||
return
|
||||
v = voice.split('/')[-1]
|
||||
if not v.startswith(self.lang_code):
|
||||
v = LANG_CODES.get(v, voice)
|
||||
return self.voices[voice]
|
||||
if voice.endswith('.pt'):
|
||||
f = voice
|
||||
else:
|
||||
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt')
|
||||
if not voice.startswith(self.lang_code):
|
||||
v = LANG_CODES.get(voice, voice)
|
||||
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
|
||||
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
|
||||
assert os.path.exists(voice_path)
|
||||
self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device)
|
||||
print(f'⚠️ WARNING: Language mismatch, loading {v} voice into {p} pipeline.')
|
||||
pack = torch.load(f, weights_only=True)
|
||||
self.voices[voice] = pack
|
||||
return pack
|
||||
|
||||
@classmethod
|
||||
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}):
|
||||
def waterfall_last(
|
||||
cls,
|
||||
pairs: List[Tuple[str, str]],
|
||||
next_count: int,
|
||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||
bumps: List[str] = [')', '”']
|
||||
) -> int:
|
||||
for w in waterfall:
|
||||
z = next((i for i, (_, ps) in reversed(list(enumerate(pairs))) if ps.strip() in set(w)), None)
|
||||
if z is not None:
|
||||
@@ -70,7 +94,10 @@ class KPipeline:
|
||||
return z
|
||||
return len(pairs)
|
||||
|
||||
def en_tokenize(self, tokens):
|
||||
def en_tokenize(
|
||||
self,
|
||||
tokens: List[Union[en.MutableToken, List[en.MutableToken]]]
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
pairs = []
|
||||
count = 0
|
||||
for w in tokens:
|
||||
@@ -78,11 +105,11 @@ class KPipeline:
|
||||
if t.phonemes is None:
|
||||
continue
|
||||
next_ps = ' ' if t.prespace and pairs and not pairs[-1][1].endswith(' ') and t.phonemes else ''
|
||||
next_ps += ''.join(filter(lambda p: p in self.vocab, t.phonemes.replace('ɾ', 'T'))) # American English: ɾ => T
|
||||
next_ps += t.phonemes.replace('ɾ', 'T') # American English: ɾ => T
|
||||
next_ps += ' ' if t.whitespace else ''
|
||||
next_count = count + len(next_ps.rstrip())
|
||||
if next_count > 510:
|
||||
z = type(self).waterfall_last(pairs, next_count)
|
||||
z = KPipeline.waterfall_last(pairs, next_count)
|
||||
text, ps = zip(*pairs[:z])
|
||||
ps = ''.join(ps)
|
||||
yield ''.join(text).strip(), ps.strip()
|
||||
@@ -96,14 +123,27 @@ class KPipeline:
|
||||
text, ps = zip(*pairs)
|
||||
yield ''.join(text).strip(), ''.join(ps).strip()
|
||||
|
||||
def p2ii(self, ps):
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), ps)))
|
||||
assert input_ids and len(input_ids) <= 510, input_ids
|
||||
return input_ids
|
||||
@classmethod
|
||||
def infer(
|
||||
cls,
|
||||
model: Optional[KModel],
|
||||
ps: str,
|
||||
pack: torch.FloatTensor,
|
||||
speed: Number
|
||||
) -> Optional[torch.FloatTensor]:
|
||||
return model(ps, pack[len(ps)-1], speed) if model else None
|
||||
|
||||
def __call__(self, text, voice, speed=1, split_pattern=r'\n+'):
|
||||
assert isinstance(text, str) or isinstance(text, list), type(text)
|
||||
self.load_voice(voice)
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
voice: str,
|
||||
speed: Number = 1,
|
||||
split_pattern: Optional[str] = r'\n+',
|
||||
model: Optional[KModel] = None
|
||||
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
|
||||
pack = self.load_voice(voice)
|
||||
model = model or self.model
|
||||
pack = pack.to(model.device) if model else pack
|
||||
if isinstance(text, str):
|
||||
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||
for graphemes in text:
|
||||
@@ -113,16 +153,14 @@ class KPipeline:
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
continue
|
||||
input_ids = self.p2ii(ps)
|
||||
yield gs, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
||||
print(f"⚠️ WARNING: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
ps = ps[:510]
|
||||
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
|
||||
else:
|
||||
ps = self.g2p(graphemes)
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
print(f'WARNING: Truncating len(ps) == {len(ps)} > 510')
|
||||
print(f'⚠️ WARNING: Truncating len(ps) == {len(ps)} > 510')
|
||||
ps = ps[:510]
|
||||
input_ids = self.p2ii(ps)
|
||||
yield graphemes, ps, self.model(input_ids, self.voices[voice][len(input_ids)-1], speed)
|
||||
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='kokoro', # Name of the package
|
||||
version='0.2.3', # Initial version
|
||||
version='0.3.0', # Initial version
|
||||
packages=find_packages(), # Automatically finds packages
|
||||
install_requires=[ # List your dependencies here
|
||||
'huggingface_hub',
|
||||
|
||||
Reference in New Issue
Block a user