Match misaki==0.8.0 dev branch (#114)

* Match misaki==0.8.0 dev branch

* en_callable, speed callable
This commit is contained in:
hexgrad
2025-02-26 17:30:50 -08:00
committed by GitHub
parent 52f7eb740b
commit efa91a8a3f
4 changed files with 50 additions and 30 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.7.16'
__version__ = '0.8.0'
from loguru import logger
import sys

View File

@@ -3,7 +3,6 @@ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from numbers import Number
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
@@ -24,14 +23,27 @@ class KModel(torch.nn.Module):
so there is no need to repeatedly download config.json outside of KModel.
'''
REPO_ID = 'hexgrad/Kokoro-82M'
MODEL_NAMES = {
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
}
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False):
def __init__(
self,
repo_id: Optional[str] = None,
config: Union[Dict, str, None] = None,
model: Optional[str] = None,
disable_complex: bool = False
):
super().__init__()
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
if not isinstance(config, dict):
if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
config = hf_hub_download(repo_id=repo_id, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
logger.debug(f"Loaded config: {config}")
@@ -52,7 +64,7 @@ class KModel(torch.nn.Module):
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
)
if not model:
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
model = hf_hub_download(repo_id=repo_id, filename=MODEL_NAMES[repo_id])
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
@@ -76,7 +88,7 @@ class KModel(torch.nn.Module):
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: Number = 1
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
input_lengths = torch.full(
(input_ids.shape[0],),
@@ -110,7 +122,7 @@ class KModel(torch.nn.Module):
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: Number = 1,
speed: float = 1,
return_output: bool = False
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
@@ -133,7 +145,7 @@ class KModelForONNX(torch.nn.Module):
self,
input_ids: torch.LongTensor,
ref_s: torch.FloatTensor,
speed: Number = 1
speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]:
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
return waveform

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from misaki import en, espeak
from typing import Generator, List, Optional, Tuple, Union
from typing import Callable, Generator, List, Optional, Tuple, Union
import re
import torch
@@ -63,8 +63,10 @@ class KPipeline:
def __init__(
self,
lang_code: str,
repo_id: Optional[str] = None,
model: Union[KModel, bool] = True,
trf: bool = False,
en_callable: Optional[Callable[[str], str]] = None,
device: Optional[str] = None
):
"""Initialize a KPipeline.
@@ -77,6 +79,10 @@ class KPipeline:
If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error
"""
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
@@ -90,7 +96,7 @@ class KPipeline:
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
try:
self.model = KModel().to(device).eval()
self.model = KModel(repo_id=repo_id).to(device).eval()
except RuntimeError as e:
if device == 'cuda':
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
@@ -115,7 +121,10 @@ class KPipeline:
elif lang_code == 'z':
try:
from misaki import zh
self.g2p = zh.ZHG2P()
self.g2p = zh.ZHG2P(
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
en_callable=en_callable
)
except ImportError:
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise
@@ -130,7 +139,7 @@ class KPipeline:
if voice.endswith('.pt'):
f = voice
else:
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt')
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code)
@@ -157,13 +166,12 @@ class KPipeline:
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice]
@classmethod
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
@staticmethod
def tokens_to_ps(tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@classmethod
@staticmethod
def waterfall_last(
cls,
tokens: List[en.MToken],
next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'],
@@ -176,12 +184,12 @@ class KPipeline:
z += 1
if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
return z
return len(tokens)
@classmethod
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
@staticmethod
def tokens_to_text(tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize(
@@ -212,14 +220,15 @@ class KPipeline:
ps = KPipeline.tokens_to_ps(tks)
yield ''.join(text).strip(), ''.join(ps).strip(), tks
@classmethod
@staticmethod
def infer(
cls,
model: KModel,
ps: str,
pack: torch.FloatTensor,
speed: float = 1
speed: Union[float, Callable[[int], float]] = 1
) -> KModel.Output:
if callable(speed):
speed = speed(len(ps))
return model(ps, pack[len(ps)-1], speed, return_output=True)
def generate_from_tokens(
@@ -272,8 +281,8 @@ class KPipeline:
KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
@classmethod
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
@staticmethod
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80
@@ -343,7 +352,7 @@ class KPipeline:
self,
text: Union[str, List[str]],
voice: Optional[str] = None,
speed: float = 1,
speed: Union[float, Callable[[int], float]] = 1,
split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]:
@@ -412,7 +421,7 @@ class KPipeline:
if not chunk.strip():
continue
ps = self.g2p(chunk)
ps, _ = self.g2p(chunk)
if not ps:
continue
elif len(ps) > 510:
@@ -421,4 +430,3 @@ class KPipeline:
output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "kokoro"
version = "0.7.16"
version = "0.8.0"
description = "TTS"
readme = "README.md"
authors = [
@@ -20,7 +20,7 @@ requires-python = ">=3.10, <3.13"
dependencies = [
"huggingface_hub",
"loguru",
"misaki[en]>=0.7.16",
"misaki[en]>=0.8.0",
"numpy==1.26.4",
"scipy",
"torch",