Use device
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
__version__ = '0.2.0'
|
__version__ = '0.2.1'
|
||||||
|
|
||||||
from .models import KModel
|
from .models import KModel
|
||||||
from .pipeline import KPipeline
|
from .pipeline import KPipeline
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ LANG_CODES = dict(
|
|||||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||||
|
|
||||||
class KPipeline:
|
class KPipeline:
|
||||||
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False):
|
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None):
|
||||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||||
self.lang_code = lang_code
|
self.lang_code = lang_code
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
@@ -31,7 +31,8 @@ class KPipeline:
|
|||||||
fallback = None
|
fallback = None
|
||||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||||
self.vocab = config['vocab']
|
self.vocab = config['vocab']
|
||||||
self.model = KModel(config, model_path)
|
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||||
|
self.model = KModel(config, model_path).to(self.device).eval()
|
||||||
self.voices = {}
|
self.voices = {}
|
||||||
|
|
||||||
def load_voice(self, voice):
|
def load_voice(self, voice):
|
||||||
@@ -44,7 +45,7 @@ class KPipeline:
|
|||||||
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
|
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
|
||||||
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
|
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
|
||||||
assert os.path.exists(voice_path)
|
assert os.path.exists(voice_path)
|
||||||
self.voices[voice] = torch.load(voice_path, weights_only=True)
|
self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}):
|
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}):
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='kokoro', # Name of the package
|
name='kokoro', # Name of the package
|
||||||
version='0.2.0', # Initial version
|
version='0.2.1', # Initial version
|
||||||
packages=find_packages(), # Automatically finds packages
|
packages=find_packages(), # Automatically finds packages
|
||||||
install_requires=[ # List your dependencies here
|
install_requires=[ # List your dependencies here
|
||||||
'huggingface_hub',
|
'huggingface_hub',
|
||||||
|
|||||||
Reference in New Issue
Block a user