Backward compatible KModel.Output and KPipeline.Result dataclasses (#40)
* Backward compatible KModel.Output and KPipeline.Result dataclasses * Typo: Bool => bool * Allow voice=None for quiet pipelines * Specify class names * Fixed and tested * Update README.md
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = '0.3.4'
|
||||
__version__ = '0.3.5'
|
||||
|
||||
from loguru import logger
|
||||
import sys
|
||||
@@ -17,7 +17,7 @@ logger.add(
|
||||
)
|
||||
|
||||
# Disable before release or as needed
|
||||
# logger.disable("kokoro")
|
||||
logger.disable("kokoro")
|
||||
|
||||
from .model import KModel
|
||||
from .pipeline import KPipeline
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from .istftnet import Decoder
|
||||
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from numbers import Number
|
||||
from transformers import AlbertConfig
|
||||
from typing import Dict, Optional, Union
|
||||
from loguru import logger
|
||||
import json
|
||||
import torch
|
||||
|
||||
@@ -65,8 +66,19 @@ class KModel(torch.nn.Module):
|
||||
def device(self):
|
||||
return self.bert.device
|
||||
|
||||
@dataclass
|
||||
class Output:
|
||||
audio: torch.FloatTensor
|
||||
pred_dur: Optional[torch.LongTensor] = None
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, phonemes: str, ref_s: torch.FloatTensor, speed: Number = 1) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self,
|
||||
phonemes: str,
|
||||
ref_s: torch.FloatTensor,
|
||||
speed: Number = 1,
|
||||
return_output: bool = False # MARK: BACKWARD COMPAT
|
||||
) -> Union['KModel.Output', torch.FloatTensor]:
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
|
||||
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
|
||||
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
|
||||
@@ -82,16 +94,15 @@ class KModel(torch.nn.Module):
|
||||
x, _ = self.predictor.lstm(d)
|
||||
duration = self.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
|
||||
logger.debug(f"pred_dur: {pred_dur}")
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||
c_frame += pred_dur[0,i].item()
|
||||
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
|
||||
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
|
||||
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
||||
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
return self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
||||
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
||||
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .model import KModel
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from misaki import en, espeak
|
||||
from numbers import Number
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from loguru import logger
|
||||
import re
|
||||
import torch
|
||||
|
||||
@@ -134,6 +135,7 @@ class KPipeline:
|
||||
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor:
|
||||
if voice in self.voices:
|
||||
return self.voices[voice]
|
||||
logger.debug(f"Loading voice: {voice}")
|
||||
packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
|
||||
if len(packs) == 1:
|
||||
return packs[0]
|
||||
@@ -194,26 +196,52 @@ class KPipeline:
|
||||
@classmethod
|
||||
def infer(
|
||||
cls,
|
||||
model: Optional[KModel],
|
||||
model: KModel,
|
||||
ps: str,
|
||||
pack: torch.FloatTensor,
|
||||
speed: Number
|
||||
) -> Optional[torch.FloatTensor]:
|
||||
return model(ps, pack[len(ps)-1], speed) if model else None
|
||||
speed: Number = 1
|
||||
) -> KModel.Output:
|
||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
graphemes: str
|
||||
phonemes: str
|
||||
output: Optional[KModel.Output] = None
|
||||
|
||||
@property
|
||||
def audio(self) -> Optional[torch.FloatTensor]:
|
||||
return None if self.output is None else self.output.audio
|
||||
|
||||
@property
|
||||
def pred_dur(self) -> Optional[torch.LongTensor]:
|
||||
return None if self.output is None else self.output.pred_dur
|
||||
|
||||
### MARK: BEGIN BACKWARD COMPAT ###
|
||||
def __iter__(self):
|
||||
yield self.graphemes
|
||||
yield self.phonemes
|
||||
yield self.audio
|
||||
|
||||
def __getitem__(self, index):
|
||||
return [self.graphemes, self.phonemes, self.audio][index]
|
||||
|
||||
def __len__(self):
|
||||
return 3
|
||||
#### MARK: END BACKWARD COMPAT ####
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
voice: str,
|
||||
voice: Optional[str] = None,
|
||||
speed: Number = 1,
|
||||
split_pattern: Optional[str] = r'\n+',
|
||||
model: Optional[KModel] = None
|
||||
) -> Generator[Tuple[str, str, Optional[torch.FloatTensor]], None, None]:
|
||||
logger.debug(f"Loading voice: {voice}")
|
||||
pack = self.load_voice(voice)
|
||||
) -> Generator['KPipeline.Result', None, None]:
|
||||
model = model or self.model
|
||||
pack = pack.to(model.device) if model else pack
|
||||
logger.debug(f"Voice loaded on device: {pack.device if hasattr(pack, 'device') else 'N/A'}")
|
||||
if model and voice is None:
|
||||
raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
|
||||
pack = self.load_voice(voice).to(model.device) if model else None
|
||||
if isinstance(text, str):
|
||||
text = re.split(split_pattern, text.strip()) if split_pattern else [text]
|
||||
for graphemes in text:
|
||||
@@ -227,7 +255,8 @@ class KPipeline:
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||
ps = ps[:510]
|
||||
yield gs, ps, KPipeline.infer(model, ps, pack, speed)
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
yield self.Result(graphemes=gs, phonemes=ps, output=output)
|
||||
else:
|
||||
ps = self.g2p(graphemes)
|
||||
if not ps:
|
||||
@@ -235,4 +264,5 @@ class KPipeline:
|
||||
elif len(ps) > 510:
|
||||
logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
|
||||
ps = ps[:510]
|
||||
yield graphemes, ps, KPipeline.infer(model, ps, pack, speed)
|
||||
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||
yield self.Result(graphemes=graphemes, phonemes=ps, output=output)
|
||||
|
||||
Reference in New Issue
Block a user