Match misaki==0.8.0 dev branch (#114)
* Match misaki==0.8.0 dev branch * en_callable, speed callable
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
__version__ = '0.7.16'
|
__version__ = '0.8.0'
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from numbers import Number
|
|
||||||
from transformers import AlbertConfig
|
from transformers import AlbertConfig
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
import json
|
import json
|
||||||
@@ -24,14 +23,27 @@ class KModel(torch.nn.Module):
|
|||||||
so there is no need to repeatedly download config.json outside of KModel.
|
so there is no need to repeatedly download config.json outside of KModel.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
MODEL_NAMES = {
|
||||||
|
'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
|
||||||
|
'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None, disable_complex: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: Optional[str] = None,
|
||||||
|
config: Union[Dict, str, None] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
disable_complex: bool = False
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = 'hexgrad/Kokoro-82M'
|
||||||
|
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||||
|
self.repo_id = repo_id
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
if not config:
|
if not config:
|
||||||
logger.debug("No config provided, downloading from HF")
|
logger.debug("No config provided, downloading from HF")
|
||||||
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
|
config = hf_hub_download(repo_id=repo_id, filename='config.json')
|
||||||
with open(config, 'r', encoding='utf-8') as r:
|
with open(config, 'r', encoding='utf-8') as r:
|
||||||
config = json.load(r)
|
config = json.load(r)
|
||||||
logger.debug(f"Loaded config: {config}")
|
logger.debug(f"Loaded config: {config}")
|
||||||
@@ -52,7 +64,7 @@ class KModel(torch.nn.Module):
|
|||||||
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
|
||||||
)
|
)
|
||||||
if not model:
|
if not model:
|
||||||
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
|
model = hf_hub_download(repo_id=repo_id, filename=MODEL_NAMES[repo_id])
|
||||||
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
|
||||||
assert hasattr(self, key), key
|
assert hasattr(self, key), key
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +88,7 @@ class KModel(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
ref_s: torch.FloatTensor,
|
ref_s: torch.FloatTensor,
|
||||||
speed: Number = 1
|
speed: float = 1
|
||||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
input_lengths = torch.full(
|
input_lengths = torch.full(
|
||||||
(input_ids.shape[0],),
|
(input_ids.shape[0],),
|
||||||
@@ -110,7 +122,7 @@ class KModel(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
phonemes: str,
|
phonemes: str,
|
||||||
ref_s: torch.FloatTensor,
|
ref_s: torch.FloatTensor,
|
||||||
speed: Number = 1,
|
speed: float = 1,
|
||||||
return_output: bool = False
|
return_output: bool = False
|
||||||
) -> Union['KModel.Output', torch.FloatTensor]:
|
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||||
@@ -133,7 +145,7 @@ class KModelForONNX(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
ref_s: torch.FloatTensor,
|
ref_s: torch.FloatTensor,
|
||||||
speed: Number = 1
|
speed: float = 1
|
||||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
return waveform
|
return waveform
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from misaki import en, espeak
|
from misaki import en, espeak
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Callable, Generator, List, Optional, Tuple, Union
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -63,8 +63,10 @@ class KPipeline:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
lang_code: str,
|
lang_code: str,
|
||||||
|
repo_id: Optional[str] = None,
|
||||||
model: Union[KModel, bool] = True,
|
model: Union[KModel, bool] = True,
|
||||||
trf: bool = False,
|
trf: bool = False,
|
||||||
|
en_callable: Optional[Callable[[str], str]] = None,
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""Initialize a KPipeline.
|
"""Initialize a KPipeline.
|
||||||
@@ -77,6 +79,10 @@ class KPipeline:
|
|||||||
If None, will auto-select cuda if available
|
If None, will auto-select cuda if available
|
||||||
If 'cuda' and not available, will explicitly raise an error
|
If 'cuda' and not available, will explicitly raise an error
|
||||||
"""
|
"""
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = 'hexgrad/Kokoro-82M'
|
||||||
|
print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
|
||||||
|
self.repo_id = repo_id
|
||||||
lang_code = lang_code.lower()
|
lang_code = lang_code.lower()
|
||||||
lang_code = ALIASES.get(lang_code, lang_code)
|
lang_code = ALIASES.get(lang_code, lang_code)
|
||||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||||
@@ -90,7 +96,7 @@ class KPipeline:
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
try:
|
try:
|
||||||
self.model = KModel().to(device).eval()
|
self.model = KModel(repo_id=repo_id).to(device).eval()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
|
||||||
@@ -115,7 +121,10 @@ class KPipeline:
|
|||||||
elif lang_code == 'z':
|
elif lang_code == 'z':
|
||||||
try:
|
try:
|
||||||
from misaki import zh
|
from misaki import zh
|
||||||
self.g2p = zh.ZHG2P()
|
self.g2p = zh.ZHG2P(
|
||||||
|
version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
|
||||||
|
en_callable=en_callable
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
|
||||||
raise
|
raise
|
||||||
@@ -130,7 +139,7 @@ class KPipeline:
|
|||||||
if voice.endswith('.pt'):
|
if voice.endswith('.pt'):
|
||||||
f = voice
|
f = voice
|
||||||
else:
|
else:
|
||||||
f = hf_hub_download(repo_id=KModel.REPO_ID, filename=f'voices/{voice}.pt')
|
f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
|
||||||
if not voice.startswith(self.lang_code):
|
if not voice.startswith(self.lang_code):
|
||||||
v = LANG_CODES.get(voice, voice)
|
v = LANG_CODES.get(voice, voice)
|
||||||
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
p = LANG_CODES.get(self.lang_code, self.lang_code)
|
||||||
@@ -157,13 +166,12 @@ class KPipeline:
|
|||||||
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
|
||||||
return self.voices[voice]
|
return self.voices[voice]
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def tokens_to_ps(cls, tokens: List[en.MToken]) -> str:
|
def tokens_to_ps(tokens: List[en.MToken]) -> str:
|
||||||
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def waterfall_last(
|
def waterfall_last(
|
||||||
cls,
|
|
||||||
tokens: List[en.MToken],
|
tokens: List[en.MToken],
|
||||||
next_count: int,
|
next_count: int,
|
||||||
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
waterfall: List[str] = ['!.?…', ':;', ',—'],
|
||||||
@@ -176,12 +184,12 @@ class KPipeline:
|
|||||||
z += 1
|
z += 1
|
||||||
if z < len(tokens) and tokens[z].phonemes in bumps:
|
if z < len(tokens) and tokens[z].phonemes in bumps:
|
||||||
z += 1
|
z += 1
|
||||||
if next_count - len(cls.tokens_to_ps(tokens[:z])) <= 510:
|
if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
|
||||||
return z
|
return z
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def tokens_to_text(cls, tokens: List[en.MToken]) -> str:
|
def tokens_to_text(tokens: List[en.MToken]) -> str:
|
||||||
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
return ''.join(t.text + t.whitespace for t in tokens).strip()
|
||||||
|
|
||||||
def en_tokenize(
|
def en_tokenize(
|
||||||
@@ -212,14 +220,15 @@ class KPipeline:
|
|||||||
ps = KPipeline.tokens_to_ps(tks)
|
ps = KPipeline.tokens_to_ps(tks)
|
||||||
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
yield ''.join(text).strip(), ''.join(ps).strip(), tks
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def infer(
|
def infer(
|
||||||
cls,
|
|
||||||
model: KModel,
|
model: KModel,
|
||||||
ps: str,
|
ps: str,
|
||||||
pack: torch.FloatTensor,
|
pack: torch.FloatTensor,
|
||||||
speed: float = 1
|
speed: Union[float, Callable[[int], float]] = 1
|
||||||
) -> KModel.Output:
|
) -> KModel.Output:
|
||||||
|
if callable(speed):
|
||||||
|
speed = speed(len(ps))
|
||||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||||
|
|
||||||
def generate_from_tokens(
|
def generate_from_tokens(
|
||||||
@@ -272,8 +281,8 @@ class KPipeline:
|
|||||||
KPipeline.join_timestamps(tks, output.pred_dur)
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||||
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||||
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||||
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
# Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
|
||||||
# We will count nice round half-frames, so the divisor is 80
|
# We will count nice round half-frames, so the divisor is 80
|
||||||
@@ -343,7 +352,7 @@ class KPipeline:
|
|||||||
self,
|
self,
|
||||||
text: Union[str, List[str]],
|
text: Union[str, List[str]],
|
||||||
voice: Optional[str] = None,
|
voice: Optional[str] = None,
|
||||||
speed: float = 1,
|
speed: Union[float, Callable[[int], float]] = 1,
|
||||||
split_pattern: Optional[str] = r'\n+',
|
split_pattern: Optional[str] = r'\n+',
|
||||||
model: Optional[KModel] = None
|
model: Optional[KModel] = None
|
||||||
) -> Generator['KPipeline.Result', None, None]:
|
) -> Generator['KPipeline.Result', None, None]:
|
||||||
@@ -412,7 +421,7 @@ class KPipeline:
|
|||||||
if not chunk.strip():
|
if not chunk.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ps = self.g2p(chunk)
|
ps, _ = self.g2p(chunk)
|
||||||
if not ps:
|
if not ps:
|
||||||
continue
|
continue
|
||||||
elif len(ps) > 510:
|
elif len(ps) > 510:
|
||||||
@@ -421,4 +430,3 @@ class KPipeline:
|
|||||||
|
|
||||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|
yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "kokoro"
|
name = "kokoro"
|
||||||
version = "0.7.16"
|
version = "0.8.0"
|
||||||
description = "TTS"
|
description = "TTS"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@@ -20,7 +20,7 @@ requires-python = ">=3.10, <3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"loguru",
|
"loguru",
|
||||||
"misaki[en]>=0.7.16",
|
"misaki[en]>=0.8.0",
|
||||||
"numpy==1.26.4",
|
"numpy==1.26.4",
|
||||||
"scipy",
|
"scipy",
|
||||||
"torch",
|
"torch",
|
||||||
|
|||||||
Reference in New Issue
Block a user