Fix input dtype mismatches; drop MIGraphXExecutionProvider
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s
All checks were successful
Build ROCm Image / build (push) Successful in 14m40s
- Drop MIGraphXExecutionProvider — symbol mismatch with apt migraphx, ROCMExecutionProvider handles GPU execution fine without it - Add _ort_type_to_np() helper to read expected dtypes from session metadata - prepare_voice: cast audio to session's declared input dtype (float32, not float16) - _generate: read embed_dtype and kv_dtype from LM session metadata Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
39
engine.py
39
engine.py
@@ -43,8 +43,8 @@ def _get_providers(device: str) -> list:
|
||||
import onnxruntime as ort
|
||||
available = set(ort.get_available_providers())
|
||||
providers = []
|
||||
if "MIGraphXExecutionProvider" in available:
|
||||
providers.append("MIGraphXExecutionProvider")
|
||||
# MIGraphXExecutionProvider excluded — symbol mismatch between onnxruntime-rocm
|
||||
# and apt migraphx; ROCMExecutionProvider covers GPU execution adequately.
|
||||
if "ROCMExecutionProvider" in available:
|
||||
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
|
||||
providers.append("CPUExecutionProvider")
|
||||
@@ -52,6 +52,22 @@ def _get_providers(device: str) -> list:
|
||||
return providers
|
||||
|
||||
|
||||
def _ort_type_to_np(ort_type: str):
|
||||
"""Convert ORT type string (e.g. 'tensor(float16)') to numpy dtype."""
|
||||
mapping = {
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(double)": np.float64,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(bool)": np.bool_,
|
||||
}
|
||||
return mapping.get(ort_type, np.float32)
|
||||
|
||||
|
||||
def load_model() -> bool:
|
||||
global _sessions, _tokenizer
|
||||
|
||||
@@ -111,10 +127,13 @@ def prepare_voice(audio_prompt_path: str) -> None:
|
||||
logger.info(f"Preparing voice conditionals for '{audio_prompt_path}'")
|
||||
try:
|
||||
audio, _ = librosa.load(audio_prompt_path, sr=SAMPLE_RATE, mono=True)
|
||||
audio = audio[np.newaxis, :].astype(np.float16) # [1, T]
|
||||
|
||||
session = _sessions["speech_encoder"]
|
||||
input_name = session.get_inputs()[0].name
|
||||
input_meta = session.get_inputs()[0]
|
||||
input_dtype = _ort_type_to_np(input_meta.type)
|
||||
audio = audio[np.newaxis, :].astype(input_dtype) # [1, T]
|
||||
|
||||
input_name = input_meta.name
|
||||
output_names = [o.name for o in session.get_outputs()]
|
||||
outputs = session.run(None, {input_name: audio})
|
||||
|
||||
@@ -171,12 +190,18 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
|
||||
embed_input_name = embed_sess.get_inputs()[0].name
|
||||
|
||||
# Discover KV cache slot names from session metadata
|
||||
# Discover KV cache slot names and dtypes from session metadata
|
||||
past_names = [i.name for i in lm_sess.get_inputs() if "past_key_values" in i.name]
|
||||
present_names = [o.name for o in lm_sess.get_outputs() if "present" in o.name]
|
||||
lm_out_names = [o.name for o in lm_sess.get_outputs()]
|
||||
|
||||
kv_dtype = np.float16
|
||||
# Read expected embed dtype from language_model's inputs_embeds slot
|
||||
embeds_meta = next((i for i in lm_sess.get_inputs() if "embeds" in i.name), None)
|
||||
embed_dtype = _ort_type_to_np(embeds_meta.type) if embeds_meta else np.float32
|
||||
|
||||
# Read KV cache dtype from first past_key_values slot
|
||||
kv_meta = next((i for i in lm_sess.get_inputs() if "past_key_values" in i.name), None)
|
||||
kv_dtype = _ort_type_to_np(kv_meta.type) if kv_meta else np.float16
|
||||
|
||||
# Embed full text sequence
|
||||
text_embeds = embed_sess.run(None, {embed_input_name: input_ids})[0] # [1, seq, hidden]
|
||||
@@ -201,7 +226,7 @@ def _generate(input_ids: np.ndarray, cond: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
|
||||
for _ in range(MAX_NEW_TOKENS):
|
||||
feed = {
|
||||
"inputs_embeds": inputs_embeds.astype(kv_dtype),
|
||||
"inputs_embeds": inputs_embeds.astype(embed_dtype),
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user