Add generate_from_tokens method, example (#53)
This commit is contained in:
62
examples/phoneme_example.py
Normal file
62
examples/phoneme_example.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from kokoro import KPipeline, KModel
|
||||||
|
import torch
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
def save_audio(audio: torch.Tensor, filename: str):
|
||||||
|
"""Helper function to save audio tensor as WAV file"""
|
||||||
|
if audio is not None:
|
||||||
|
# Ensure audio is on CPU and in the right format
|
||||||
|
audio_cpu = audio.cpu().numpy()
|
||||||
|
|
||||||
|
# Save using scipy.io.wavfile
|
||||||
|
wavfile.write(
|
||||||
|
filename,
|
||||||
|
24000, # Kokoro uses 24kHz sample rate
|
||||||
|
audio_cpu
|
||||||
|
)
|
||||||
|
print(f"Audio saved as '{filename}'")
|
||||||
|
else:
|
||||||
|
print("No audio was generated")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize pipeline with American English
|
||||||
|
pipeline = KPipeline(lang_code='a')
|
||||||
|
|
||||||
|
# The phoneme string for:
|
||||||
|
# "How are you today? I am doing reasonably well, thank you for asking"
|
||||||
|
phonemes = "hˌW ɑɹ ju tədˈA? ˌI ɐm dˈuɪŋ ɹˈizənəbli wˈɛl, θˈæŋk ju fɔɹ ˈæskɪŋ"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nExample 1: Using generate_from_tokens with raw phonemes")
|
||||||
|
results = list(pipeline.generate_from_tokens(
|
||||||
|
tokens=phonemes,
|
||||||
|
voice="af_bella",
|
||||||
|
speed=1.0
|
||||||
|
))
|
||||||
|
if results:
|
||||||
|
save_audio(results[0].audio, 'phoneme_output_new.wav')
|
||||||
|
|
||||||
|
# Example 2: Using generate_from_tokens with pre-processed tokens
|
||||||
|
print("\nExample 2: Using generate_from_tokens with pre-processed tokens")
|
||||||
|
# get the tokens through G2P or any other method
|
||||||
|
text = "How are you today? I am doing reasonably well, thank you for asking"
|
||||||
|
_, tokens = pipeline.g2p(text)
|
||||||
|
|
||||||
|
# Then generate from tokens
|
||||||
|
for result in pipeline.generate_from_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
voice="af_bella",
|
||||||
|
speed=1.0
|
||||||
|
):
|
||||||
|
# Each result may contain timestamps if available
|
||||||
|
if result.tokens:
|
||||||
|
for token in result.tokens:
|
||||||
|
if hasattr(token, 'start_ts') and hasattr(token, 'end_ts'):
|
||||||
|
print(f"Token: {token.text} ({token.start_ts:.2f}s - {token.end_ts:.2f}s)")
|
||||||
|
save_audio(result.audio, f'token_output_{hash(result.phonemes)}.wav')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -221,6 +221,56 @@ class KPipeline:
|
|||||||
) -> KModel.Output:
|
) -> KModel.Output:
|
||||||
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
return model(ps, pack[len(ps)-1], speed, return_output=True)
|
||||||
|
|
||||||
|
def generate_from_tokens(
|
||||||
|
self,
|
||||||
|
tokens: Union[str, List[en.MToken]],
|
||||||
|
voice: str,
|
||||||
|
speed: Number = 1,
|
||||||
|
model: Optional[KModel] = None
|
||||||
|
) -> Generator['KPipeline.Result', None, None]:
|
||||||
|
"""Generate audio from either raw phonemes or pre-processed tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Either a phoneme string or list of pre-processed MTokens
|
||||||
|
voice: The voice to use for synthesis
|
||||||
|
speed: Speech speed modifier (default: 1)
|
||||||
|
model: Optional KModel instance (uses pipeline's model if not provided)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
KPipeline.Result containing the input tokens and generated audio
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no voice is provided or token sequence exceeds model limits
|
||||||
|
"""
|
||||||
|
model = model or self.model
|
||||||
|
if model and voice is None:
|
||||||
|
raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
|
||||||
|
|
||||||
|
pack = self.load_voice(voice).to(model.device) if model else None
|
||||||
|
|
||||||
|
# Handle raw phoneme string
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
logger.debug("Processing phonemes from raw string")
|
||||||
|
if len(tokens) > 510:
|
||||||
|
raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
|
||||||
|
output = KPipeline.infer(model, tokens, pack, speed) if model else None
|
||||||
|
yield self.Result(graphemes='', phonemes=tokens, output=output)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("Processing MTokens")
|
||||||
|
# Handle pre-processed tokens
|
||||||
|
for gs, ps, tks in self.en_tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
elif len(ps) > 510:
|
||||||
|
logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
|
||||||
|
logger.warning("Truncating to 510 characters")
|
||||||
|
ps = ps[:510]
|
||||||
|
output = KPipeline.infer(model, ps, pack, speed) if model else None
|
||||||
|
if output is not None and output.pred_dur is not None:
|
||||||
|
KPipeline.join_timestamps(tks, output.pred_dur)
|
||||||
|
yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor):
|
||||||
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
# Multiply by 600 to go from pred_dur frames to sample_rate 24000
|
||||||
|
|||||||
Reference in New Issue
Block a user