From 8cec8005b360b513d3b960c51de49ec0fedd6b6d Mon Sep 17 00:00:00 2001 From: remsky Date: Wed, 5 Feb 2025 00:18:53 -0700 Subject: [PATCH] Add generate_from_tokens method, example (#53) --- examples/phoneme_example.py | 62 +++++++++++++++++++++++++++++++++++++ kokoro/pipeline.py | 50 ++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 examples/phoneme_example.py diff --git a/examples/phoneme_example.py b/examples/phoneme_example.py new file mode 100644 index 0000000..a141fc2 --- /dev/null +++ b/examples/phoneme_example.py @@ -0,0 +1,62 @@ +from kokoro import KPipeline, KModel +import torch +from scipy.io import wavfile + +def save_audio(audio: torch.Tensor, filename: str): + """Helper function to save audio tensor as WAV file""" + if audio is not None: + # Ensure audio is on CPU and in the right format + audio_cpu = audio.cpu().numpy() + + # Save using scipy.io.wavfile + wavfile.write( + filename, + 24000, # Kokoro uses 24kHz sample rate + audio_cpu + ) + print(f"Audio saved as '{filename}'") + else: + print("No audio was generated") + +def main(): + # Initialize pipeline with American English + pipeline = KPipeline(lang_code='a') + + # The phoneme string for: + # "How are you today? I am doing reasonably well, thank you for asking" + phonemes = "hˌW ɑɹ ju tədˈA? ˌI ɐm dˈuɪŋ ɹˈizənəbli wˈɛl, θˈæŋk ju fɔɹ ˈæskɪŋ" + + try: + print("\nExample 1: Using generate_from_tokens with raw phonemes") + results = list(pipeline.generate_from_tokens( + tokens=phonemes, + voice="af_bella", + speed=1.0 + )) + if results: + save_audio(results[0].audio, 'phoneme_output_new.wav') + + # Example 2: Using generate_from_tokens with pre-processed tokens + print("\nExample 2: Using generate_from_tokens with pre-processed tokens") + # get the tokens through G2P or any other method + text = "How are you today? I am doing reasonably well, thank you for asking" + _, tokens = pipeline.g2p(text) + + # Then generate from tokens + for result in pipeline.generate_from_tokens( + tokens=tokens, + voice="af_bella", + speed=1.0 + ): + # Each result may contain timestamps if available + if result.tokens: + for token in result.tokens: + if hasattr(token, 'start_ts') and hasattr(token, 'end_ts'): + print(f"Token: {token.text} ({token.start_ts:.2f}s - {token.end_ts:.2f}s)") + save_audio(result.audio, f'token_output_{hash(result.phonemes)}.wav') + + except Exception as e: + print(f"An error occurred: {str(e)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index 28d92c7..bc099fd 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -221,6 +221,56 @@ class KPipeline: ) -> KModel.Output: return model(ps, pack[len(ps)-1], speed, return_output=True) + def generate_from_tokens( + self, + tokens: Union[str, List[en.MToken]], + voice: str, + speed: Number = 1, + model: Optional[KModel] = None + ) -> Generator['KPipeline.Result', None, None]: + """Generate audio from either raw phonemes or pre-processed tokens. + + Args: + tokens: Either a phoneme string or list of pre-processed MTokens + voice: The voice to use for synthesis + speed: Speech speed modifier (default: 1) + model: Optional KModel instance (uses pipeline's model if not provided) + + Yields: + KPipeline.Result containing the input tokens and generated audio + + Raises: + ValueError: If no voice is provided or token sequence exceeds model limits + """ + model = model or self.model + if model and voice is None: + raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")') + + pack = self.load_voice(voice).to(model.device) if model else None + + # Handle raw phoneme string + if isinstance(tokens, str): + logger.debug("Processing phonemes from raw string") + if len(tokens) > 510: + raise ValueError(f'Phoneme string too long: {len(tokens)} > 510') + output = KPipeline.infer(model, tokens, pack, speed) if model else None + yield self.Result(graphemes='', phonemes=tokens, output=output) + return + + logger.debug("Processing MTokens") + # Handle pre-processed tokens + for gs, ps, tks in self.en_tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + logger.warning("Truncating to 510 characters") + ps = ps[:510] + output = KPipeline.infer(model, ps, pack, speed) if model else None + if output is not None and output.pred_dur is not None: + KPipeline.join_timestamps(tks, output.pred_dur) + yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output) + @classmethod def join_timestamps(cls, tokens: List[en.MToken], pred_dur: torch.LongTensor): # Multiply by 600 to go from pred_dur frames to sample_rate 24000