add onnx export.py (#112)
* Add files via upload onnx export * Add files via upload KModelForONNX * Add files via upload * Delete export.py * Add files via upload * Add files via upload 修正中文的错误 * Add files via upload 增加duration的输出
This commit is contained in:
148
examples/export.py
Normal file
148
examples/export.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
from kokoro import KModel, KPipeline
|
||||||
|
from kokoro.model import KModelForONNX
|
||||||
|
|
||||||
|
def export_onnx(model, output):
|
||||||
|
onnx_file = output + "/" + "kokoro.onnx"
|
||||||
|
|
||||||
|
input_ids = torch.randint(1, 100, (48,)).numpy()
|
||||||
|
input_ids = torch.LongTensor([[0, *input_ids, 0]])
|
||||||
|
style = torch.randn(1, 256)
|
||||||
|
speed = torch.randint(1, 10, (1,)).int()
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
args = (input_ids, style, speed),
|
||||||
|
f = onnx_file,
|
||||||
|
export_params = True,
|
||||||
|
verbose = True,
|
||||||
|
input_names = [ 'input_ids', 'style', 'speed' ],
|
||||||
|
output_names = [ 'waveform', 'duration' ],
|
||||||
|
opset_version = 17,
|
||||||
|
dynamic_axes = {
|
||||||
|
'input_ids': { 1: 'input_ids_len' },
|
||||||
|
'waveform': { 0: 'num_samples' },
|
||||||
|
},
|
||||||
|
do_constant_folding = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print('export kokoro.onnx ok!')
|
||||||
|
|
||||||
|
onnx_model = onnx.load(onnx_file)
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
print('onnx check ok!')
|
||||||
|
|
||||||
|
def load_input_ids(pipeline, text):
|
||||||
|
if pipeline.lang_code in 'ab':
|
||||||
|
_, tokens = pipeline.g2p(text)
|
||||||
|
for gs, ps, tks in pipeline.en_tokenize(tokens):
|
||||||
|
if not ps:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
ps, _ = pipeline.g2p(text)
|
||||||
|
|
||||||
|
if len(ps) > 510:
|
||||||
|
ps = ps[:510]
|
||||||
|
|
||||||
|
input_ids = list(filter(lambda i: i is not None, map(lambda p: pipeline.model.vocab.get(p), ps)))
|
||||||
|
print(f"text: {text} -> phonemes: {ps} -> input_ids: {input_ids}")
|
||||||
|
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(pipeline.model.device)
|
||||||
|
return ps, input_ids
|
||||||
|
|
||||||
|
def load_voice(pipeline, voice, phonemes):
|
||||||
|
pack = pipeline.load_voice(voice).to('cpu')
|
||||||
|
return pack[len(phonemes) - 1]
|
||||||
|
|
||||||
|
def load_sample(model):
|
||||||
|
pipeline = KPipeline(lang_code='a', model=model.kmodel, device='cpu')
|
||||||
|
text = '''
|
||||||
|
In today's fast-paced tech world, building software applications has never been easier — thanks to AI-powered coding assistants.'
|
||||||
|
'''
|
||||||
|
text = '''
|
||||||
|
The sky above the port was the color of television, tuned to a dead channel.
|
||||||
|
'''
|
||||||
|
voice = 'checkpoints/voices/af_heart.pt'
|
||||||
|
|
||||||
|
pipeline = KPipeline(lang_code='z', model=model.kmodel, device='cpu')
|
||||||
|
text = '''
|
||||||
|
2月15日晚,猫眼专业版数据显示,截至发稿,《哪吒之魔童闹海》(或称《哪吒2》)今日票房已达7.8亿元,累计票房(含预售)超过114亿元。
|
||||||
|
'''
|
||||||
|
voice = 'checkpoints/voices/zf_xiaoxiao.pt'
|
||||||
|
|
||||||
|
phonemes, input_ids = load_input_ids(pipeline, text)
|
||||||
|
style = load_voice(pipeline, voice, phonemes)
|
||||||
|
speed = torch.IntTensor([1])
|
||||||
|
|
||||||
|
return input_ids, style, speed
|
||||||
|
|
||||||
|
def inference_onnx(model, output):
|
||||||
|
onnx_file = output + "/" + "kokoro.onnx"
|
||||||
|
session = ort.InferenceSession(onnx_file)
|
||||||
|
|
||||||
|
input_ids, style, speed = load_sample(model)
|
||||||
|
|
||||||
|
outputs = session.run(None, {
|
||||||
|
'input_ids': input_ids.numpy(),
|
||||||
|
'style': style.numpy(),
|
||||||
|
'speed': speed.numpy(),
|
||||||
|
})
|
||||||
|
|
||||||
|
output = torch.from_numpy(outputs[0])
|
||||||
|
print(f'output: {output.shape}')
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
audio = output.numpy()
|
||||||
|
sd.play(audio, 24000)
|
||||||
|
sd.wait()
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
input_ids, style, speed = load_sample(model)
|
||||||
|
output, duration = model(input_ids, style, speed)
|
||||||
|
|
||||||
|
print(f'output: {output.shape}')
|
||||||
|
print(f'duration: {duration.shape}')
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
audio = output.numpy()
|
||||||
|
sd.play(audio, 24000)
|
||||||
|
sd.wait()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Export kokoro Model to ONNX", add_help=True)
|
||||||
|
parser.add_argument("--inference", "-t", help="test kokoro.onnx model", action="store_true")
|
||||||
|
parser.add_argument("--check", "-m", help="check kokoro model", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file", "-c", type=str, default="checkpoints/config.json", help="path to config file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", "-p", type=str, default="checkpoints/kokoro-v1_0.pth", help="path to checkpoint file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", "-o", type=str, default="onnx", help="output directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# cfg
|
||||||
|
config_file = args.config_file # change the path of the model config file
|
||||||
|
checkpoint_path = args.checkpoint_path # change the path of the model
|
||||||
|
output_dir = args.output_dir
|
||||||
|
|
||||||
|
# make dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
kmodel = KModel(config=config_file, model=checkpoint_path, disable_complex=True)
|
||||||
|
model = KModelForONNX(kmodel).eval()
|
||||||
|
|
||||||
|
if args.inference:
|
||||||
|
inference_onnx(model, output_dir)
|
||||||
|
elif args.check:
|
||||||
|
check_model(model)
|
||||||
|
else:
|
||||||
|
export_onnx(model, output_dir)
|
||||||
@@ -148,4 +148,4 @@ class KModelForONNX(torch.nn.Module):
|
|||||||
speed: float = 1
|
speed: float = 1
|
||||||
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
) -> tuple[torch.FloatTensor, torch.LongTensor]:
|
||||||
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
|
||||||
return waveform
|
return waveform, duration
|
||||||
|
|||||||
Reference in New Issue
Block a user