add onnx export.py (#112)

* Add files via upload

onnx export

* Add files via upload

KModelForONNX

* Add files via upload

* Delete export.py

* Add files via upload

* Add files via upload

修正中文的错误

* Add files via upload

增加duration的输出
This commit is contained in:
szsteven008
2025-03-01 03:01:34 +08:00
committed by GitHub
parent b15ef354b2
commit c87df60d4c
2 changed files with 149 additions and 1 deletions

148
examples/export.py Normal file
View File

@@ -0,0 +1,148 @@
import argparse
import os
import torch
import onnx
import onnxruntime as ort
import sounddevice as sd
from kokoro import KModel, KPipeline
from kokoro.model import KModelForONNX
def export_onnx(model, output):
onnx_file = output + "/" + "kokoro.onnx"
input_ids = torch.randint(1, 100, (48,)).numpy()
input_ids = torch.LongTensor([[0, *input_ids, 0]])
style = torch.randn(1, 256)
speed = torch.randint(1, 10, (1,)).int()
torch.onnx.export(
model,
args = (input_ids, style, speed),
f = onnx_file,
export_params = True,
verbose = True,
input_names = [ 'input_ids', 'style', 'speed' ],
output_names = [ 'waveform', 'duration' ],
opset_version = 17,
dynamic_axes = {
'input_ids': { 1: 'input_ids_len' },
'waveform': { 0: 'num_samples' },
},
do_constant_folding = True,
)
print('export kokoro.onnx ok!')
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)
print('onnx check ok!')
def load_input_ids(pipeline, text):
if pipeline.lang_code in 'ab':
_, tokens = pipeline.g2p(text)
for gs, ps, tks in pipeline.en_tokenize(tokens):
if not ps:
continue
else:
ps, _ = pipeline.g2p(text)
if len(ps) > 510:
ps = ps[:510]
input_ids = list(filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps)))
print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}")
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device)
return ps, input_ids
def load_voice(pipeline, voice, phonemes):
pack = pipeline.load_voice(voice).to('cpu')
return pack[len(phonemes) - 1]
def load_sample(model):
pipeline = KPipeline(lang_code='a', model=model.kmodel, device='cpu')
text = '''
In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.'
'''
text = '''
The sky above the port was the color of television, tuned to a dead channel.
'''
voice = 'checkpoints/voices/af_heart.pt'
pipeline = KPipeline(lang_code='z', model=model.kmodel, device='cpu')
text = '''
2月15日晚猫眼专业版数据显示截至发稿《哪吒之魔童闹海》或称《哪吒2》今日票房已达7.8亿元累计票房含预售超过114亿元。
'''
voice = 'checkpoints/voices/zf_xiaoxiao.pt'
phonemes, input_ids = load_input_ids(pipeline, text)
style = load_voice(pipeline, voice, phonemes)
speed = torch.IntTensor([1])
return input_ids, style, speed
def inference_onnx(model, output):
onnx_file = output + "/" + "kokoro.onnx"
session = ort.InferenceSession(onnx_file)
input_ids, style, speed = load_sample(model)
outputs = session.run(None, {
'input_ids': input_ids.numpy(),
'style': style.numpy(),
'speed': speed.numpy(),
})
output = torch.from_numpy(outputs[0])
print(f'output: {output.shape}')
print(output)
audio = output.numpy()
sd.play(audio, 24000)
sd.wait()
def check_model(model):
input_ids, style, speed = load_sample(model)
output, duration = model(input_ids, style, speed)
print(f'output: {output.shape}')
print(f'duration: {duration.shape}')
print(output)
audio = output.numpy()
sd.play(audio, 24000)
sd.wait()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Export kokoro Model to ONNX", add_help=True)
parser.add_argument("--inference", "-t", help="test kokoro.onnx model", action="store_true")
parser.add_argument("--check", "-m", help="check kokoro model", action="store_true")
parser.add_argument(
"--config_file", "-c", type=str, default="checkpoints/config.json", help="path to config file"
)
parser.add_argument(
"--checkpoint_path", "-p", type=str, default="checkpoints/kokoro-v1_0.pth", help="path to checkpoint file"
)
parser.add_argument(
"--output_dir", "-o", type=str, default="onnx", help="output directory"
)
args = parser.parse_args()
# cfg
config_file = args.config_file # change the path of the model config file
checkpoint_path = args.checkpoint_path # change the path of the model
output_dir = args.output_dir
# make dir
os.makedirs(output_dir, exist_ok=True)
kmodel = KModel(config=config_file, model=checkpoint_path, disable_complex=True)
model = KModelForONNX(kmodel).eval()
if args.inference:
inference_onnx(model, output_dir)
elif args.check:
check_model(model)
else:
export_onnx(model, output_dir)

View File

@@ -148,4 +148,4 @@ class KModelForONNX(torch.nn.Module):
speed: float = 1 speed: float = 1
) -> tuple[torch.FloatTensor, torch.LongTensor]: ) -> tuple[torch.FloatTensor, torch.LongTensor]:
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
return waveform return waveform, duration