Match misaki==0.8.0 dev branch (#114)

* Match misaki==0.8.0 dev branch

* en_callable, speed callable
This commit is contained in:
hexgrad
2025-02-26 17:30:50 -08:00
committed by GitHub
parent 52f7eb740b
commit efa91a8a3f
4 changed files with 50 additions and 30 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.7.16' __version__ = '0.8.0'
from loguru import logger from loguru import logger
import sys import sys

View File

@@ -3,7 +3,6 @@ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass from dataclasses import dataclass
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from loguru import logger from loguru import logger
from numbers import Number
from transformers import AlbertConfig from transformers import AlbertConfig
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import json import json
@@ -24,14 +23,27 @@ class KModel(torch.nn.Module):
so there is no need to repeatedly download config.json outside of KModel. so there is no need to repeatedly download config.json outside of KModel.
''' '''
REPO_ID = 'hexgrad/Kokoro-82M' MODEL_NAMES = {
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
}
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False): def __init__(
self,
repo_id: Optional[str] = None,
config: Union[Dict, str, None] = None,
model: Optional[str] = None,
disable_complex: bool = False
):
super().__init__() super().__init__()
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
if not isinstance(config, dict): if not isinstance(config, dict):
if not config: if not config:
logger.debug("No config provided, downloading from HF") logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json') config = hf_hub_download(repo_id=repo_id, filename='config.json')
with open(config, 'r', encoding='utf-8') as r: with open(config, 'r', encoding='utf-8') as r:
config = json.load(r) config = json.load(r)
logger.debug(f"Loaded config: {config}") logger.debug(f"Loaded config: {config}")
@@ -52,7 +64,7 @@ class KModel(torch.nn.Module):
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
) )
if not model: if not model:
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth') model = hf_hub_download(repo_id=repo_id, filename=MODEL_NAMES[repo_id])
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key assert hasattr(self, key), key
try: try:
@@ -76,7 +88,7 @@ class KModel(torch.nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
ref_s: torch.FloatTensor, ref_s: torch.FloatTensor,
speed: Number = 1 speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]: ) -> tuple[torch.FloatTensor, torch.LongTensor]:
input_lengths = torch.full( input_lengths = torch.full(
(input_ids.shape[0],), (input_ids.shape[0],),
@@ -110,7 +122,7 @@ class KModel(torch.nn.Module):
self, self,
phonemes: str, phonemes: str,
ref_s: torch.FloatTensor, ref_s: torch.FloatTensor,
speed: Number = 1, speed: float = 1,
return_output: bool = False return_output: bool = False
) -> Union['KModel.Output', torch.FloatTensor]: ) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
@@ -133,7 +145,7 @@ class KModelForONNX(torch.nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
ref_s: torch.FloatTensor, ref_s: torch.FloatTensor,
speed: Number = 1 speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]: ) -> tuple[torch.FloatTensor, torch.LongTensor]:
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
return waveform return waveform

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from loguru import logger from loguru import logger
from misaki import en, espeak from misaki import en, espeak
from typing import Generator, List, Optional, Tuple, Union from typing import Callable, Generator, List, Optional, Tuple, Union
import re import re
import torch import torch
@@ -63,8 +63,10 @@ class KPipeline:
def __init__( def __init__(
self, self,
lang_code: str, lang_code: str,
repo_id: Optional[str] = None,
model: Union[KModel, bool] = True, model: Union[KModel, bool] = True,
trf: bool = False, trf: bool = False,
en_callable: Optional[Callable[[str], str]] = None,
device: Optional[str] = None device: Optional[str] = None
): ):
"""Initialize a KPipeline. """Initialize a KPipeline.
@@ -77,6 +79,10 @@ class KPipeline:
If None, will auto-select cuda if available If None, will auto-select cuda if available
If 'cuda' and not available, will explicitly raise an error If 'cuda' and not available, will explicitly raise an error
""" """
if repo_id is None:
repo_id = 'hexgrad/Kokoro-82M'
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
self.repo_id = repo_id
lang_code = lang_code.lower() lang_code = lang_code.lower()
lang_code = ALIASES.get(lang_code, lang_code) lang_code = ALIASES.get(lang_code, lang_code)
assert lang_code in LANG_CODES, (lang_code, LANG_CODES) assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
@@ -90,7 +96,7 @@ class KPipeline:
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
try: try:
self.model = KModel().to(device).eval() self.model = KModel(repo_id=repo_id).to(device).eval()
except RuntimeError as e: except RuntimeError as e:
if device == 'cuda': if device == 'cuda':
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
@@ -115,7 +121,10 @@ class KPipeline:
elif lang_code == 'z': elif lang_code == 'z':
try: try:
from misaki import zh from misaki import zh
self.g2p = zh.ZHG2P() self.g2p = zh.ZHG2P(
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
en_callable=en_callable
)
except ImportError: except ImportError:
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'") logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
raise raise
@@ -130,7 +139,7 @@ class KPipeline:
if voice.endswith('.pt'): if voice.endswith('.pt'):
f = voice f = voice
else: else:
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt') f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
if not voice.startswith(self.lang_code): if not voice.startswith(self.lang_code):
v = LANG_CODES.get(voice, voice) v = LANG_CODES.get(voice, voice)
p = LANG_CODES.get(self.lang_code, self.lang_code) p = LANG_CODES.get(self.lang_code, self.lang_code)
@@ -157,13 +166,12 @@ class KPipeline:
self.voices[voice] = torch.mean(torch.stack(packs), dim=0) self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
return self.voices[voice] return self.voices[voice]
@classmethod @staticmethod
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str: def tokens_to_ps(tokens: List[en.MToken]) -> str:
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip() return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
@classmethod @staticmethod
def waterfall_last( def waterfall_last(
cls,
tokens: List[en.MToken], tokens: List[en.MToken],
next_count: int, next_count: int,
waterfall: List[str] = ['!.?…', ':;', ',—'], waterfall: List[str] = ['!.?…', ':;', ',—'],
@@ -176,12 +184,12 @@ class KPipeline:
z += 1 z += 1
if z < len(tokens) and tokens[z].phonemes in bumps: if z < len(tokens) and tokens[z].phonemes in bumps:
z += 1 z += 1
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510: if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
return z return z
return len(tokens) return len(tokens)
@classmethod @staticmethod
def tokens_to_text(cls, tokens: List[en.MToken]) -> str: def tokens_to_text(tokens: List[en.MToken]) -> str:
return ''.join(t.text + t.whitespace for t in tokens).strip() return ''.join(t.text + t.whitespace for t in tokens).strip()
def en_tokenize( def en_tokenize(
@@ -212,14 +220,15 @@ class KPipeline:
ps = KPipeline.tokens_to_ps(tks) ps = KPipeline.tokens_to_ps(tks)
yield ''.join(text).strip(), ''.join(ps).strip(), tks yield ''.join(text).strip(), ''.join(ps).strip(), tks
@classmethod @staticmethod
def infer( def infer(
cls,
model: KModel, model: KModel,
ps: str, ps: str,
pack: torch.FloatTensor, pack: torch.FloatTensor,
speed: float = 1 speed: Union[float, Callable[[int], float]] = 1
) -> KModel.Output: ) -> KModel.Output:
if callable(speed):
speed = speed(len(ps))
return model(ps, pack[len(ps)-1], speed, return_output=True) return model(ps, pack[len(ps)-1], speed, return_output=True)
def generate_from_tokens( def generate_from_tokens(
@@ -272,8 +281,8 @@ class KPipeline:
KPipeline.join_timestamps(tks, output.pred_dur) KPipeline.join_timestamps(tks, output.pred_dur)
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output) yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
@classmethod @staticmethod
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor): def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
# Multiply by 600 to go from pred_dur frames to sample_rate 24000 # Multiply by 600 to go from pred_dur frames to sample_rate 24000
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
# We will count nice round half-frames, so the divisor is 80 # We will count nice round half-frames, so the divisor is 80
@@ -343,7 +352,7 @@ class KPipeline:
self, self,
text: Union[str, List[str]], text: Union[str, List[str]],
voice: Optional[str] = None, voice: Optional[str] = None,
speed: float = 1, speed: Union[float, Callable[[int], float]] = 1,
split_pattern: Optional[str] = r'\n+', split_pattern: Optional[str] = r'\n+',
model: Optional[KModel] = None model: Optional[KModel] = None
) -> Generator['KPipeline.Result', None, None]: ) -> Generator['KPipeline.Result', None, None]:
@@ -412,7 +421,7 @@ class KPipeline:
if not chunk.strip(): if not chunk.strip():
continue continue
ps = self.g2p(chunk) ps, _ = self.g2p(chunk)
if not ps: if not ps:
continue continue
elif len(ps) > 510: elif len(ps) > 510:
@@ -421,4 +430,3 @@ class KPipeline:
output = KPipeline.infer(model, ps, pack, speed) if model else None output = KPipeline.infer(model, ps, pack, speed) if model else None
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index) yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "kokoro" name = "kokoro"
version = "0.7.16" version = "0.8.0"
description = "TTS" description = "TTS"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -20,7 +20,7 @@ requires-python = ">=3.10, <3.13"
dependencies = [ dependencies = [
"huggingface_hub", "huggingface_hub",
"loguru", "loguru",
"misaki[en]>=0.7.16", "misaki[en]>=0.8.0",
"numpy==1.26.4", "numpy==1.26.4",
"scipy", "scipy",
"torch", "torch",