From 2fe9944ff3426886a8ad14b2ed207f2952f914a8 Mon Sep 17 00:00:00 2001 From: hexgrad Date: Mon, 27 Jan 2025 21:23:11 -0800 Subject: [PATCH] Use device --- kokoro/__init__.py | 2 +- kokoro/pipeline.py | 7 ++++--- setup.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/kokoro/__init__.py b/kokoro/__init__.py index c11f3fa..5fe3637 100644 --- a/kokoro/__init__.py +++ b/kokoro/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.0' +__version__ = '0.2.1' from .models import KModel from .pipeline import KPipeline diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index 38a803e..31df88c 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -13,7 +13,7 @@ LANG_CODES = dict( REPO_ID = 'hexgrad/Kokoro-82M' class KPipeline: - def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False): + def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None): assert lang_code in LANG_CODES, (lang_code, LANG_CODES) self.lang_code = lang_code if config_path is None: @@ -31,7 +31,8 @@ class KPipeline: fallback = None self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback) self.vocab = config['vocab'] - self.model = KModel(config, model_path) + self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + self.model = KModel(config, model_path).to(self.device).eval() self.voices = {} def load_voice(self, voice): @@ -44,7 +45,7 @@ class KPipeline: print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.') voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt') assert os.path.exists(voice_path) - self.voices[voice] = torch.load(voice_path, weights_only=True) + self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device) @classmethod def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}): diff --git a/setup.py b/setup.py index 9f05033..64faef4 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='kokoro', # Name of the package - version='0.2.0', # Initial version + version='0.2.1', # Initial version packages=find_packages(), # Automatically finds packages install_requires=[ # List your dependencies here 'huggingface_hub',