Use device

This commit is contained in:
hexgrad
2025-01-27 21:23:11 -08:00
parent de2acfcc8a
commit 2fe9944ff3
3 changed files with 6 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
__version__ = '0.2.0'
__version__ = '0.2.1'
from .models import KModel
from .pipeline import KPipeline

View File

@@ -13,7 +13,7 @@ LANG_CODES = dict(
REPO_ID = 'hexgrad/Kokoro-82M'
class KPipeline:
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False):
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None):
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
self.lang_code = lang_code
if config_path is None:
@@ -31,7 +31,8 @@ class KPipeline:
fallback = None
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
self.vocab = config['vocab']
self.model = KModel(config, model_path)
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.model = KModel(config, model_path).to(self.device).eval()
self.voices = {}
def load_voice(self, voice):
@@ -44,7 +45,7 @@ class KPipeline:
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
assert os.path.exists(voice_path)
self.voices[voice] = torch.load(voice_path, weights_only=True)
self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device)
@classmethod
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', ''}):

View File

@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name='kokoro', # Name of the package
version='0.2.0', # Initial version
version='0.2.1', # Initial version
packages=find_packages(), # Automatically finds packages
install_requires=[ # List your dependencies here
'huggingface_hub',