diff --git a/kokoro/pipeline.py b/kokoro/pipeline.py index ad5f2a3..1737104 100644 --- a/kokoro/pipeline.py +++ b/kokoro/pipeline.py @@ -83,7 +83,7 @@ class KPipeline: print(f"⚠️ WARNING: Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.") self.g2p = espeak.EspeakG2P(language=language) - def load_voice(self, voice: str) -> torch.FloatTensor: + def load_single_voice(self, voice: str): if voice in self.voices: return self.voices[voice] if voice.endswith('.pt'): @@ -98,6 +98,21 @@ class KPipeline: self.voices[voice] = pack return pack + """ + load_voice is a helper function that lazily downloads and loads a voice: + Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica'). + If multiple voices are requested, they are averaged. + Delimiter is optional and defaults to ','. + """ + def load_voice(self, voice: str, delimiter: str = ",") -> torch.FloatTensor: + if voice in self.voices: + return self.voices[voice] + packs = [self.load_single_voice(v) for v in voice.split(delimiter)] + if len(packs) == 1: + return packs[0] + self.voices[voice] = torch.mean(torch.stack(packs), dim=0) + return self.voices[voice] + @classmethod def waterfall_last( cls,