* Added batch support Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> * file rename Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> --------- Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> Co-authored-by: Siddharth Tyagi <siddhartht@nvidia.com>
150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
import argparse
|
||
import os
|
||
import torch
|
||
import onnx
|
||
import onnxruntime as ort
|
||
import sounddevice as sd
|
||
|
||
from kokoro import KModel, KPipeline
|
||
from kokoro.model import KModelForONNX
|
||
|
||
def export_onnx(model, output):
|
||
onnx_file = output + "/" + "kokoro.onnx"
|
||
|
||
input_ids = torch.randint(1, 100, (48,)).numpy()
|
||
input_ids = torch.LongTensor([[0, *input_ids, 0]])
|
||
style = torch.randn(1, 256)
|
||
speed = torch.randint(1, 10, (1,)).int()
|
||
|
||
torch.onnx.export(
|
||
model,
|
||
args = (input_ids, style, speed),
|
||
f = onnx_file,
|
||
export_params = True,
|
||
verbose = True,
|
||
input_names = [ 'input_ids', 'style', 'speed' ],
|
||
output_names = [ 'waveform', 'duration' ],
|
||
opset_version = 17,
|
||
dynamic_axes = {
|
||
'input_ids': {0: "batch_size", 1: 'input_ids_len' },
|
||
'style': {0: "batch_size"},
|
||
"speed": {0: "batch_size"}
|
||
},
|
||
do_constant_folding = True,
|
||
)
|
||
|
||
print('export kokoro.onnx ok!')
|
||
|
||
onnx_model = onnx.load(onnx_file)
|
||
onnx.checker.check_model(onnx_model)
|
||
print('onnx check ok!')
|
||
|
||
def load_input_ids(pipeline, text):
|
||
if pipeline.lang_code in 'ab':
|
||
_, tokens = pipeline.g2p(text)
|
||
for gs, ps, tks in pipeline.en_tokenize(tokens):
|
||
if not ps:
|
||
continue
|
||
else:
|
||
ps, _ = pipeline.g2p(text)
|
||
|
||
if len(ps) > 510:
|
||
ps = ps[:510]
|
||
|
||
input_ids = list(filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps)))
|
||
print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}")
|
||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device)
|
||
return ps, input_ids
|
||
|
||
def load_voice(pipeline, voice, phonemes):
|
||
pack = pipeline.load_voice(voice).to('cpu')
|
||
return pack[len(phonemes) - 1]
|
||
|
||
def load_sample(model):
|
||
pipeline = KPipeline(lang_code='a', model=model.kmodel, device='cpu')
|
||
text = '''
|
||
In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.'
|
||
'''
|
||
text = '''
|
||
The sky above the port was the color of television, tuned to a dead channel.
|
||
'''
|
||
voice = 'checkpoints/voices/af_heart.pt'
|
||
|
||
pipeline = KPipeline(lang_code='z', model=model.kmodel, device='cpu')
|
||
text = '''
|
||
2月15日晚,猫眼专业版数据显示,截至发稿,《哪吒之魔童闹海》(或称《哪吒2》)今日票房已达7.8亿元,累计票房(含预售)超过114亿元。
|
||
'''
|
||
voice = 'checkpoints/voices/zf_xiaoxiao.pt'
|
||
|
||
phonemes, input_ids = load_input_ids(pipeline, text)
|
||
style = load_voice(pipeline, voice, phonemes)
|
||
speed = torch.IntTensor([1])
|
||
|
||
return input_ids, style, speed
|
||
|
||
def inference_onnx(model, output):
|
||
onnx_file = output + "/" + "kokoro.onnx"
|
||
session = ort.InferenceSession(onnx_file)
|
||
|
||
input_ids, style, speed = load_sample(model)
|
||
|
||
outputs = session.run(None, {
|
||
'input_ids': input_ids.numpy(),
|
||
'style': style.numpy(),
|
||
'speed': speed.numpy(),
|
||
})
|
||
|
||
output = torch.from_numpy(outputs[0])
|
||
print(f'output: {output.shape}')
|
||
print(output)
|
||
|
||
audio = output.numpy()
|
||
sd.play(audio, 24000)
|
||
sd.wait()
|
||
|
||
def check_model(model):
|
||
input_ids, style, speed = load_sample(model)
|
||
output, duration = model(input_ids, style, speed)
|
||
|
||
print(f'output: {output.shape}')
|
||
print(f'duration: {duration.shape}')
|
||
print(output)
|
||
|
||
audio = output.numpy()
|
||
sd.play(audio, 24000)
|
||
sd.wait()
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser("Export kokoro Model to ONNX", add_help=True)
|
||
parser.add_argument("--inference", "-t", help="test kokoro.onnx model", action="store_true")
|
||
parser.add_argument("--check", "-m", help="check kokoro model", action="store_true")
|
||
parser.add_argument(
|
||
"--config_file", "-c", type=str, default="checkpoints/config.json", help="path to config file"
|
||
)
|
||
parser.add_argument(
|
||
"--checkpoint_path", "-p", type=str, default="checkpoints/kokoro-v1_0.pth", help="path to checkpoint file"
|
||
)
|
||
parser.add_argument(
|
||
"--output_dir", "-o", type=str, default="onnx", help="output directory"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# cfg
|
||
config_file = args.config_file # change the path of the model config file
|
||
checkpoint_path = args.checkpoint_path # change the path of the model
|
||
output_dir = args.output_dir
|
||
|
||
# make dir
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
kmodel = KModel(config=config_file, model=checkpoint_path, disable_complex=True)
|
||
model = KModelForONNX(kmodel).eval()
|
||
|
||
if args.inference:
|
||
inference_onnx(model, output_dir)
|
||
elif args.check:
|
||
check_model(model)
|
||
else:
|
||
export_onnx(model, output_dir)
|