Feat: batch support for onnx and triton compatibility (#239)
* Added batch support Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> * file rename Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> --------- Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> Co-authored-by: Siddharth Tyagi <siddhartht@nvidia.com>
This commit is contained in:
@@ -26,8 +26,9 @@ def export_onnx(model, output):
|
|||||||
output_names = [ 'waveform', 'duration' ],
|
output_names = [ 'waveform', 'duration' ],
|
||||||
opset_version = 17,
|
opset_version = 17,
|
||||||
dynamic_axes = {
|
dynamic_axes = {
|
||||||
'input_ids': { 1: 'input_ids_len' },
|
'input_ids': {0: "batch_size", 1: 'input_ids_len' },
|
||||||
'waveform': { 0: 'num_samples' },
|
'style': {0: "batch_size"},
|
||||||
|
"speed": {0: "batch_size"}
|
||||||
},
|
},
|
||||||
do_constant_folding = True,
|
do_constant_folding = True,
|
||||||
)
|
)
|
||||||
|
|||||||
92
examples/make_triton_compatible.py
Normal file
92
examples/make_triton_compatible.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
This script makes the ONNX model compatible with Triton inference server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
|
|
||||||
|
|
||||||
|
def add_squeeze(graph, speed_input, speed_unsqueezed):
|
||||||
|
"""
|
||||||
|
Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size]
|
||||||
|
"""
|
||||||
|
# Create a squeeze node
|
||||||
|
squeeze_node = gs.Node(
|
||||||
|
op="Squeeze",
|
||||||
|
name="speed_squeeze",
|
||||||
|
inputs=[speed_unsqueezed],
|
||||||
|
outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)]
|
||||||
|
)
|
||||||
|
|
||||||
|
## Find first node that has speed_unsqueezed as input
|
||||||
|
insert_idx = 0
|
||||||
|
for idx, node in enumerate(graph.nodes):
|
||||||
|
for i, input_name in enumerate(node.inputs):
|
||||||
|
if input_name.name == speed_unsqueezed.name:
|
||||||
|
insert_idx = idx
|
||||||
|
break
|
||||||
|
if insert_idx != 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
## Add squeeze node to the graph
|
||||||
|
insert_idx = min(0, insert_idx - 1)
|
||||||
|
graph.nodes.insert(insert_idx, squeeze_node)
|
||||||
|
|
||||||
|
# Update the speed input to point to the squeezed output
|
||||||
|
for node in graph.nodes:
|
||||||
|
for i, input_name in enumerate(node.inputs):
|
||||||
|
if input_name.name == speed_input.name and not node.name == "speed_squeeze":
|
||||||
|
node.inputs[i] = squeeze_node.outputs[0]
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python make_triton_compatible.py <onnx_model_path>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
onnx_model_path = sys.argv[1]
|
||||||
|
onnx_model = onnx.load(onnx_model_path)
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
print("Model is valid")
|
||||||
|
|
||||||
|
graph = gs.import_onnx(onnx_model)
|
||||||
|
|
||||||
|
## get input_id for speed
|
||||||
|
speed_idx, speed = None, None
|
||||||
|
for idx, input_ in enumerate(graph.inputs):
|
||||||
|
if input_.name=="speed":
|
||||||
|
speed_idx = idx
|
||||||
|
speed = input_
|
||||||
|
|
||||||
|
# Update the speed input to have shape [batch_size, 1]
|
||||||
|
speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1])
|
||||||
|
graph.inputs[speed_idx] = speed_unsqueezed
|
||||||
|
|
||||||
|
## Add squeeze to change speed shape from [batch_size, 1] to [batch_size]
|
||||||
|
if speed is not None:
|
||||||
|
print(f"Found speed input: {speed.name}")
|
||||||
|
print(f"Found speed input shape: {speed.shape}")
|
||||||
|
print(f"Found speed input dtype: {speed.dtype}")
|
||||||
|
print(f"Found speed input: {speed}")
|
||||||
|
print(f"Found speed input: {type(speed)}")
|
||||||
|
graph = add_squeeze(graph, speed, speed_unsqueezed)
|
||||||
|
|
||||||
|
# Export the modified graph back to ONNX
|
||||||
|
modified_model = gs.export_onnx(graph)
|
||||||
|
onnx.checker.check_model(modified_model)
|
||||||
|
|
||||||
|
# Save the modified model
|
||||||
|
output_path = onnx_model_path.replace('.onnx', '_triton.onnx')
|
||||||
|
onnx.save(modified_model, output_path)
|
||||||
|
print(f"Modified model saved to: {output_path}")
|
||||||
|
else:
|
||||||
|
print("Speed input not found in the model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user