Fix input dtype mismatches; drop MIGraphXExecutionProvider
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s
- Drop MIGraphXExecutionProvider — symbol mismatch with apt migraphx, ROCMExecutionProvider handles GPU execution fine without it - Add _ort_type_to_np() helper to read expected dtypes from session metadata - prepare_voice: cast audio to session's declared input dtype (float32, not float16) - _generate: read embed_dtype and kv_dtype from LM session metadata Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
39
engine.py
39
engine.py
@@ -43,8 +43,8 @@ def _get_providers(device: str) -> list:
|
|||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
available = set(ort.get_available_providers())
|
available = set(ort.get_available_providers())
|
||||||
providers = []
|
providers = []
|
||||||
if "MIGraphXExecutionProvider" in available:
|
# MIGraphXExecutionProvider excluded — symbol mismatch between onnxruntime-rocm
|
||||||
providers.append("MIGraphXExecutionProvider")
|
# and apt migraphx; ROCMExecutionProvider covers GPU execution adequately.
|
||||||
if "ROCMExecutionProvider" in available:
|
if "ROCMExecutionProvider" in available:
|
||||||
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
|
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
|
||||||
providers.append("CPUExecutionProvider")
|
providers.append("CPUExecutionProvider")
|
||||||
@@ -52,6 +52,22 @@ def _get_providers(device: str) -> list:
|
|||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
def _ort_type_to_np(ort_type: str):
|
||||||
|
"""Convert ORT type string (e.g. 'tensor(float16)') to numpy dtype."""
|
||||||
|
mapping = {
|
||||||
|
"tensor(float)": np.float32,
|
||||||
|
"tensor(float16)": np.float16,
|
||||||
|
"tensor(double)": np.float64,
|
||||||
|
"tensor(int64)": np.int64,
|
||||||
|
"tensor(int32)": np.int32,
|
||||||
|
"tensor(int16)": np.int16,
|
||||||
|
"tensor(int8)": np.int8,
|
||||||
|
"tensor(uint8)": np.uint8,
|
||||||
|
"tensor(bool)": np.bool_,
|
||||||
|
}
|
||||||
|
return mapping.get(ort_type, np.float32)
|
||||||
|
|
||||||
|
|
||||||
def load_model() -> bool:
|
def load_model() -> bool:
|
||||||
global _sessions, _tokenizer
|
global _sessions, _tokenizer
|
||||||
|
|
||||||
@@ -111,10 +127,13 @@ def prepare_voice(audio_prompt_path: str) -> None:
|
|||||||
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
||||||
try:
|
try:
|
||||||
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
||||||
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
|
|
||||||
|
|
||||||
session = _sessions["speech_encoder"]
|
session = _sessions["speech_encoder"]
|
||||||
input_name = session.get_inputs()[0].name
|
input_meta = session.get_inputs()[0]
|
||||||
|
input_dtype = _ort_type_to_np(input_meta.type)
|
||||||
|
audio = audio[np.newaxis, :].astype(input_dtype) # [1, T]
|
||||||
|
|
||||||
|
input_name = input_meta.name
|
||||||
output_names = [o.name for o in session.get_outputs()]
|
output_names = [o.name for o in session.get_outputs()]
|
||||||
outputs = session.run(None, {input_name: audio})
|
outputs = session.run(None, {input_name: audio})
|
||||||
|
|
||||||
@@ -171,12 +190,18 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
|||||||
|
|
||||||
embed_input_name = embed_sess.get_inputs()[0].name
|
embed_input_name = embed_sess.get_inputs()[0].name
|
||||||
|
|
||||||
# Discover KV cache slot names from session metadata
|
# Discover KV cache slot names and dtypes from session metadata
|
||||||
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
||||||
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
||||||
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
||||||
|
|
||||||
kv_dtype = np.float16
|
# Read expected embed dtype from language_model's inputs_embeds slot
|
||||||
|
embeds_meta = next((i for i in lm_sess.get_inputs() if "embeds" in i.name), None)
|
||||||
|
embed_dtype = _ort_type_to_np(embeds_meta.type) if embeds_meta else np.float32
|
||||||
|
|
||||||
|
# Read KV cache dtype from first past_key_values slot
|
||||||
|
kv_meta = next((i for i in lm_sess.get_inputs() if "past_key_values" in i.name), None)
|
||||||
|
kv_dtype = _ort_type_to_np(kv_meta.type) if kv_meta else np.float16
|
||||||
|
|
||||||
# Embed full text sequence
|
# Embed full text sequence
|
||||||
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
||||||
@@ -201,7 +226,7 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
|||||||
|
|
||||||
for _ in range(MAX_NEW_TOKENS):
|
for _ in range(MAX_NEW_TOKENS):
|
||||||
feed = {
|
feed = {
|
||||||
"inputs_embeds": inputs_embeds.astype(kv_dtype),
|
"inputs_embeds": inputs_embeds.astype(embed_dtype),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user