From dc179ad8c617b24b6da67e180a6c2a490d4bfd6c Mon Sep 17 00:00:00 2001 From: scott Date: Mon, 6 Apr 2026 20:11:57 -0400 Subject: [PATCH] Fix input dtype mismatches; drop MIGraphXExecutionProvider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Drop MIGraphXExecutionProvider — symbol mismatch with apt migraphx, ROCMExecutionProvider handles GPU execution fine without it - Add _ort_type_to_np() helper to read expected dtypes from session metadata - prepare_voice: cast audio to session's declared input dtype (float32, not float16) - _generate: read embed_dtype and kv_dtype from LM session metadata Co-Authored-By: Claude Sonnet 4.6 --- engine.py | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/engine.py b/engine.py index 113e9bf..3b7fe93 100644 --- a/engine.py +++ b/engine.py @@ -43,8 +43,8 @@ def _get_providers(device: str) -> list: import onnxruntime as ort available = set(ort.get_available_providers()) providers = [] - if "MIGraphXExecutionProvider" in available: - providers.append("MIGraphXExecutionProvider") + # MIGraphXExecutionProvider excluded — symbol mismatch between onnxruntime-rocm + # and apt migraphx; ROCMExecutionProvider covers GPU execution adequately. if "ROCMExecutionProvider" in available: providers.append(("ROCMExecutionProvider", {"device_id": 0})) providers.append("CPUExecutionProvider") @@ -52,6 +52,22 @@ def _get_providers(device: str) -> list: return providers +def _ort_type_to_np(ort_type: str): + """Convert ORT type string (e.g. 'tensor(float16)') to numpy dtype.""" + mapping = { + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(double)": np.float64, + "tensor(int64)": np.int64, + "tensor(int32)": np.int32, + "tensor(int16)": np.int16, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(bool)": np.bool_, + } + return mapping.get(ort_type, np.float32) + + def load_model() -> bool: global _sessions, _tokenizer @@ -111,10 +127,13 @@ def prepare_voice(audio_prompt_path: str) -> None: logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") try: audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True) - audio = audio[np.newaxis, :].astype(np.float16) # [1, T] session = _sessions["speech_encoder"] - input_name = session.get_inputs()[0].name + input_meta = session.get_inputs()[0] + input_dtype = _ort_type_to_np(input_meta.type) + audio = audio[np.newaxis, :].astype(input_dtype) # [1, T] + + input_name = input_meta.name output_names = [o.name for o in session.get_outputs()] outputs = session.run(None, {input_name: audio}) @@ -171,12 +190,18 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: embed_input_name = embed_sess.get_inputs()[0].name - # Discover KV cache slot names from session metadata + # Discover KV cache slot names and dtypes from session metadata past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name] present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name] lm_out_names = [o.name for o in lm_sess.get_outputs()] - kv_dtype = np.float16 + # Read expected embed dtype from language_model's inputs_embeds slot + embeds_meta = next((i for i in lm_sess.get_inputs() if "embeds" in i.name), None) + embed_dtype = _ort_type_to_np(embeds_meta.type) if embeds_meta else np.float32 + + # Read KV cache dtype from first past_key_values slot + kv_meta = next((i for i in lm_sess.get_inputs() if "past_key_values" in i.name), None) + kv_dtype = _ort_type_to_np(kv_meta.type) if kv_meta else np.float16 # Embed full text sequence text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden] @@ -201,7 +226,7 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray: for _ in range(MAX_NEW_TOKENS): feed = { - "inputs_embeds": inputs_embeds.astype(kv_dtype), + "inputs_embeds": inputs_embeds.astype(embed_dtype), "attention_mask": attention_mask, "position_ids": position_ids, }