Feat: batch support for onnx and triton compatibility (#239)
* Added batch support Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> * file rename Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> --------- Signed-off-by: Siddharth Tyagi <siddhartht@nvidia.com> Co-authored-by: Siddharth Tyagi <siddhartht@nvidia.com>
This commit is contained in:
92
examples/make_triton_compatible.py
Normal file
92
examples/make_triton_compatible.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
This script makes the ONNX model compatible with Triton inference server.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import onnx_graphsurgeon as gs
|
||||
|
||||
|
||||
def add_squeeze(graph, speed_input, speed_unsqueezed):
|
||||
"""
|
||||
Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size]
|
||||
"""
|
||||
# Create a squeeze node
|
||||
squeeze_node = gs.Node(
|
||||
op="Squeeze",
|
||||
name="speed_squeeze",
|
||||
inputs=[speed_unsqueezed],
|
||||
outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)]
|
||||
)
|
||||
|
||||
## Find first node that has speed_unsqueezed as input
|
||||
insert_idx = 0
|
||||
for idx, node in enumerate(graph.nodes):
|
||||
for i, input_name in enumerate(node.inputs):
|
||||
if input_name.name == speed_unsqueezed.name:
|
||||
insert_idx = idx
|
||||
break
|
||||
if insert_idx != 0:
|
||||
break
|
||||
|
||||
## Add squeeze node to the graph
|
||||
insert_idx = min(0, insert_idx - 1)
|
||||
graph.nodes.insert(insert_idx, squeeze_node)
|
||||
|
||||
# Update the speed input to point to the squeezed output
|
||||
for node in graph.nodes:
|
||||
for i, input_name in enumerate(node.inputs):
|
||||
if input_name.name == speed_input.name and not node.name == "speed_squeeze":
|
||||
node.inputs[i] = squeeze_node.outputs[0]
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python make_triton_compatible.py <onnx_model_path>")
|
||||
sys.exit(1)
|
||||
|
||||
onnx_model_path = sys.argv[1]
|
||||
onnx_model = onnx.load(onnx_model_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print("Model is valid")
|
||||
|
||||
graph = gs.import_onnx(onnx_model)
|
||||
|
||||
## get input_id for speed
|
||||
speed_idx, speed = None, None
|
||||
for idx, input_ in enumerate(graph.inputs):
|
||||
if input_.name=="speed":
|
||||
speed_idx = idx
|
||||
speed = input_
|
||||
|
||||
# Update the speed input to have shape [batch_size, 1]
|
||||
speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1])
|
||||
graph.inputs[speed_idx] = speed_unsqueezed
|
||||
|
||||
## Add squeeze to change speed shape from [batch_size, 1] to [batch_size]
|
||||
if speed is not None:
|
||||
print(f"Found speed input: {speed.name}")
|
||||
print(f"Found speed input shape: {speed.shape}")
|
||||
print(f"Found speed input dtype: {speed.dtype}")
|
||||
print(f"Found speed input: {speed}")
|
||||
print(f"Found speed input: {type(speed)}")
|
||||
graph = add_squeeze(graph, speed, speed_unsqueezed)
|
||||
|
||||
# Export the modified graph back to ONNX
|
||||
modified_model = gs.export_onnx(graph)
|
||||
onnx.checker.check_model(modified_model)
|
||||
|
||||
# Save the modified model
|
||||
output_path = onnx_model_path.replace('.onnx', '_triton.onnx')
|
||||
onnx.save(modified_model, output_path)
|
||||
print(f"Modified model saved to: {output_path}")
|
||||
else:
|
||||
print("Speed input not found in the model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user