ARG ROCM_VERSION=7.0
ARG GPU_ARCH="gfx1030;gfx1031"
ARG ONEDNN_VERSION=3.1.1

# ── Stage 1: Build CTranslate2 wheel ─────────────────────────────────────────
FROM rocm/dev-ubuntu-22.04:${ROCM_VERSION} AS builder

ARG GPU_ARCH
ARG ONEDNN_VERSION

RUN apt-get update && apt-get install -y --no-install-recommends \
        git cmake cmake-curses-gui \
        libopenblas-dev wget \
        python3-dev python3-pip \
        hipblas-dev miopen-hip-dev rocthrust-dev hiprand-dev hipcub-dev rocrand-dev \
    && rm -rf /var/lib/apt/lists/*

# Build oneDNN as a static library (CPU fallback / conv ops)
WORKDIR /build
RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VERSION}.tar.gz && \
    tar xf v${ONEDNN_VERSION}.tar.gz && \
    rm v${ONEDNN_VERSION}.tar.gz && \
    cd oneDNN-${ONEDNN_VERSION} && \
    cmake -DCMAKE_BUILD_TYPE=Release \
          -DONEDNN_LIBRARY_TYPE=STATIC \
          -DONEDNN_BUILD_EXAMPLES=OFF \
          -DONEDNN_BUILD_TESTS=OFF \
          -DONEDNN_ENABLE_WORKLOAD=INFERENCE \
          "-DONEDNN_ENABLE_PRIMITIVE=CONVOLUTION;REORDER" \
          -DONEDNN_BUILD_GRAPH=OFF . && \
    make -j$(nproc) install && \
    cd /build && rm -rf oneDNN-${ONEDNN_VERSION}

# Build CTranslate2 with ROCm/HIP
ENV CTRANSLATE2_ROOT=/opt/ctranslate2
COPY patches/ct2_4.0.0_rocm.patch /tmp/

RUN git clone --branch v4.0.0 --depth 1 \
        https://github.com/OpenNMT/CTranslate2.git /build/CTranslate2 && \
    cd /build/CTranslate2 && \
    git submodule update --init --recursive && \
    git apply /tmp/ct2_4.0.0_rocm.patch && \
    mkdir build && cd build && \
    cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \
          -DCMAKE_PREFIX_PATH=/opt/rocm \
          -DWITH_CUDA=ON \
          -DWITH_CUDNN=ON \
          -DWITH_MKL=OFF \
          -DWITH_DNNL=ON \
          -DOPENMP_RUNTIME=COMP \
          -DCMAKE_HIP_ARCHITECTURES="${GPU_ARCH}" \
          -DGPU_TARGETS="${GPU_ARCH}" \
          -DAMDGPU_TARGETS="${GPU_ARCH}" \
          -DCMAKE_BUILD_TYPE=Release \
          -DGPU_RUNTIME=HIP \
          -DWITH_OPENBLAS=ON \
          -DENABLE_CPU_DISPATCH=OFF \
          -DBUILD_TESTS=OFF .. && \
    make -j$(nproc) install

# Build Python wheel
RUN cd /build/CTranslate2/python && \
    pip install --no-cache-dir -r install_requirements.txt && \
    python3 setup.py bdist_wheel --dist-dir /wheels

# ── Stage 2: Runtime ──────────────────────────────────────────────────────────
FROM rocm/dev-ubuntu-22.04:${ROCM_VERSION}

ENV CTRANSLATE2_ROOT=/opt/ctranslate2
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CTRANSLATE2_ROOT}/lib"
# ROCm HIP allocator settings for stable inference
ENV CT2_CUDA_ALLOCATOR=cub_caching
ENV CT2_CUDA_CACHING_ALLOCATOR_CONFIG=4,3,12,419430400

COPY --from=builder /opt/ctranslate2 /opt/ctranslate2
COPY --from=builder /wheels /wheels

RUN apt-get update && apt-get install -y --no-install-recommends \
        python3 python3-pip ffmpeg \
        hipblas miopen-hip rocrand libopenblas0 \
    && rm -rf /var/lib/apt/lists/* \
    && pip install --no-cache-dir /wheels/*.whl \
    && pip install --no-cache-dir \
        "faster-whisper>=1.2.1,<2" \
        "wyoming-faster-whisper>=3.1.0"

RUN mkdir /data
VOLUME /data
EXPOSE 10405

ENTRYPOINT ["python3", "-m", "wyoming_faster_whisper"]
CMD ["--uri", "tcp://0.0.0.0:10405", "--data-dir", "/data", "--model", "Systran/faster-distil-whisper-small.en", "--device", "cuda", "--compute-type", "float16"]
