Replace individual ROCm libs with rocm-libs meta-package; graceful provider fallback
All checks were successful
Build ROCm Image / build (push) Successful in 14m51s
All checks were successful
Build ROCm Image / build (push) Successful in 14m51s
- Dockerfile: rocm-libs installs all ROCm compute libraries at once (hipblas, hipfft, hipsparse, rocblas, miopen, etc.) avoiding whack-a-mole with individual missing .so files - engine.py: query ort.get_available_providers() at runtime and only request providers that actually loaded — falls back to CPU instead of crashing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
git \
|
git \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
libsndfile1 \
|
libsndfile1 \
|
||||||
hipblas \
|
rocm-libs \
|
||||||
migraphx \
|
migraphx \
|
||||||
migraphx-dev \
|
migraphx-dev \
|
||||||
half \
|
half \
|
||||||
|
|||||||
23
engine.py
23
engine.py
@@ -28,7 +28,8 @@ _cond_cache: Dict[str, Dict[str, np.ndarray]] = {}
|
|||||||
def detect_device() -> str:
|
def detect_device() -> str:
|
||||||
try:
|
try:
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
if "ROCMExecutionProvider" in ort.get_available_providers():
|
available = ort.get_available_providers()
|
||||||
|
if "ROCMExecutionProvider" in available or "MIGraphXExecutionProvider" in available:
|
||||||
return "rocm"
|
return "rocm"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -36,13 +37,19 @@ def detect_device() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _get_providers(device: str) -> list:
|
def _get_providers(device: str) -> list:
|
||||||
if device in ("rocm", "cuda"):
|
if device not in ("rocm", "cuda"):
|
||||||
return [
|
return ["CPUExecutionProvider"]
|
||||||
"MIGraphXExecutionProvider",
|
|
||||||
("ROCMExecutionProvider", {"device_id": 0}),
|
import onnxruntime as ort
|
||||||
"CPUExecutionProvider",
|
available = set(ort.get_available_providers())
|
||||||
]
|
providers = []
|
||||||
return ["CPUExecutionProvider"]
|
if "MIGraphXExecutionProvider" in available:
|
||||||
|
providers.append("MIGraphXExecutionProvider")
|
||||||
|
if "ROCMExecutionProvider" in available:
|
||||||
|
providers.append(("ROCMExecutionProvider", {"device_id": 0}))
|
||||||
|
providers.append("CPUExecutionProvider")
|
||||||
|
logger.info(f"Available ORT providers: {available}")
|
||||||
|
return providers
|
||||||
|
|
||||||
|
|
||||||
def load_model() -> bool:
|
def load_model() -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user