Use device
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = '0.2.0'
|
||||
__version__ = '0.2.1'
|
||||
|
||||
from .models import KModel
|
||||
from .pipeline import KPipeline
|
||||
|
||||
@@ -13,7 +13,7 @@ LANG_CODES = dict(
|
||||
REPO_ID = 'hexgrad/Kokoro-82M'
|
||||
|
||||
class KPipeline:
|
||||
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False):
|
||||
def __init__(self, lang_code='a', config_path=None, model_path=None, trf=False, device=None):
|
||||
assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
|
||||
self.lang_code = lang_code
|
||||
if config_path is None:
|
||||
@@ -31,7 +31,8 @@ class KPipeline:
|
||||
fallback = None
|
||||
self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback)
|
||||
self.vocab = config['vocab']
|
||||
self.model = KModel(config, model_path)
|
||||
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
self.model = KModel(config, model_path).to(self.device).eval()
|
||||
self.voices = {}
|
||||
|
||||
def load_voice(self, voice):
|
||||
@@ -44,7 +45,7 @@ class KPipeline:
|
||||
print(f'WARNING: Loading {v} voice into {p} pipeline. Phonemes may be mismatched.')
|
||||
voice_path = voice if voice.endswith('.pt') else hf_hub_download(repo_id=REPO_ID, filename=f'voices/{voice}.pt')
|
||||
assert os.path.exists(voice_path)
|
||||
self.voices[voice] = torch.load(voice_path, weights_only=True)
|
||||
self.voices[voice] = torch.load(voice_path, weights_only=True).to(self.device)
|
||||
|
||||
@classmethod
|
||||
def waterfall_last(cls, pairs, next_count, waterfall=['!.?…', ':;', ',—'], bumps={')', '”'}):
|
||||
|
||||
Reference in New Issue
Block a user