diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 3add5cc..39fa3f6 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git \ ffmpeg \ libsndfile1 \ - hipblas \ + rocm-libs \ migraphx \ migraphx-dev \ half \ diff --git a/engine.py b/engine.py index 928fa66..113e9bf 100644 --- a/engine.py +++ b/engine.py @@ -28,7 +28,8 @@ _cond_cache: Dict[str, Dict[str, np.ndarray]] = {} def detect_device() -> str: try: import onnxruntime as ort - if "ROCMExecutionProvider" in ort.get_available_providers(): + available = ort.get_available_providers() + if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available: return "rocm" except Exception: pass @@ -36,13 +37,19 @@ def detect_device() -> str: def _get_providers(device: str) -> list: - if device in ("rocm", "cuda"): - return [ - "MIGraphXExecutionProvider", - ("ROCMExecutionProvider", {"device_id": 0}), - "CPUExecutionProvider", - ] - return ["CPUExecutionProvider"] + if device not in ("rocm", "cuda"): + return ["CPUExecutionProvider"] + + import onnxruntime as ort + available = set(ort.get_available_providers()) + providers = [] + if "MIGraphXExecutionProvider" in available: + providers.append("MIGraphXExecutionProvider") + if "ROCMExecutionProvider" in available: + providers.append(("ROCMExecutionProvider", {"device_id": 0})) + providers.append("CPUExecutionProvider") + logger.info(f"Available ORT providers: {available}") + return providers def load_model() -> bool: