diff --git a/README.md b/README.md index 2734d53..d19f1c6 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ You can run this cell on [Google Colab](https://colab.research.google.com/). [Li from kokoro import KPipeline from IPython.display import display, Audio import soundfile as sf +import torch # πŸ‡ΊπŸ‡Έ 'a' => American English, πŸ‡¬πŸ‡§ 'b' => British English # πŸ‡―πŸ‡΅ 'j' => Japanese: pip install misaki[ja] # πŸ‡¨πŸ‡³ 'z' => Mandarin Chinese: pip install misaki[zh] @@ -49,6 +50,14 @@ generator = pipeline( text, voice='af_heart', # <= change voice here speed=1, split_pattern=r'\n+' ) + +# Alternatively, load voice tensor directly: +voice_tensor = torch.load('path/to/voice.pt', weights_only=True) +generator = pipeline( + text, voice=voice_tensor, + speed=1, split_pattern=r'\n+' +) + for i, (gs, ps, audio) in enumerate(generator): print(i) # i => index print(gs) # gs => graphemes/text diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index bc099fd..66f4c01 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -146,7 +146,9 @@ class KPipeline: If multiple voices are requested, they are averaged. Delimiter is optional and defaults to ','. """ - def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor: + def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor: + if isinstance(voice, torch.FloatTensor): + return voice if voice in self.voices: return self.voices[voice] logger.debug(f"Loading voice: {voice}")