Replace individual ROCm libs with rocm-libs meta-package; graceful provider fallback
All checks were successful
Build ROCm Image / build (push) Successful in 14m51s

- Dockerfile: rocm-libs installs all ROCm compute libraries at once
  (hipblas, hipfft, hipsparse, rocblas, miopen, etc.) avoiding whack-a-mole
  with individual missing .so files
- engine.py: query ort.get_available_providers() at runtime and only request
  providers that actually loaded — falls back to CPU instead of crashing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 19:34:28 -04:00
parent f0ab3c1d59
commit 0df6a6cc8f
2 changed files with 16 additions and 9 deletions

View File

@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git \ git \
ffmpeg \ ffmpeg \
libsndfile1 \ libsndfile1 \
hipblas \ rocm-libs \
migraphx \ migraphx \
migraphx-dev \ migraphx-dev \
half \ half \

View File

@@ -28,7 +28,8 @@ _cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
def detect_device() -> str: def detect_device() -> str:
try: try:
import onnxruntime as ort import onnxruntime as ort
if "ROCMExecutionProvider" in ort.get_available_providers(): available = ort.get_available_providers()
if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available:
return "rocm" return "rocm"
except Exception: except Exception:
pass pass
@@ -36,13 +37,19 @@ def detect_device() -> str:
def _get_providers(device: str) -> list: def _get_providers(device: str) -> list:
if device in ("rocm", "cuda"): if device not in ("rocm", "cuda"):
return [ return ["CPUExecutionProvider"]
"MIGraphXExecutionProvider",
("ROCMExecutionProvider", {"device_id": 0}), import onnxruntime as ort
"CPUExecutionProvider", available = set(ort.get_available_providers())
] providers = []
return ["CPUExecutionProvider"] if "MIGraphXExecutionProvider" in available:
providers.append("MIGraphXExecutionProvider")
if "ROCMExecutionProvider" in available:
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
providers.append("CPUExecutionProvider")
logger.info(f"Available ORT providers: {available}")
return providers
def load_model() -> bool: def load_model() -> bool: