add onnx export.py (#112)
* Add files via upload onnx export * Add files via upload KModelForONNX * Add files via upload * Delete export.py * Add files via upload * Add files via upload 修正中文的错误 * Add files via upload 增加duration的输出
This commit is contained in:
148
examples/export.py
Normal file
148
examples/export.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import sounddevice as sd
|
||||
|
||||
from kokoro import KModel, KPipeline
|
||||
from kokoro.model import KModelForONNX
|
||||
|
||||
def export_onnx(model, output):
|
||||
onnx_file = output + "/" + "kokoro.onnx"
|
||||
|
||||
input_ids = torch.randint(1, 100, (48,)).numpy()
|
||||
input_ids = torch.LongTensor([[0, *input_ids, 0]])
|
||||
style = torch.randn(1, 256)
|
||||
speed = torch.randint(1, 10, (1,)).int()
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
args = (input_ids, style, speed),
|
||||
f = onnx_file,
|
||||
export_params = True,
|
||||
verbose = True,
|
||||
input_names = [ 'input_ids', 'style', 'speed' ],
|
||||
output_names = [ 'waveform', 'duration' ],
|
||||
opset_version = 17,
|
||||
dynamic_axes = {
|
||||
'input_ids': { 1: 'input_ids_len' },
|
||||
'waveform': { 0: 'num_samples' },
|
||||
},
|
||||
do_constant_folding = True,
|
||||
)
|
||||
|
||||
print('export kokoro.onnx ok!')
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print('onnx check ok!')
|
||||
|
||||
def load_input_ids(pipeline, text):
|
||||
if pipeline.lang_code in 'ab':
|
||||
_, tokens = pipeline.g2p(text)
|
||||
for gs, ps, tks in pipeline.en_tokenize(tokens):
|
||||
if not ps:
|
||||
continue
|
||||
else:
|
||||
ps, _ = pipeline.g2p(text)
|
||||
|
||||
if len(ps) > 510:
|
||||
ps = ps[:510]
|
||||
|
||||
input_ids = list(filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps)))
|
||||
print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}")
|
||||
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device)
|
||||
return ps, input_ids
|
||||
|
||||
def load_voice(pipeline, voice, phonemes):
|
||||
pack = pipeline.load_voice(voice).to('cpu')
|
||||
return pack[len(phonemes) - 1]
|
||||
|
||||
def load_sample(model):
|
||||
pipeline = KPipeline(lang_code='a', model=model.kmodel, device='cpu')
|
||||
text = '''
|
||||
In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.'
|
||||
'''
|
||||
text = '''
|
||||
The sky above the port was the color of television, tuned to a dead channel.
|
||||
'''
|
||||
voice = 'checkpoints/voices/af_heart.pt'
|
||||
|
||||
pipeline = KPipeline(lang_code='z', model=model.kmodel, device='cpu')
|
||||
text = '''
|
||||
2月15日晚,猫眼专业版数据显示,截至发稿,《哪吒之魔童闹海》(或称《哪吒2》)今日票房已达7.8亿元,累计票房(含预售)超过114亿元。
|
||||
'''
|
||||
voice = 'checkpoints/voices/zf_xiaoxiao.pt'
|
||||
|
||||
phonemes, input_ids = load_input_ids(pipeline, text)
|
||||
style = load_voice(pipeline, voice, phonemes)
|
||||
speed = torch.IntTensor([1])
|
||||
|
||||
return input_ids, style, speed
|
||||
|
||||
def inference_onnx(model, output):
|
||||
onnx_file = output + "/" + "kokoro.onnx"
|
||||
session = ort.InferenceSession(onnx_file)
|
||||
|
||||
input_ids, style, speed = load_sample(model)
|
||||
|
||||
outputs = session.run(None, {
|
||||
'input_ids': input_ids.numpy(),
|
||||
'style': style.numpy(),
|
||||
'speed': speed.numpy(),
|
||||
})
|
||||
|
||||
output = torch.from_numpy(outputs[0])
|
||||
print(f'output: {output.shape}')
|
||||
print(output)
|
||||
|
||||
audio = output.numpy()
|
||||
sd.play(audio, 24000)
|
||||
sd.wait()
|
||||
|
||||
def check_model(model):
|
||||
input_ids, style, speed = load_sample(model)
|
||||
output, duration = model(input_ids, style, speed)
|
||||
|
||||
print(f'output: {output.shape}')
|
||||
print(f'duration: {duration.shape}')
|
||||
print(output)
|
||||
|
||||
audio = output.numpy()
|
||||
sd.play(audio, 24000)
|
||||
sd.wait()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Export kokoro Model to ONNX", add_help=True)
|
||||
parser.add_argument("--inference", "-t", help="test kokoro.onnx model", action="store_true")
|
||||
parser.add_argument("--check", "-m", help="check kokoro model", action="store_true")
|
||||
parser.add_argument(
|
||||
"--config_file", "-c", type=str, default="checkpoints/config.json", help="path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", "-p", type=str, default="checkpoints/kokoro-v1_0.pth", help="path to checkpoint file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", "-o", type=str, default="onnx", help="output directory"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# cfg
|
||||
config_file = args.config_file # change the path of the model config file
|
||||
checkpoint_path = args.checkpoint_path # change the path of the model
|
||||
output_dir = args.output_dir
|
||||
|
||||
# make dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
kmodel = KModel(config=config_file, model=checkpoint_path, disable_complex=True)
|
||||
model = KModelForONNX(kmodel).eval()
|
||||
|
||||
if args.inference:
|
||||
inference_onnx(model, output_dir)
|
||||
elif args.check:
|
||||
check_model(model)
|
||||
else:
|
||||
export_onnx(model, output_dir)
|
||||
Reference in New Issue
Block a user