Match misaki==0.8.0 dev branch (#114)
* Match misaki==0.8.0 dev branch * en_callable, speed callable
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = '0.7.16'
|
||||
__version__ = '0.8.0'
|
||||
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
@@ -3,7 +3,6 @@ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from numbers import Number
|
||||
from transformers import AlbertConfig
|
||||
from typing import Dict, Optional, Union
|
||||
import json
|
||||
@@ -24,14 +23,27 @@ class KModel(torch.nn.Module):
|
||||
so there is no need to repeatedly download config.json outside of KModel.
|
||||
'''
|
||||
|
||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||
MODEL_NAMES = {
|
||||
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
|
||||
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
|
||||
}
|
||||
|
||||
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: Optional[str] = None,
|
||||
config: Union[Dict, str, None] = None,
|
||||
model: Optional[str] = None,
|
||||
disable_complex: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
if repo_id is None:
|
||||
repo_id = 'hexgrad/Kokoro-82M'
|
||||
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||
self.repo_id = repo_id
|
||||
if not isinstance(config, dict):
|
||||
if not config:
|
||||
logger.debug("No config provided, downloading from HF")
|
||||
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
|
||||
config = hf_hub_download(repo_id=repo_id, filename='config.json')
|
||||
with open(config, 'r', encoding='utf-8') as r:
|
||||
config = json.load(r)
|
||||
logger.debug(f"Loaded config: {config}")
|
||||
@@ -52,7 +64,7 @@ class KModel(torch.nn.Module):
|
||||
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
||||
)
|
||||
if not model:
|
||||
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
|
||||
model = hf_hub_download(repo_id=repo_id, filename=MODEL_NAMES[repo_id])
|
||||
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
||||
assert hasattr(self, key), key
|
||||
try:
|
||||
@@ -76,7 +88,7 @@ class KModel(torch.nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: Number = 1
|
||||
speed: float = 1
|
||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
input_lengths = torch.full(
|
||||
(input_ids.shape[0],),
|
||||
@@ -110,7 +122,7 @@ class KModel(torch.nn.Module):
|
||||
self,
|
||||
phonemes: str,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: Number = 1,
|
||||
speed: float = 1,
|
||||
return_output: bool = False
|
||||
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||
@@ -133,7 +145,7 @@ class KModelForONNX(torch.nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: Number = 1
|
||||
speed: float = 1
|
||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||
return waveform
|
||||
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from misaki import en, espeak
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||
import re
|
||||
import torch
|
||||
|
||||
@@ -63,8 +63,10 @@ class KPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
lang_code: str,
|
||||
repo_id: Optional[str] = None,
|
||||
model: Union[KModel, bool] = True,
|
||||
trf: bool = False,
|
||||
en_callable: Optional[Callable[[str], str]] = None,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""Initialize a KPipeline.
|
||||
@@ -77,6 +79,10 @@ class KPipeline:
|
||||
If None, will auto-select cuda if available
|
||||
If 'cuda' and not available, will explicitly raise an error
|
||||
"""
|
||||
if repo_id is None:
|
||||
repo_id = 'hexgrad/Kokoro-82M'
|
||||
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||
self.repo_id = repo_id
|
||||
lang_code = lang_code.lower()
|
||||
lang_code = ALIASES.get(lang_code, lang_code)
|
||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||
@@ -90,7 +96,7 @@ class KPipeline:
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
try:
|
||||
self.model = KModel().to(device).eval()
|
||||
self.model = KModel(repo_id=repo_id).to(device).eval()
|
||||
except RuntimeError as e:
|
||||
if device == 'cuda':
|
||||
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
||||
@@ -115,7 +121,10 @@ class KPipeline:
|
||||
elif lang_code == 'z':
|
||||
try:
|
||||
from misaki import zh
|
||||
self.g2p = zh.ZHG2P()
|
||||
self.g2p = zh.ZHG2P(
|
||||
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
|
||||
en_callable=en_callable
|
||||
)
|
||||
except ImportError:
|
||||
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
||||
raise
|
||||
@@ -130,7 +139,7 @@ class KPipeline:
|
||||
if voice.endswith('.pt'):
|
||||
f = voice
|
||||
else:
|
||||
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt')
|
||||
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
|
||||
if not voice.startswith(self.lang_code):
|
||||
v = LANG_CODES.get(voice, voice)
|
||||
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||
@@ -157,13 +166,12 @@ class KPipeline:
|
||||
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||
return self.voices[voice]
|
||||
|
||||
@classmethod
|
||||
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
|
||||
@staticmethod
|
||||
def tokens_to_ps(tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def waterfall_last(
|
||||
cls,
|
||||
tokens: List[en.MToken],
|
||||
next_count: int,
|
||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||
@@ -176,12 +184,12 @@ class KPipeline:
|
||||
z += 1
|
||||
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||
z += 1
|
||||
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
|
||||
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
|
||||
return z
|
||||
return len(tokens)
|
||||
|
||||
@classmethod
|
||||
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
|
||||
@staticmethod
|
||||
def tokens_to_text(tokens: List[en.MToken]) -> str:
|
||||
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||
|
||||
def en_tokenize(
|
||||
@@ -212,14 +220,15 @@ class KPipeline:
|
||||
ps = KPipeline.tokens_to_ps(tks)
|
||||
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def infer(
|
||||
cls,
|
||||
model: KModel,
|
||||
ps: str,
|
||||
pack: torch.FloatTensor,
|
||||
speed: float = 1
|
||||
speed: Union[float, Callable[[int], float]] = 1
|
||||
) -> KModel.Output:
|
||||
if callable(speed):
|
||||
speed = speed(len(ps))
|
||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||
|
||||
def generate_from_tokens(
|
||||
@@ -272,8 +281,8 @@ class KPipeline:
|
||||
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||
|
||||
@classmethod
|
||||
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||
@staticmethod
|
||||
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||
# We will count nice round half-frames, so the divisor is 80
|
||||
@@ -343,7 +352,7 @@ class KPipeline:
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
voice: Optional[str] = None,
|
||||
speed: float = 1,
|
||||
speed: Union[float, Callable[[int], float]] = 1,
|
||||
split_pattern: Optional[str] = r'\n+',
|
||||
model: Optional[KModel] = None
|
||||
) -> Generator['KPipeline.Result', None, None]:
|
||||
@@ -412,7 +421,7 @@ class KPipeline:
|
||||
if not chunk.strip():
|
||||
continue
|
||||
|
||||
ps = self.g2p(chunk)
|
||||
ps, _ = self.g2p(chunk)
|
||||
if not ps:
|
||||
continue
|
||||
elif len(ps) > 510:
|
||||
@@ -421,4 +430,3 @@ class KPipeline:
|
||||
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "kokoro"
|
||||
version = "0.7.16"
|
||||
version = "0.8.0"
|
||||
description = "TTS"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -20,7 +20,7 @@ requires-python = ">=3.10, <3.13"
|
||||
dependencies = [
|
||||
"huggingface_hub",
|
||||
"loguru",
|
||||
"misaki[en]>=0.7.16",
|
||||
"misaki[en]>=0.8.0",
|
||||
"numpy==1.26.4",
|
||||
"scipy",
|
||||
"torch",
|
||||
|
||||
Reference in New Issue
Block a user