From 9dd62138314076d4e1e4fa370fe8437c57e266a4 Mon Sep 17 00:00:00 2001 From: scott Date: Thu, 2 Apr 2026 20:57:03 -0400 Subject: [PATCH] Add ROCm build for faster-whisper with Wyoming endpoint Multi-stage Dockerfile builds CTranslate2 v4.0.0 with HIP/ROCm support targeting gfx1030/gfx1031 (RX 6000 series), then installs faster-whisper and wyoming-faster-whisper on top for a Wyoming ASR server on port 10300. Co-Authored-By: Claude Sonnet 4.6 --- Dockerfile | 89 +++ docker-compose.yml | 29 + patches/ct2_4.0.0_rocm.patch | 1163 ++++++++++++++++++++++++++++++++++ 3 files changed, 1281 insertions(+) create mode 100644 Dockerfile create mode 100644 docker-compose.yml create mode 100755 patches/ct2_4.0.0_rocm.patch diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7973712 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,89 @@ +ARG ROCM_VERSION=6.2 +ARG GPU_ARCH="gfx1030;gfx1031" +ARG ONEDNN_VERSION=3.1.1 + +# ── Stage 1: Build CTranslate2 wheel ───────────────────────────────────────── +FROM rocm/dev-ubuntu-22.04:${ROCM_VERSION} AS builder + +ARG GPU_ARCH +ARG ONEDNN_VERSION + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git cmake cmake-curses-gui \ + libopenblas-dev wget \ + python3-dev python3-pip \ + && rm -rf /var/lib/apt/lists/* + +# Build oneDNN as a static library (CPU fallback / conv ops) +WORKDIR /build +RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VERSION}.tar.gz && \ + tar xf v${ONEDNN_VERSION}.tar.gz && \ + rm v${ONEDNN_VERSION}.tar.gz && \ + cd oneDNN-${ONEDNN_VERSION} && \ + cmake -DCMAKE_BUILD_TYPE=Release \ + -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + "-DONEDNN_ENABLE_PRIMITIVE=CONVOLUTION;REORDER" \ + -DONEDNN_BUILD_GRAPH=OFF . && \ + make -j$(nproc) install && \ + cd /build && rm -rf oneDNN-${ONEDNN_VERSION} + +# Build CTranslate2 with ROCm/HIP +ENV CTRANSLATE2_ROOT=/opt/ctranslate2 +COPY patches/ct2_4.0.0_rocm.patch /tmp/ + +RUN git clone --branch v4.0.0 --depth 1 \ + https://github.com/OpenNMT/CTranslate2.git /build/CTranslate2 && \ + cd /build/CTranslate2 && \ + git submodule update --init --recursive && \ + git apply /tmp/ct2_4.0.0_rocm.patch && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \ + -DWITH_CUDA=ON \ + -DWITH_CUDNN=ON \ + -DWITH_MKL=OFF \ + -DWITH_DNNL=ON \ + -DOPENMP_RUNTIME=COMP \ + -DCMAKE_HIP_ARCHITECTURES="${GPU_ARCH}" \ + -DGPU_TARGETS="${GPU_ARCH}" \ + -DAMDGPU_TARGETS="${GPU_ARCH}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_RUNTIME=HIP \ + -DWITH_OPENBLAS=ON \ + -DENABLE_CPU_DISPATCH=OFF \ + -DBUILD_TESTS=OFF .. && \ + make -j$(nproc) install + +# Build Python wheel +RUN cd /build/CTranslate2/python && \ + pip install --no-cache-dir -r install_requirements.txt && \ + python3 setup.py bdist_wheel --dist-dir /wheels + +# ── Stage 2: Runtime ────────────────────────────────────────────────────────── +FROM rocm/dev-ubuntu-22.04:${ROCM_VERSION} + +ENV CTRANSLATE2_ROOT=/opt/ctranslate2 +ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CTRANSLATE2_ROOT}/lib" +# ROCm HIP allocator settings for stable inference +ENV CT2_CUDA_ALLOCATOR=cub_caching +ENV CT2_CUDA_CACHING_ALLOCATOR_CONFIG=4,3,12,419430400 + +COPY --from=builder /opt/ctranslate2 /opt/ctranslate2 +COPY --from=builder /wheels /wheels + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-pip ffmpeg \ + && rm -rf /var/lib/apt/lists/* \ + && pip install --no-cache-dir /wheels/*.whl \ + && pip install --no-cache-dir \ + "faster-whisper>=1.2.1,<2" \ + "wyoming-faster-whisper>=3.1.1" + +RUN mkdir /data +VOLUME /data +EXPOSE 10300 + +ENTRYPOINT ["python3", "-m", "wyoming_faster_whisper"] +CMD ["--uri", "tcp://0.0.0.0:10300", "--data-dir", "/data", "--model", "Systran/faster-distil-whisper-small.en", "--device", "cuda", "--compute-type", "float16"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..2afc29e --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,29 @@ +services: + wyoming-faster-whisper: + build: + context: . + args: + ROCM_VERSION: "6.2" + GPU_ARCH: "gfx1030;gfx1031" + image: rocm-faster-whisper:latest + devices: + - /dev/kfd + - /dev/dri + group_add: + - video + - render + ports: + - "10300:10300" + volumes: + - ./data:/data + environment: + # Override GFX version if ROCm doesn't natively support your exact chip. + # gfx1031 (RX 6000-series variants) can be told to behave as gfx1030. + - HSA_OVERRIDE_GFX_VERSION=10.3.0 + command: > + --uri tcp://0.0.0.0:10300 + --data-dir /data + --model Systran/faster-distil-whisper-small.en + --device cuda + --compute-type float16 + restart: unless-stopped diff --git a/patches/ct2_4.0.0_rocm.patch b/patches/ct2_4.0.0_rocm.patch new file mode 100755 index 0000000..d5f354d --- /dev/null +++ b/patches/ct2_4.0.0_rocm.patch @@ -0,0 +1,1163 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 1089106c..d47f504d 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -1,4 +1,4 @@ +-cmake_minimum_required(VERSION 3.7) ++cmake_minimum_required(VERSION 3.21 FATAL_ERROR) + + # Set policy for setting the MSVC runtime library for static MSVC builds + if(POLICY CMP0091) +@@ -42,12 +42,12 @@ else() + endif() + set(INTEL_ROOT ${INTEL_ROOT_DEFAULT} CACHE FILEPATH "Path to Intel root directory") + set(OPENMP_RUNTIME "INTEL" CACHE STRING "OpenMP runtime (INTEL, COMP, NONE)") +- ++set_property(CACHE OPENMP_RUNTIME PROPERTY STRINGS INTEL COMP NONE) + # Set Release build type by default to get sane performance. + if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) + endif(NOT CMAKE_BUILD_TYPE) +- ++set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS Debug Release RelWithDebInfo MinSizeRel) + # Set CXX flags. + set(CMAKE_CXX_STANDARD 17) + +@@ -106,7 +106,7 @@ set(SOURCES + src/cpu/primitives.cc + src/decoding.cc + src/decoding_utils.cc +- src/devices.cc ++ #src/devices.cc + src/dtw.cc + src/encoder.cc + src/env.cc +@@ -120,13 +120,13 @@ set(SOURCES + src/layers/whisper.cc + src/logging.cc + src/models/language_model.cc +- src/models/model.cc ++ #src/models/model.cc + src/models/model_factory.cc + src/models/model_reader.cc + src/models/sequence_to_sequence.cc + src/models/transformer.cc +- src/models/wav2vec2.cc +- src/models/whisper.cc ++ #src/models/wav2vec2.cc ++ #src/models/whisper.cc + src/ops/activation.cc + src/ops/add.cc + src/ops/alibi_add.cc +@@ -184,11 +184,11 @@ set(SOURCES + src/random.cc + src/sampling.cc + src/scoring.cc +- src/storage_view.cc ++ #src/storage_view.cc + src/thread_pool.cc + src/translator.cc +- src/types.cc +- src/utils.cc ++ #src/types.cc ++ #src/utils.cc + src/vocabulary.cc + src/vocabulary_map.cc + ) +@@ -416,10 +416,89 @@ if (WITH_RUY) + unset(CMAKE_POSITION_INDEPENDENT_CODE) + list(APPEND LIBRARIES ruy) + endif() ++set(GPU_RUNTIME "HIP" CACHE STRING "GPU_RUNTIME (HIP and CUDA)") ++set_property(CACHE GPU_RUNTIME PROPERTY STRINGS CUDA HIP) ++if(WIN32) ++ set(ROCM_ROOT "$ENV{HIP_PATH}" CACHE PATH "Root directory of the ROCm installation") ++else() ++ if (NOT DEFINED ROCM_PATH ) ++ if (NOT DEFINED ENV{ROCM_PATH} ) ++ set(ROCM_PATH "/opt/rocm" CACHE PATH "Root directory of the ROCm installation") ++ else() ++ set(ROCM_PATH $ENV{ROCM_PATH} CACHE PATH "Root directory of the ROCm installation") ++ endif() ++ endif() ++ set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake) ++endif() + + if (WITH_CUDA) +- find_package(CUDA 11.0 REQUIRED) ++ set(GPU_SOURCES ++ src/cuda/allocator.cc ++ src/cuda/primitives.cu ++ src/cuda/random.cu ++ src/cuda/utils.cc ++ src/ops/alibi_add_gpu.cu ++ src/ops/bias_add_gpu.cu ++ src/ops/concat_split_slide_gpu.cu ++ src/ops/conv1d_gpu.cu ++ src/ops/dequantize_gpu.cu ++ src/ops/gather_gpu.cu ++ src/ops/gumbel_max_gpu.cu ++ src/ops/layer_norm_gpu.cu ++ src/ops/mean_gpu.cu ++ src/ops/multinomial_gpu.cu ++ src/ops/rms_norm_gpu.cu ++ src/ops/rotary_gpu.cu ++ src/ops/softmax_gpu.cu ++ src/ops/tile_gpu.cu ++ src/ops/topk_gpu.cu ++ src/ops/topp_mask_gpu.cu ++ src/ops/quantize_gpu.cu ++ ) + add_definitions(-DCT2_WITH_CUDA) ++ list(APPEND GPU_SOURCES ++ src/devices.cc ++ src/types.cc ++ src/storage_view.cc ++ src/utils.cc ++ src/models/model.cc ++ src/models/wav2vec2.cc ++ src/models/whisper.cc ++ ) ++ if(GPU_RUNTIME STREQUAL "HIP") ++ enable_language(${GPU_RUNTIME}) ++ set(CMAKE_${GPU_RUNTIME}_STANDARD 17) ++ set(CMAKE_${GPU_RUNTIME}_EXTENSIONS OFF) ++ set(CMAKE_${GPU_RUNTIME}_STANDARD_REQUIRED ON) ++ add_definitions(-D__HIP_PLATFORM_AMD__) ++ find_package(hipblas REQUIRED) ++ find_package(hip REQUIRED) ++ ++ if(WITH_CUDNN) ++ find_package(miopen REQUIRED) ++ add_definitions(-DCT2_WITH_CUDNN) ++ endif() ++ # list(APPEND GPU_SOURCES src/devices.cc src/types.cc src/utils.cc src/models/model.cc src/models/whisper.cc) ++ add_library(${PROJECT_NAME} SHARED ++ ${SOURCES} ++ ${GPU_SOURCES} ++ ) ++ set_source_files_properties(${GPU_SOURCES} PROPERTIES LANGUAGE ${GPU_RUNTIME}) ++ set_source_files_properties(${SOURCES} PROPERTIES LANGUAGE CXX) ++ target_link_libraries(${PROJECT_NAME} PUBLIC roc::hipblas) ++ if(WITH_CUDNN) ++ # find_library(LIBRT NAMES "librt.so.1" PATHS /lib/x86_64-linux-gnu/ NO_DEFAULT_PATH ) ++ find_package(miopen REQUIRED) ++ find_library(LIBRT rt) ++ if(LIBRT) ++ message(STATUS "Librt: " ${LIBRT}) ++ target_link_libraries(${PROJECT_NAME} PUBLIC MIOpen ${LIBRT} ) ++ else() ++ target_link_libraries(${PROJECT_NAME} PUBLIC MIOpen ) ++ endif() ++ endif() ++ else() ++ find_package(CUDA 11.0 REQUIRED) + if(MSVC) + if(BUILD_SHARED_LIBS) + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/MD$<$:d>") +@@ -501,31 +580,20 @@ if (WITH_CUDA) + set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) + cuda_add_library(${PROJECT_NAME} + ${SOURCES} +- src/cuda/allocator.cc +- src/cuda/primitives.cu +- src/cuda/random.cu +- src/cuda/utils.cc +- src/ops/alibi_add_gpu.cu +- src/ops/bias_add_gpu.cu +- src/ops/concat_split_slide_gpu.cu +- src/ops/conv1d_gpu.cu +- src/ops/dequantize_gpu.cu +- src/ops/gather_gpu.cu +- src/ops/gumbel_max_gpu.cu +- src/ops/layer_norm_gpu.cu +- src/ops/mean_gpu.cu +- src/ops/multinomial_gpu.cu +- src/ops/rms_norm_gpu.cu +- src/ops/rotary_gpu.cu +- src/ops/softmax_gpu.cu +- src/ops/tile_gpu.cu +- src/ops/topk_gpu.cu +- src/ops/topp_mask_gpu.cu +- src/ops/quantize_gpu.cu ++ ${GPU_SOURCES} + ) ++ endif() + elseif(WITH_CUDNN) + message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") + else() ++ list(APPEND SOURCES ++ src/devices.cc ++ src/types.cc ++ src/storage_view.cc ++ src/utils.cc ++ src/models/model.cc ++ src/models/wav2vec2.cc ++ src/models/whisper.cc) + add_library(${PROJECT_NAME} ${SOURCES}) + endif() + +diff --git a/cli/CMakeLists.txt b/cli/CMakeLists.txt +index 3311ad33..e1dfe246 100644 +--- a/cli/CMakeLists.txt ++++ b/cli/CMakeLists.txt +@@ -10,9 +10,24 @@ add_executable(translator + target_include_directories(translator + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cxxopts/include + ) +-target_link_libraries(translator +- PRIVATE ${PROJECT_NAME} +-) ++if(WITH_CUDA) ++ if(GPU_RUNTIME STREQUAL "HIP") ++ find_package(hipblas REQUIRED) ++ find_package(miopen REQUIRED) ++ target_link_libraries(translator ++ PRIVATE ${PROJECT_NAME} MIOpen #roc::hipblas ++ ) ++ else() ++ target_link_libraries(translator ++ PRIVATE ${PROJECT_NAME} ++ ) ++ endif() ++ ++else() ++ target_link_libraries(translator ++ PRIVATE ${PROJECT_NAME} ++ ) ++endif() + + set_target_properties(translator PROPERTIES OUTPUT_NAME ct2-translator) + +diff --git a/src/cuda/allocator.cc b/src/cuda/allocator.cc +index 2311bd00..0ed42fd0 100644 +--- a/src/cuda/allocator.cc ++++ b/src/cuda/allocator.cc +@@ -6,9 +6,13 @@ + #include "ctranslate2/utils.h" + #include "cuda/utils.h" + #include "env.h" +- +-#include +-#include ++#ifdef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++#else ++ #include ++ #include ++#endif + #include + + namespace ctranslate2 { +@@ -63,7 +67,7 @@ namespace ctranslate2 { + class CudaAsyncAllocator : public Allocator { + public: + void* allocate(size_t size, int device_index) override { +-#if CUDA_VERSION >= 11020 ++#if (CUDA_VERSION >= 11020) || defined (__HIP_PLATFORM_AMD__) + int prev_device_index = -1; + if (device_index >= 0) { + CUDA_CHECK(cudaGetDevice(&prev_device_index)); +@@ -86,7 +90,7 @@ namespace ctranslate2 { + } + + void free(void* ptr, int device_index) override { +-#if CUDA_VERSION >= 11020 ++#if (CUDA_VERSION >= 11020) || defined (__HIP_PLATFORM_AMD__) + int prev_device_index = -1; + if (device_index >= 0) { + CUDA_CHECK(cudaGetDevice(&prev_device_index)); +@@ -107,12 +111,16 @@ namespace ctranslate2 { + }; + + static bool support_cuda_malloc_async() { +-#if CUDA_VERSION < 11020 ++#if CUDA_VERSION < 11020 && !defined(__HIP_PLATFORM_AMD__) + return false; + #else + for (int i = 0; i < get_gpu_count(); ++i) { + int supported = 0; ++#ifdef __HIP_PLATFORM_AMD__ ++ supported = 1; ++#else + cudaDeviceGetAttribute(&supported, cudaDevAttrMemoryPoolsSupported, i); ++#endif + if (!supported) + return false; + } +diff --git a/src/cuda/helpers.h b/src/cuda/helpers.h +index a34d5d89..62c54df1 100644 +--- a/src/cuda/helpers.h ++++ b/src/cuda/helpers.h +@@ -2,21 +2,23 @@ + + #include + #include +- +-#include +-#include +- ++#ifdef __HIP_PLATFORM_AMD__ ++ #include ++#else ++ #include ++ #include ++#endif + #include "ctranslate2/types.h" + + #include "utils.h" + +-#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 ++#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 ||defined(__HIP_PLATFORM_AMD__) + # define CUDA_CAN_USE_HALF 1 + #else + # define CUDA_CAN_USE_HALF 0 + #endif + +-#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) ++#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) || defined(__HIP_PLATFORM_AMD__) + # define CUDA_CAN_USE_BF16_MATH 1 + #else + # define CUDA_CAN_USE_BF16_MATH 0 +@@ -362,9 +364,11 @@ namespace ctranslate2 { + // The following kernels are adapted from: + // https://github.com/pytorch/pytorch/blob/40eff454ce5638fbff638a7f4502e29ffb9a2f0d/aten/src/ATen/native/cuda/SoftMax.cu + // They help define row-wise reduction where each block handles a single row. +- +-#define C10_WARP_SIZE 32 +- ++#ifdef __HIP_PLATFORM_AMD__ ++ #define C10_WARP_SIZE 64 ++#else ++ #define C10_WARP_SIZE 32 ++#endif + template + inline dim3 get_block_size(index_t dim_size) { + index_t block_size = 1; +@@ -400,7 +404,9 @@ namespace ctranslate2 { + for (index_t i = 0; i < C10_WARP_SIZE; ++i) { + warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); + } ++#ifndef __HIP_PLATFORM_AMD__ + __syncwarp(mask); ++#endif + smem[lane] = warpVal; + } + } +diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu +index 149e10db..8fc2825f 100644 +--- a/src/cuda/primitives.cu ++++ b/src/cuda/primitives.cu +@@ -1,7 +1,11 @@ + #include "ctranslate2/primitives.h" +- +-#include +-#include ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++#else ++ #include ++ #include ++#endif + #include + + #include "cuda/helpers.h" +@@ -240,7 +244,11 @@ namespace ctranslate2 { + }; + + template +- __global__ void penalize_previous_tokens_kernel(T* scores, ++ __global__ void ++#ifdef __HIP_PLATFORM_AMD__ ++ __launch_bounds__(64) ++#endif ++ penalize_previous_tokens_kernel(T* scores, + const T* previous_scores, + const int32_t* previous_ids, + float penalty, +@@ -265,7 +273,11 @@ namespace ctranslate2 { + dim_t batch_size, + dim_t length, + dim_t vocabulary_size) { ++#ifndef __HIP_PLATFORM_AMD__ + dim3 block(32); ++#else ++ dim3 block(64); ++#endif + dim3 grid((batch_size * length + block.x - 1) / block.x); + penalize_previous_tokens_kernel<<>>( + cuda::device_cast(scores), +@@ -478,12 +490,12 @@ namespace ctranslate2 { + + const void* alpha_ptr = &alpha_h; + const void* beta_ptr = &beta_h; +- cudaDataType_t compute_type = CUDA_R_16F; ++ cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; //cudaDataType_t compute_type = CUDA_R_16F; + + if (!cuda::use_true_fp16_gemm()) { + alpha_ptr = α + beta_ptr = β +- compute_type = CUDA_R_32F; ++ compute_type = CUBLAS_COMPUTE_32F; //compute_type = CUDA_R_32F; + } + + // cuBLAS assumes column-major storage, so swap a and b accordingly. +@@ -521,7 +533,7 @@ namespace ctranslate2 { + a, CUDA_R_16BF, lda, + &beta, + c, CUDA_R_16BF, ldc, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, //CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + +@@ -549,7 +561,7 @@ namespace ctranslate2 { + a, CUDA_R_8I, lda, + &beta_i, + c, CUDA_R_32I, ldc, +- CUDA_R_32I, ++ CUBLAS_COMPUTE_32I, //CUDA_R_32I, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + +@@ -591,12 +603,12 @@ namespace ctranslate2 { + + const void* alpha_ptr = &alpha_h; + const void* beta_ptr = &beta_h; +- cudaDataType_t compute_type = CUDA_R_16F; ++ cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; //cudaDataType_t compute_type = CUDA_R_16F; + + if (!cuda::use_true_fp16_gemm()) { + alpha_ptr = α + beta_ptr = β +- compute_type = CUDA_R_32F; ++ compute_type = CUBLAS_COMPUTE_32F; //compute_type = CUDA_R_32F; + } + + // cuBLAS assumes column-major storage, so swap a and b accordingly. +@@ -635,7 +647,7 @@ namespace ctranslate2 { + &beta, + c, CUDA_R_16BF, ldc, stridec, + batch_size, +- CUDA_R_32F, ++ CUBLAS_COMPUTE_32F, //CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + +diff --git a/src/cuda/random.cu b/src/cuda/random.cu +index f016bb44..1d081303 100644 +--- a/src/cuda/random.cu ++++ b/src/cuda/random.cu +@@ -22,7 +22,11 @@ namespace ctranslate2 { + ScopedCurandStates(size_t num_states) + : _allocator(get_allocator()) + { ++#ifdef __HIP_PLATFORM_AMD__ ++ constexpr size_t num_init_threads = 64; ++#else + constexpr size_t num_init_threads = 32; ++#endif + const size_t blocks = ceil_divide(num_states, num_init_threads); + _num_states = blocks * num_init_threads; + _states = static_cast(_allocator.allocate(_num_states * sizeof (curandState))); +diff --git a/src/cuda/random.h b/src/cuda/random.h +index e12ae20f..20fe7436 100644 +--- a/src/cuda/random.h ++++ b/src/cuda/random.h +@@ -1,7 +1,10 @@ + #pragma once +- +-#include +- ++#ifdef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++#else ++ #include ++#endif + namespace ctranslate2 { + namespace cuda { + +diff --git a/src/cuda/utils.h b/src/cuda/utils.h +index 29bc99a3..e14f1aab 100644 +--- a/src/cuda/utils.h ++++ b/src/cuda/utils.h +@@ -2,12 +2,22 @@ + + #include + +-#include +-#include ++#ifdef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++ #include ++#else ++ #include ++ #include ++#endif + #include + + #ifdef CT2_WITH_CUDNN +-# include ++ #ifdef __HIP_PLATFORM_AMD__ ++ #include ++ #else ++ #include ++ #endif + #endif + + #include "ctranslate2/types.h" +@@ -79,7 +89,10 @@ namespace ctranslate2 { + }; + + // Convenience macro to call Thrust functions with a default execution policy. +-#define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) +- ++#ifdef __HIP_PLATFORM_AMD__ ++ #define THRUST_CALL(FUN, ...) FUN(thrust::hip::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) ++#else ++ #define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) ++#endif + } + } +diff --git a/src/cuda2hip_macros.hpp b/src/cuda2hip_macros.hpp +new file mode 100644 +index 00000000..48be99c6 +--- /dev/null ++++ b/src/cuda2hip_macros.hpp +@@ -0,0 +1,138 @@ ++#pragma once ++#ifdef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++ #define __nv_bfloat16 hip_bfloat16 ++ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT ++ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT ++ __device__ hip_bfloat16 hlog(const hip_bfloat16 h) { ++ return hip_bfloat16(__ocml_log_f32(float(h))); ++ } ++ __device__ hip_bfloat16 hsin(const hip_bfloat16 h) { ++ return hip_bfloat16(__ocml_sin_f32(float(h))); ++ } ++ __device__ hip_bfloat16 hcos(const hip_bfloat16 h) { ++ return hip_bfloat16(__ocml_cos_f32(float(h))); ++ } ++ __device__ hip_bfloat16 hexp(const hip_bfloat16 h) { ++ return hip_bfloat16(__ocml_exp_f32(float(h))); ++ } ++ ++ __device__ hip_bfloat16 __habs(const hip_bfloat16 a) { ++ auto ret = a; ++ ret.data &= 0x7FFF; ++ return ret; ++ } ++ __device__ hip_bfloat16 __hmax(const hip_bfloat16 a, const hip_bfloat16 b) { ++ return hip_bfloat16(__ocml_fmax_f32(float(a), float(b))); ++ } ++ #define curandStatePhilox4_32_10_t hiprandStatePhilox4_32_10_t ++ #define cublasStatus_t hipblasStatus_t ++ #define cublasHandle_t hipblasHandle_t ++ #define curand_init hiprand_init ++ #define cudaDeviceProp hipDeviceProp_t ++ #define cudaDeviceSynchronize hipDeviceSynchronize ++ #define cudaErrorInsufficientDriver hipErrorInsufficientDriver ++ #define cudaErrorNoDevice hipErrorNoDevice ++ #define cudaError_t hipError_t ++ #define cudaEventCreate hipEventCreate ++ #define cudaEventElapsedTime hipEventElapsedTime ++ #define cudaEventRecord hipEventRecord ++ #define cudaEventSynchronize hipEventSynchronize ++ #define cudaEvent_t hipEvent_t ++ #define cudaFree hipFree ++ #define cudaGetDevice hipGetDevice ++ #define cudaGetDeviceCount hipGetDeviceCount ++ #define cudaGetDeviceProperties hipGetDeviceProperties ++ #define cudaGetErrorString hipGetErrorString ++ #define cudaGetLastError hipGetLastError ++ #define cudaLaunchKernelGGL hipLaunchKernelGGL ++ #define cudaMalloc hipMalloc ++ #define cudaMemcpy hipMemcpy ++ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost ++ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice ++ #define cudaSuccess hipSuccess ++ #define cudaMallocHost hipHostMalloc ++ #define cudaStream_t hipStream_t ++ #define cudaStreamCreate hipStreamCreate ++ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags ++ #define cudaStreamNonBlocking hipStreamNonBlocking ++ #define cudaStreamDestroy hipStreamDestroy ++ #define cudaSetDevice hipSetDevice ++ #define udaMemcpyToSymbol hipMemcpyToSymbol c ++ #define cudaMemcpyAsync hipMemcpyAsync ++ #define cudaFreeHost hipHostFree ++ #define cudaDeviceReset hipDeviceReset ++ #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice ++ #define cudaStreamSynchronize hipStreamSynchronize ++ #define cudaMallocAsync hipMallocAsync ++ #define cudaFreeAsync hipFreeAsync ++ #define cudaDeviceGetAttribute hipDeviceGetAttribute ++ #define cudaDevAttrMemoryPoolsSupported hipDeviceAttributeMemoryPoolsSupported ++ #define CUBLAS_OP_T HIPBLAS_OP_T ++ #define CUBLAS_OP_N HIPBLAS_OP_N ++ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS ++ #define cublasSgemmStridedBatched hipblasSgemmStridedBatched ++ #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx ++ #define cublasSgemm hipblasSgemm ++ #define cublasGemmEx hipblasGemmEx ++ #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F //new added, HIPBLAS_R_16F deprecated ++ #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F //new added ++ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F //new added ++ #define CUBLAS_COMPUTE_32F_FAST_16BF HIPBLAS_COMPUTE_32F_FAST_16BF //new added ++ #define CUBLAS_COMPUTE_32I HIPBLAS_COMPUTE_32I //new added ++ #define CUDA_R_16F HIP_R_16F //HIPBLAS_R_16F deprecated ++ #define CUDA_R_32F HIP_R_32F //HIPBLAS_R_32F deprecated ++ #define CUDA_R_16B HIP_R_16BF //HIPBLAS_R_16B deprecated ++ #define CUDA_R_16BF HIP_R_16BF //HIPBLAS_R_16B deprecated ++ #define CUDA_R_32I HIP_R_32I //HIPBLAS_R_32I deprecated ++ #define CUDA_R_8I HIP_R_8I //HIPBLAS_R_8I deprecated ++ #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED ++ #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED ++ #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE ++ #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH ++ #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR ++ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED ++ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR ++ #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED ++ #define CUBLAS_STATUS_LICENSE_ERROR HIPBLAS_STATUS_UNKNOWN ++ #define cublasCreate hipblasCreate ++ #define cublasSetStream hipblasSetStream ++ #define cublasDestroy hipblasDestroy ++ #define cublasComputeType_t hipblasComputeType_t //hipblasDatatype_t & hipblasLtComputeType_t deprecated ++ #define cudaDataType_t hipDataType //hipblasDatatype_t deprecated, use hipDataType for rocm7 ++ #define cub hipcub ++ #define cudaStreamDefault hipStreamDefault ++ #define curand_uniform hiprand_uniform ++//cudnn vs miopen ++ #define cudnnHandle_t miopenHandle_t ++ #define cudnnDataType_t miopenDataType_t ++ #define cudnnStatus_t miopenStatus_t ++ #define cudnnGetErrorString miopenGetErrorString ++ #define cudnnCreate miopenCreate ++ #define cudnnSetStream miopenSetStream ++ #define cudnnDestroy miopenDestroy ++ #define CUDNN_STATUS_SUCCESS miopenStatusSuccess ++ #define CUDNN_DATA_FLOAT miopenFloat ++ #define CUDNN_DATA_HALF miopenHalf ++ #define CUDNN_DATA_BFLOAT16 miopenBFloat16 ++ #define CUDNN_DATA_BFLOAT16 miopenBFloat16 ++ #define CUDNN_DATA_INT32 miopenInt32 ++ #define CUDNN_DATA_INT8 miopenInt8 ++ #define cudnnCreateTensorDescriptor miopenCreateTensorDescriptor ++ #define cudnnTensorDescriptor_t miopenTensorDescriptor_t ++ #define cudnnSetTensor4dDescriptor miopenSet4dTensorDescriptor ++ #define cudnnFilterDescriptor_t miopenTensorDescriptor_t ++ #define cudnnCreateFilterDescriptor miopenCreateTensorDescriptor ++ #define cudnnSetFilter4dDescriptor miopenSet4dTensorDescriptor ++ #define cudnnActivationDescriptor_t miopenActivationDescriptor_t ++ #define cudnnCreateActivationDescriptor miopenCreateActivationDescriptor ++ #define cudnnConvolutionDescriptor_t miopenConvolutionDescriptor_t ++ #define cudnnCreateConvolutionDescriptor miopenCreateConvolutionDescriptor ++ #define cudnnDestroyActivationDescriptor miopenDestroyActivationDescriptor ++ #define cudnnDestroyTensorDescriptor miopenDestroyTensorDescriptor ++ #define cudnnDestroyConvolutionDescriptor miopenDestroyConvolutionDescriptor ++ #define cudnnConvolutionBiasActivationForward miopenConvolutionBiasActivationForward ++ #define cudnnDestroyFilterDescriptor miopenDestroyTensorDescriptor ++#endif ++ +diff --git a/src/ops/bias_add_gpu.cu b/src/ops/bias_add_gpu.cu +index 8f53bcf6..47222b02 100644 +--- a/src/ops/bias_add_gpu.cu ++++ b/src/ops/bias_add_gpu.cu +@@ -7,9 +7,9 @@ namespace ctranslate2 { + namespace ops { + + template +- __global__ void bias_add_kernel(const T* value, +- const T* bias, +- T* output, ++ __global__ void bias_add_kernel(const T* __restrict__ value, ++ const T* __restrict__ bias, ++ T* __restrict__ output, + cuda::index_t depth, + const AddFunc& add_func, + const Epilogue& epilogue) { +diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu +index 3b389358..be1fa170 100644 +--- a/src/ops/conv1d_gpu.cu ++++ b/src/ops/conv1d_gpu.cu +@@ -30,40 +30,66 @@ namespace ctranslate2 { + + cudnnTensorDescriptor_t input_desc; + CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); +- CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, data_type, +- batch_size, in_channels, 1, input_length)); ++ CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, ++#ifndef __HIP_PLATFORM_AMD__ ++ CUDNN_TENSOR_NCHW, ++#endif ++ data_type, batch_size, in_channels, 1, input_length)); + + cudnnTensorDescriptor_t output_desc; + CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); +- CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, data_type, +- batch_size, out_channels, 1, output_length)); ++ CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, ++#ifndef __HIP_PLATFORM_AMD__ ++ CUDNN_TENSOR_NCHW, ++#endif ++ data_type, batch_size, out_channels, 1, output_length)); + + cudnnFilterDescriptor_t weight_desc; + CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc)); +- CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW, +- out_channels, in_channels, 1, kernel_size)); ++ CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, ++#ifndef __HIP_PLATFORM_AMD__ ++ CUDNN_TENSOR_NCHW, ++#endif ++ out_channels, in_channels, 1, kernel_size)); + + cudnnConvolutionDescriptor_t conv_desc; + CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); ++#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, + /*pad_h=*/0, /*pad_w=*/_padding, + /*stride_h=*/1, /*stride_w=*/_stride, + /*dilation_h=*/1, /*dilation_w=*/_dilation, + CUDNN_CROSS_CORRELATION, + CUDNN_DATA_FLOAT)); +- ++#else ++ CUDNN_CHECK(miopenInitConvolutionDescriptor(conv_desc, ++ miopenConvolution, ++ /*pad_h=*/0, /*pad_w=*/_padding, ++ /*stride_h=*/1, /*stride_w=*/_stride, ++ /*dilation_h=*/1, /*dilation_w=*/_dilation) ); ++#endif ++#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); + if (data_type == CUDNN_DATA_HALF) + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); +- ++#endif + cudnnHandle_t handle = cuda::get_cudnn_handle(); + ++#ifndef __HIP_PLATFORM_AMD__ + cudnnConvolutionFwdAlgo_t algo = (bias + ? CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM); ++#else ++ miopenConvFwdAlgorithm_t algo = (bias ++ ? miopenConvolutionFwdAlgoGEMM ++ : miopenConvolutionFwdAlgoImplicitGEMM); ++#endif ++ + + size_t workspace_size = 0; + void* workspace = nullptr; ++#ifndef __HIP_PLATFORM_AMD__ ++ + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, + input_desc, + weight_desc, +@@ -74,23 +100,85 @@ namespace ctranslate2 { + + if (workspace_size > 0) + workspace = get_allocator().allocate(workspace_size); +- ++#else ++ std::size_t count; ++ CUDNN_CHECK(miopenConvolutionForwardGetSolutionCount(handle, ++ weight_desc, ++ input_desc, ++ conv_desc, ++ output_desc, ++ &count)); ++ if(count <1){ ++ std::cout<<"count: "<(count); ++ CUDNN_CHECK(miopenConvolutionForwardGetSolution(handle, ++ weight_desc, ++ input_desc, ++ conv_desc, ++ output_desc, ++ count, ++ &count, ++ solutions.data())); ++ const miopenConvSolution_t* selected = &solutions.front(); ++ CUDNN_CHECK(miopenConvolutionForwardGetSolutionWorkspaceSize(handle, ++ weight_desc, ++ input_desc, ++ conv_desc, ++ output_desc, ++ selected->solution_id, ++ &workspace_size)); ++ if (workspace_size > 0){ ++ workspace = get_allocator().allocate(workspace_size); ++ } ++ CUDNN_CHECK(miopenConvolutionForwardCompileSolution(handle, ++ weight_desc, ++ input_desc, ++ conv_desc, ++ output_desc, ++ selected->solution_id)); ++ ++ CUDNN_CHECK(miopenConvolutionForwardImmediate(handle, ++ weight_desc, ++ weight.buffer(), ++ input_desc, ++ input.buffer(), ++ conv_desc, ++ output_desc, ++ output.buffer(), ++ workspace, ++ workspace_size, ++ selected->solution_id)); ++#endif + float alpha = 1; + float beta = 0; + + if (bias) { + cudnnTensorDescriptor_t bias_desc; + CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); +- CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, data_type, +- 1, out_channels, 1, 1)); ++ CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, ++#ifndef __HIP_PLATFORM_AMD__ ++ CUDNN_TENSOR_NCHW, ++#endif ++ data_type, 1, out_channels, 1, 1)); ++ + + cudnnActivationDescriptor_t activation_desc; + CUDNN_CHECK(cudnnCreateActivationDescriptor(&activation_desc)); ++#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnSetActivationDescriptor(activation_desc, + CUDNN_ACTIVATION_IDENTITY, + CUDNN_NOT_PROPAGATE_NAN, + /*coef=*/0)); +- ++#else ++ CUDNN_CHECK(miopenSetActivationDescriptor(activation_desc, ++ miopenActivationPASTHRU, ++ 0, ++ 0, ++ 0)); ++#endif ++#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle, + &alpha, + input_desc, +@@ -109,11 +197,21 @@ namespace ctranslate2 { + activation_desc, + output_desc, + output.buffer())); ++#else ++ CUDNN_CHECK(miopenConvolutionForwardBias(handle, ++ &alpha, ++ bias_desc, ++ bias->buffer(), ++ &beta, ++ output_desc, ++ output.buffer())); ++#endif + + CUDNN_CHECK(cudnnDestroyActivationDescriptor(activation_desc)); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); + + } else { ++#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnConvolutionForward(handle, + &alpha, + input_desc, +@@ -127,6 +225,7 @@ namespace ctranslate2 { + &beta, + output_desc, + output.buffer())); ++#endif + } + + if (workspace) +diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu +index 241b3acd..99de40b9 100644 +--- a/src/ops/dequantize_gpu.cu ++++ b/src/ops/dequantize_gpu.cu +@@ -4,6 +4,16 @@ + + namespace ctranslate2 { + namespace ops { ++#ifdef __HIP_PLATFORM_AMD__ ++ template ++ struct dequantize_func { ++ __device__ __forceinline__ ++ OutT operator()(float scale, InT x) const { ++ return OutT(__fdividef(float(x), scale)); ++ } ++ }; ++ ++#else + + template + struct dequantize_func { +@@ -12,7 +22,7 @@ namespace ctranslate2 { + return __fdividef(static_cast(x), scale); + } + }; +- ++#endif + template + void Dequantize::dequantize(const StorageView& input, + const StorageView& scale, +diff --git a/src/ops/gumbel_max_gpu.cu b/src/ops/gumbel_max_gpu.cu +index 160390c4..b50b5343 100644 +--- a/src/ops/gumbel_max_gpu.cu ++++ b/src/ops/gumbel_max_gpu.cu +@@ -17,7 +17,11 @@ namespace ctranslate2 { + template + __device__ DataType operator()(DataType value, IndexType id) const { + const float z = -logf(curand_uniform(_states + id)); ++#ifdef __HIP_PLATFORM_AMD__ ++ return value + DataType(z); ++#else + return float(value) + z; ++#endif + } + + private: +diff --git a/src/ops/layer_norm_gpu.cu b/src/ops/layer_norm_gpu.cu +index 8c644d87..7dfcae9a 100644 +--- a/src/ops/layer_norm_gpu.cu ++++ b/src/ops/layer_norm_gpu.cu +@@ -139,14 +139,17 @@ namespace ctranslate2 { + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + */ +- +-#include +- ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++#else ++ #include ++ #include ++#endif + namespace at { + namespace native { + + template +- __global__ void LayerNormForwardCUDAKernel(SizeT N, ++ __global__ void __launch_bounds__(CUDA_NUM_THREADS) LayerNormForwardCUDAKernel(SizeT N, + float eps, + const T* X, + const T* gamma, +diff --git a/src/ops/mean_gpu.cu b/src/ops/mean_gpu.cu +index 5125924c..508d9706 100644 +--- a/src/ops/mean_gpu.cu ++++ b/src/ops/mean_gpu.cu +@@ -1,7 +1,10 @@ + #include "ctranslate2/ops/mean.h" +- +-#include +- ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++#else ++#include ++#include ++#endif + #include "type_dispatch.h" + #include "cuda/helpers.h" + +@@ -11,7 +14,7 @@ namespace ctranslate2 { + constexpr dim_t num_threads = 256; + + template +- __global__ void mean_kernel(const T* input, ++ __global__ void __launch_bounds__(num_threads) mean_kernel(const T* input, + const cuda::index_t outer_size, + const cuda::index_t axis_size, + const cuda::index_t inner_size, +diff --git a/src/ops/multinomial_gpu.cu b/src/ops/multinomial_gpu.cu +index 90f36377..e2425836 100644 +--- a/src/ops/multinomial_gpu.cu ++++ b/src/ops/multinomial_gpu.cu +@@ -1,8 +1,12 @@ + #include "ctranslate2/ops/multinomial.h" +- +-#include +-#include +- ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++ #include ++#else ++ #include ++ #include ++ #include ++#endif + #include "cuda/helpers.h" + #include "cuda/random.h" + +@@ -24,7 +28,7 @@ namespace ctranslate2 { + constexpr dim_t num_threads = 256; + + template +- __global__ void multinomial_kernel(const In* probs, ++ __global__ void __launch_bounds__(num_threads) multinomial_kernel(const In* probs, + cuda::index_t class_size, + Out* output, + curandStatePhilox4_32_10_t* states) { +diff --git a/src/ops/rms_norm_gpu.cu b/src/ops/rms_norm_gpu.cu +index 41130cd9..f2734ede 100644 +--- a/src/ops/rms_norm_gpu.cu ++++ b/src/ops/rms_norm_gpu.cu +@@ -1,7 +1,10 @@ + #include "ctranslate2/ops/rms_norm.h" +- +-#include +- ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++#else ++ #include ++ #include ++#endif + #include "cuda/helpers.h" + #include "cuda/utils.h" + +@@ -11,7 +14,7 @@ namespace ctranslate2 { + constexpr dim_t num_threads = 512; + + template +- __global__ void rms_norm_kernel(const T* input, ++ __global__ void __launch_bounds__(num_threads) rms_norm_kernel(const T* input, + const T* gamma, + T* output, + cuda::index_t depth, +diff --git a/src/ops/topk_gpu.cu b/src/ops/topk_gpu.cu +index ad010fb4..ffbe192e 100644 +--- a/src/ops/topk_gpu.cu ++++ b/src/ops/topk_gpu.cu +@@ -115,9 +115,12 @@ namespace ctranslate2 { + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + */ +- +-#include +- ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++#else ++ #include ++ #include ++#endif + namespace fastertransformer { + + #define NOT_FOUND -1 +diff --git a/src/ops/topp_mask_gpu.cu b/src/ops/topp_mask_gpu.cu +index a4e6cb6e..4d6661b0 100644 +--- a/src/ops/topp_mask_gpu.cu ++++ b/src/ops/topp_mask_gpu.cu +@@ -1,6 +1,10 @@ + #include "ctranslate2/ops/topp_mask.h" +- +-#include ++#ifndef __HIP_PLATFORM_AMD__ ++ #include ++#else ++ #include ++ #include ++#endif + + #include "cuda/helpers.h" + +@@ -10,7 +14,7 @@ namespace ctranslate2 { + constexpr dim_t num_threads = 256; + + template +- __global__ void topp_mask_kernel(const T* input, ++ __global__ void __launch_bounds__(num_threads) topp_mask_kernel(const T* input, + const T* probs, + T* output, + const float p, +diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt +index 283c49db..87906a7e 100644 +--- a/tests/CMakeLists.txt ++++ b/tests/CMakeLists.txt +@@ -18,18 +18,45 @@ add_executable(ctranslate2_test + target_include_directories(ctranslate2_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ) +-target_link_libraries(ctranslate2_test +- ${PROJECT_NAME} +- gtest_main ++if(WITH_CUDA) ++ if(GPU_RUNTIME STREQUAL "HIP") ++ find_package(hipblas REQUIRED) ++ find_package(miopen REQUIRED) ++ target_link_libraries(ctranslate2_test ++ ${PROJECT_NAME} ++ gtest_main # roc::hipblas ++ MIOpen ++ ) ++ else() ++ target_link_libraries(ctranslate2_test ++ ${PROJECT_NAME} ++ gtest_main ++ ) ++ endif() ++else() ++ ++ target_link_libraries(ctranslate2_test ++ ${PROJECT_NAME} ++ gtest_main + ) ++endif() + + add_executable(benchmark_ops + benchmark_ops.cc + ) ++target_include_directories(benchmark_ops PRIVATE ++ ${CMAKE_CURRENT_SOURCE_DIR}/../include ++ ) + target_link_libraries(benchmark_ops +- ${PROJECT_NAME} +- ) ++ ${PROJECT_NAME} ++ ) + + if(WITH_CUDA) +- target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) ++ set_source_files_properties(benchmark_ops.cc PROPERTIES LANGUAGE ${GPU_RUNTIME}) ++ if(GPU_RUNTIME STREQUAL "HIP") ++ find_package(hipblas REQUIRED) ++ target_link_libraries(benchmark_ops roc::hipblas) ++ else() ++ target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) ++ endif() + endif() +diff --git a/tests/benchmark_utils.h b/tests/benchmark_utils.h +index 55ee77b0..c35da837 100644 +--- a/tests/benchmark_utils.h ++++ b/tests/benchmark_utils.h +@@ -5,8 +5,13 @@ + #include + + #ifdef CT2_WITH_CUDA +-# include +-# define SYNCHRONIZE cudaDeviceSynchronize() ++ #ifndef __HIP_PLATFORM_AMD__ ++ #include ++ #define SYNCHRONIZE cudaDeviceSynchronize() ++ #else ++ #include ++ #define SYNCHRONIZE hipDeviceSynchronize() ++ #endif + #else + # define SYNCHRONIZE do {} while (false) + #endif