Fix input dtype mismatches; drop MIGraphXExecutionProvider
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s

- Drop MIGraphXExecutionProvider — symbol mismatch with apt migraphx,
  ROCMExecutionProvider handles GPU execution fine without it
- Add _ort_type_to_np() helper to read expected dtypes from session metadata
- prepare_voice: cast audio to session's declared input dtype (float32, not float16)
- _generate: read embed_dtype and kv_dtype from LM session metadata

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 20:11:57 -04:00
parent 37e7f082fd
commit dc179ad8c6

View File

@@ -43,8 +43,8 @@ def _get_providers(device: str) -> list:
import onnxruntime as ort import onnxruntime as ort
available = set(ort.get_available_providers()) available = set(ort.get_available_providers())
providers = [] providers = []
if "MIGraphXExecutionProvider" in available: # MIGraphXExecutionProvider excluded — symbol mismatch between onnxruntime-rocm
providers.append("MIGraphXExecutionProvider") # and apt migraphx; ROCMExecutionProvider covers GPU execution adequately.
if "ROCMExecutionProvider" in available: if "ROCMExecutionProvider" in available:
providers.append(("ROCMExecutionProvider", {"device_id": 0})) providers.append(("ROCMExecutionProvider", {"device_id": 0}))
providers.append("CPUExecutionProvider") providers.append("CPUExecutionProvider")
@@ -52,6 +52,22 @@ def _get_providers(device: str) -> list:
return providers return providers
def _ort_type_to_np(ort_type: str):
"""Convert ORT type string (e.g. 'tensor(float16)') to numpy dtype."""
mapping = {
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(double)": np.float64,
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int16)": np.int16,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(bool)": np.bool_,
}
return mapping.get(ort_type, np.float32)
def load_model() -> bool: def load_model() -> bool:
global _sessions, _tokenizer global _sessions, _tokenizer
@@ -111,10 +127,13 @@ def prepare_voice(audio_prompt_path: str) -> None:
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'") logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
try: try:
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True) audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
session = _sessions["speech_encoder"] session = _sessions["speech_encoder"]
input_name = session.get_inputs()[0].name input_meta = session.get_inputs()[0]
input_dtype = _ort_type_to_np(input_meta.type)
audio = audio[np.newaxis, :].astype(input_dtype) # [1, T]
input_name = input_meta.name
output_names = [o.name for o in session.get_outputs()] output_names = [o.name for o in session.get_outputs()]
outputs = session.run(None, {input_name: audio}) outputs = session.run(None, {input_name: audio})
@@ -171,12 +190,18 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
embed_input_name = embed_sess.get_inputs()[0].name embed_input_name = embed_sess.get_inputs()[0].name
# Discover KV cache slot names from session metadata # Discover KV cache slot names and dtypes from session metadata
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name] past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name] present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
lm_out_names = [o.name for o in lm_sess.get_outputs()] lm_out_names = [o.name for o in lm_sess.get_outputs()]
kv_dtype = np.float16 # Read expected embed dtype from language_model's inputs_embeds slot
embeds_meta = next((i for i in lm_sess.get_inputs() if "embeds" in i.name), None)
embed_dtype = _ort_type_to_np(embeds_meta.type) if embeds_meta else np.float32
# Read KV cache dtype from first past_key_values slot
kv_meta = next((i for i in lm_sess.get_inputs() if "past_key_values" in i.name), None)
kv_dtype = _ort_type_to_np(kv_meta.type) if kv_meta else np.float16
# Embed full text sequence # Embed full text sequence
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden] text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
@@ -201,7 +226,7 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
for _ in range(MAX_NEW_TOKENS): for _ in range(MAX_NEW_TOKENS):
feed = { feed = {
"inputs_embeds": inputs_embeds.astype(kv_dtype), "inputs_embeds": inputs_embeds.astype(embed_dtype),
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids, "position_ids": position_ids,
} }