diff --git a/examples/export.py b/examples/export.py index 3e42c00..db3e1b9 100644 --- a/examples/export.py +++ b/examples/export.py @@ -26,8 +26,9 @@ def export_onnx(model, output): output_names = [ 'waveform', 'duration' ], opset_version = 17, dynamic_axes = { - 'input_ids': { 1: 'input_ids_len' }, - 'waveform': { 0: 'num_samples' }, + 'input_ids': {0: "batch_size", 1: 'input_ids_len' }, + 'style': {0: "batch_size"}, + "speed": {0: "batch_size"} }, do_constant_folding = True, ) diff --git a/examples/make_triton_compatible.py b/examples/make_triton_compatible.py new file mode 100644 index 0000000..b769efd --- /dev/null +++ b/examples/make_triton_compatible.py @@ -0,0 +1,92 @@ +""" +This script makes the ONNX model compatible with Triton inference server. +""" + +import sys +import numpy as np +import onnx +import onnxruntime as ort +import onnx_graphsurgeon as gs + + +def add_squeeze(graph, speed_input, speed_unsqueezed): + """ + Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size] + """ + # Create a squeeze node + squeeze_node = gs.Node( + op="Squeeze", + name="speed_squeeze", + inputs=[speed_unsqueezed], + outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)] + ) + + ## Find first node that has speed_unsqueezed as input + insert_idx = 0 + for idx, node in enumerate(graph.nodes): + for i, input_name in enumerate(node.inputs): + if input_name.name == speed_unsqueezed.name: + insert_idx = idx + break + if insert_idx != 0: + break + + ## Add squeeze node to the graph + insert_idx = min(0, insert_idx - 1) + graph.nodes.insert(insert_idx, squeeze_node) + + # Update the speed input to point to the squeezed output + for node in graph.nodes: + for i, input_name in enumerate(node.inputs): + if input_name.name == speed_input.name and not node.name == "speed_squeeze": + node.inputs[i] = squeeze_node.outputs[0] + + return graph + + +def main(): + if len(sys.argv) != 2: + print("Usage: python make_triton_compatible.py ") + sys.exit(1) + + onnx_model_path = sys.argv[1] + onnx_model = onnx.load(onnx_model_path) + onnx.checker.check_model(onnx_model) + print("Model is valid") + + graph = gs.import_onnx(onnx_model) + + ## get input_id for speed + speed_idx, speed = None, None + for idx, input_ in enumerate(graph.inputs): + if input_.name=="speed": + speed_idx = idx + speed = input_ + + # Update the speed input to have shape [batch_size, 1] + speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1]) + graph.inputs[speed_idx] = speed_unsqueezed + + ## Add squeeze to change speed shape from [batch_size, 1] to [batch_size] + if speed is not None: + print(f"Found speed input: {speed.name}") + print(f"Found speed input shape: {speed.shape}") + print(f"Found speed input dtype: {speed.dtype}") + print(f"Found speed input: {speed}") + print(f"Found speed input: {type(speed)}") + graph = add_squeeze(graph, speed, speed_unsqueezed) + + # Export the modified graph back to ONNX + modified_model = gs.export_onnx(graph) + onnx.checker.check_model(modified_model) + + # Save the modified model + output_path = onnx_model_path.replace('.onnx', '_triton.onnx') + onnx.save(modified_model, output_path) + print(f"Modified model saved to: {output_path}") + else: + print("Speed input not found in the model") + + +if __name__ == "__main__": + main()