From e74290bf5a47dbf579d6e25e8095e3893d5e7198 Mon Sep 17 00:00:00 2001 From: remsky Date: Fri, 31 Jan 2025 11:09:48 -0700 Subject: [PATCH] feat: add device examples and pipeline updates (#27) --- examples/device_examples.py | 48 +++++++++++++++++++++++++++++++++++++ kokoro/pipeline.py | 31 +++++++++++++++++++++--- 2 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 examples/device_examples.py diff --git a/examples/device_examples.py b/examples/device_examples.py new file mode 100644 index 0000000..83fb61d --- /dev/null +++ b/examples/device_examples.py @@ -0,0 +1,48 @@ +""" +Quick example to show how device selection can be controlled, and was checked +""" +# import warnings, time +import time +# import warnings +from kokoro import KPipeline +# warnings.filterwarnings('ignore') +# import torch; torch.set_warn_always(False) + +def generate_audio(pipeline, text): + for _, _, audio in pipeline(text, voice='af_bella'): + samples = audio.shape[0] if audio is not None else 0 + assert samples > 0, "No audio generated" + return samples + +def time_synthesis(device=None): + try: + start = time.perf_counter() + pipeline = KPipeline(lang_code='a', device=device) + samples = generate_audio(pipeline, "The quick brown fox jumps over the lazy dog.") + ms = (time.perf_counter() - start) * 1000 + print(f"✓ {device or 'auto':<6} | {ms:>5.1f}ms total | {samples:>6,d} samples") + except RuntimeError as e: + print(f"✗ {'cuda' if 'CUDA' in str(e) else device or 'auto':<6} | {'not available' if 'CUDA' in str(e) else str(e)}") + +def compare_shared_model(): + try: + start = time.perf_counter() + en_us = KPipeline(lang_code='a') + en_uk = KPipeline(lang_code='a', model=en_us.model) + + for pipeline in [en_us, en_uk]: + generate_audio(pipeline, "Testing model reuse.") + + ms = (time.perf_counter() - start) * 1000 + print(f"✓ reuse | {ms:>5.1f}ms for both models") + except Exception as e: + print(f"✗ reuse | {str(e)}") + +if __name__ == '__main__': + print("\nDevice Selection & Performance:") + print("----------------------------------------") + time_synthesis() + time_synthesis('cuda') + time_synthesis('cpu') + print("----------------------------------------") + compare_shared_model() \ No newline at end of file diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index 1737104..0683133 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -47,15 +47,40 @@ class KPipeline: A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio). ''' - def __init__(self, lang_code: str, model: Union[KModel, bool] = True, trf: bool = False): + def __init__( + self, + lang_code: str, + model: Union[KModel, bool] = True, + trf: bool = False, + device: Optional[str] = None + ): + """Initialize a KPipeline. + + Args: + lang_code: Language code for G2P processing + model: KModel instance, True to create new model, False for no model + trf: Whether to use transformer-based G2P + device: Override default device selection ('cuda' or 'cpu', or None for auto) + If None, will auto-select cuda if available + If 'cuda' and not available, will explicitly raise an error + """ assert lang_code in LANG_CODES, (lang_code, LANG_CODES) self.lang_code = lang_code self.model = None if isinstance(model, KModel): self.model = model elif model: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = KModel().to(device).eval() + if device == 'cuda' and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + try: + self.model = KModel().to(device).eval() + except RuntimeError as e: + if device == 'cuda': + raise RuntimeError(f"""Failed to initialize model on CUDA: {e}. + Try setting device='cpu' or check CUDA installation.""") + raise self.voices = {} if lang_code in 'ab': try: