Allow pipeline to take a voice style tensor directly. (#93)

This commit is contained in:
RobViren
2025-02-15 00:48:08 -06:00
committed by GitHub
parent 1145c0b7f6
commit 330d110c05
2 changed files with 12 additions and 1 deletions

View File

@@ -21,6 +21,7 @@ You can run this cell on [Google Colab](https://colab.research.google.com/). [Li
from kokoro import KPipeline from kokoro import KPipeline
from IPython.display import display, Audio from IPython.display import display, Audio
import soundfile as sf import soundfile as sf
import torch
# 🇺🇸 'a' => American English, 🇬🇧 'b' => British English # 🇺🇸 'a' => American English, 🇬🇧 'b' => British English
# 🇯🇵 'j' => Japanese: pip install misaki[ja] # 🇯🇵 'j' => Japanese: pip install misaki[ja]
# 🇨🇳 'z' => Mandarin Chinese: pip install misaki[zh] # 🇨🇳 'z' => Mandarin Chinese: pip install misaki[zh]
@@ -49,6 +50,14 @@ generator = pipeline(
text, voice='af_heart', # <= change voice here text, voice='af_heart', # <= change voice here
speed=1, split_pattern=r'\n+' speed=1, split_pattern=r'\n+'
) )
# Alternatively, load voice tensor directly:
voice_tensor = torch.load('path/to/voice.pt', weights_only=True)
generator = pipeline(
text, voice=voice_tensor,
speed=1, split_pattern=r'\n+'
)
for i, (gs, ps, audio) in enumerate(generator): for i, (gs, ps, audio) in enumerate(generator):
print(i) # i => index print(i) # i => index
print(gs) # gs => graphemes/text print(gs) # gs => graphemes/text

View File

@@ -146,7 +146,9 @@ class KPipeline:
If multiple voices are requested, they are averaged. If multiple voices are requested, they are averaged.
Delimiter is optional and defaults to ','. Delimiter is optional and defaults to ','.
""" """
def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor: def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
if isinstance(voice, torch.FloatTensor):
return voice
if voice in self.voices: if voice in self.voices:
return self.voices[voice] return self.voices[voice]
logger.debug(f"Loading voice: {voice}") logger.debug(f"Loading voice: {voice}")