From 0df6a6cc8f0e9ae5630a3d9865af53da606550e3 Mon Sep 17 00:00:00 2001 From: scott Date: Mon, 6 Apr 2026 19:34:28 -0400 Subject: [PATCH] Replace individual ROCm libs with rocm-libs meta-package; graceful provider fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Dockerfile: rocm-libs installs all ROCm compute libraries at once (hipblas, hipfft, hipsparse, rocblas, miopen, etc.) avoiding whack-a-mole with individual missing .so files - engine.py: query ort.get_available_providers() at runtime and only request providers that actually loaded — falls back to CPU instead of crashing Co-Authored-By: Claude Sonnet 4.6 --- Dockerfile.rocm | 2 +- engine.py | 23 +++++++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 3add5cc..39fa3f6 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git \ ffmpeg \ libsndfile1 \ - hipblas \ + rocm-libs \ migraphx \ migraphx-dev \ half \ diff --git a/engine.py b/engine.py index 928fa66..113e9bf 100644 --- a/engine.py +++ b/engine.py @@ -28,7 +28,8 @@ _cond_cache: Dict[str, Dict[str, np.ndarray]] = {} def detect_device() -> str: try: import onnxruntime as ort - if "ROCMExecutionProvider" in ort.get_available_providers(): + available = ort.get_available_providers() + if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available: return "rocm" except Exception: pass @@ -36,13 +37,19 @@ def detect_device() -> str: def _get_providers(device: str) -> list: - if device in ("rocm", "cuda"): - return [ - "MIGraphXExecutionProvider", - ("ROCMExecutionProvider", {"device_id": 0}), - "CPUExecutionProvider", - ] - return ["CPUExecutionProvider"] + if device not in ("rocm", "cuda"): + return ["CPUExecutionProvider"] + + import onnxruntime as ort + available = set(ort.get_available_providers()) + providers = [] + if "MIGraphXExecutionProvider" in available: + providers.append("MIGraphXExecutionProvider") + if "ROCMExecutionProvider" in available: + providers.append(("ROCMExecutionProvider", {"device_id": 0})) + providers.append("CPUExecutionProvider") + logger.info(f"Available ORT providers: {available}") + return providers def load_model() -> bool: