diff --git a/CMakeLists.txt b/CMakeLists.txt index 1089106c..d47f504d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.7) +cmake_minimum_required(VERSION 3.21 FATAL_ERROR) # Set policy for setting the MSVC runtime library for static MSVC builds if(POLICY CMP0091) @@ -42,12 +42,12 @@ else() endif() set(INTEL_ROOT ${INTEL_ROOT_DEFAULT} CACHE FILEPATH "Path to Intel root directory") set(OPENMP_RUNTIME "INTEL" CACHE STRING "OpenMP runtime (INTEL, COMP, NONE)") - +set_property(CACHE OPENMP_RUNTIME PROPERTY STRINGS INTEL COMP NONE) # Set Release build type by default to get sane performance. if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif(NOT CMAKE_BUILD_TYPE) - +set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS Debug Release RelWithDebInfo MinSizeRel) # Set CXX flags. set(CMAKE_CXX_STANDARD 17) @@ -106,7 +106,7 @@ set(SOURCES src/cpu/primitives.cc src/decoding.cc src/decoding_utils.cc - src/devices.cc + #src/devices.cc src/dtw.cc src/encoder.cc src/env.cc @@ -120,13 +120,13 @@ set(SOURCES src/layers/whisper.cc src/logging.cc src/models/language_model.cc - src/models/model.cc + #src/models/model.cc src/models/model_factory.cc src/models/model_reader.cc src/models/sequence_to_sequence.cc src/models/transformer.cc - src/models/wav2vec2.cc - src/models/whisper.cc + #src/models/wav2vec2.cc + #src/models/whisper.cc src/ops/activation.cc src/ops/add.cc src/ops/alibi_add.cc @@ -184,11 +184,11 @@ set(SOURCES src/random.cc src/sampling.cc src/scoring.cc - src/storage_view.cc + #src/storage_view.cc src/thread_pool.cc src/translator.cc - src/types.cc - src/utils.cc + #src/types.cc + #src/utils.cc src/vocabulary.cc src/vocabulary_map.cc ) @@ -416,10 +416,89 @@ if (WITH_RUY) unset(CMAKE_POSITION_INDEPENDENT_CODE) list(APPEND LIBRARIES ruy) endif() +set(GPU_RUNTIME "HIP" CACHE STRING "GPU_RUNTIME (HIP and CUDA)") +set_property(CACHE GPU_RUNTIME PROPERTY STRINGS CUDA HIP) +if(WIN32) + set(ROCM_ROOT "$ENV{HIP_PATH}" CACHE PATH "Root directory of the ROCm installation") +else() + if (NOT DEFINED ROCM_PATH ) + if (NOT DEFINED ENV{ROCM_PATH} ) + set(ROCM_PATH "/opt/rocm" CACHE PATH "Root directory of the ROCm installation") + else() + set(ROCM_PATH $ENV{ROCM_PATH} CACHE PATH "Root directory of the ROCm installation") + endif() + endif() + set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake) +endif() if (WITH_CUDA) - find_package(CUDA 11.0 REQUIRED) + set(GPU_SOURCES + src/cuda/allocator.cc + src/cuda/primitives.cu + src/cuda/random.cu + src/cuda/utils.cc + src/ops/alibi_add_gpu.cu + src/ops/bias_add_gpu.cu + src/ops/concat_split_slide_gpu.cu + src/ops/conv1d_gpu.cu + src/ops/dequantize_gpu.cu + src/ops/gather_gpu.cu + src/ops/gumbel_max_gpu.cu + src/ops/layer_norm_gpu.cu + src/ops/mean_gpu.cu + src/ops/multinomial_gpu.cu + src/ops/rms_norm_gpu.cu + src/ops/rotary_gpu.cu + src/ops/softmax_gpu.cu + src/ops/tile_gpu.cu + src/ops/topk_gpu.cu + src/ops/topp_mask_gpu.cu + src/ops/quantize_gpu.cu + ) add_definitions(-DCT2_WITH_CUDA) + list(APPEND GPU_SOURCES + src/devices.cc + src/types.cc + src/storage_view.cc + src/utils.cc + src/models/model.cc + src/models/wav2vec2.cc + src/models/whisper.cc + ) + if(GPU_RUNTIME STREQUAL "HIP") + enable_language(${GPU_RUNTIME}) + set(CMAKE_${GPU_RUNTIME}_STANDARD 17) + set(CMAKE_${GPU_RUNTIME}_EXTENSIONS OFF) + set(CMAKE_${GPU_RUNTIME}_STANDARD_REQUIRED ON) + add_definitions(-D__HIP_PLATFORM_AMD__) + find_package(hipblas REQUIRED) + find_package(hip REQUIRED) + + if(WITH_CUDNN) + find_package(miopen REQUIRED) + add_definitions(-DCT2_WITH_CUDNN) + endif() + # list(APPEND GPU_SOURCES src/devices.cc src/types.cc src/utils.cc src/models/model.cc src/models/whisper.cc) + add_library(${PROJECT_NAME} SHARED + ${SOURCES} + ${GPU_SOURCES} + ) + set_source_files_properties(${GPU_SOURCES} PROPERTIES LANGUAGE ${GPU_RUNTIME}) + set_source_files_properties(${SOURCES} PROPERTIES LANGUAGE CXX) + target_link_libraries(${PROJECT_NAME} PUBLIC roc::hipblas) + if(WITH_CUDNN) + # find_library(LIBRT NAMES "librt.so.1" PATHS /lib/x86_64-linux-gnu/ NO_DEFAULT_PATH ) + find_package(miopen REQUIRED) + find_library(LIBRT rt) + if(LIBRT) + message(STATUS "Librt: " ${LIBRT}) + target_link_libraries(${PROJECT_NAME} PUBLIC MIOpen ${LIBRT} ) + else() + target_link_libraries(${PROJECT_NAME} PUBLIC MIOpen ) + endif() + endif() + else() + find_package(CUDA 11.0 REQUIRED) if(MSVC) if(BUILD_SHARED_LIBS) list(APPEND CUDA_NVCC_FLAGS "-Xcompiler=/MD$<$:d>") @@ -501,31 +580,20 @@ if (WITH_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) cuda_add_library(${PROJECT_NAME} ${SOURCES} - src/cuda/allocator.cc - src/cuda/primitives.cu - src/cuda/random.cu - src/cuda/utils.cc - src/ops/alibi_add_gpu.cu - src/ops/bias_add_gpu.cu - src/ops/concat_split_slide_gpu.cu - src/ops/conv1d_gpu.cu - src/ops/dequantize_gpu.cu - src/ops/gather_gpu.cu - src/ops/gumbel_max_gpu.cu - src/ops/layer_norm_gpu.cu - src/ops/mean_gpu.cu - src/ops/multinomial_gpu.cu - src/ops/rms_norm_gpu.cu - src/ops/rotary_gpu.cu - src/ops/softmax_gpu.cu - src/ops/tile_gpu.cu - src/ops/topk_gpu.cu - src/ops/topp_mask_gpu.cu - src/ops/quantize_gpu.cu + ${GPU_SOURCES} ) + endif() elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") else() + list(APPEND SOURCES + src/devices.cc + src/types.cc + src/storage_view.cc + src/utils.cc + src/models/model.cc + src/models/wav2vec2.cc + src/models/whisper.cc) add_library(${PROJECT_NAME} ${SOURCES}) endif() diff --git a/cli/CMakeLists.txt b/cli/CMakeLists.txt index 3311ad33..e1dfe246 100644 --- a/cli/CMakeLists.txt +++ b/cli/CMakeLists.txt @@ -10,9 +10,24 @@ add_executable(translator target_include_directories(translator PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cxxopts/include ) -target_link_libraries(translator - PRIVATE ${PROJECT_NAME} -) +if(WITH_CUDA) + if(GPU_RUNTIME STREQUAL "HIP") + find_package(hipblas REQUIRED) + find_package(miopen REQUIRED) + target_link_libraries(translator + PRIVATE ${PROJECT_NAME} MIOpen #roc::hipblas + ) + else() + target_link_libraries(translator + PRIVATE ${PROJECT_NAME} + ) + endif() + +else() + target_link_libraries(translator + PRIVATE ${PROJECT_NAME} + ) +endif() set_target_properties(translator PROPERTIES OUTPUT_NAME ct2-translator) diff --git a/src/cuda/allocator.cc b/src/cuda/allocator.cc index 2311bd00..0ed42fd0 100644 --- a/src/cuda/allocator.cc +++ b/src/cuda/allocator.cc @@ -6,9 +6,13 @@ #include "ctranslate2/utils.h" #include "cuda/utils.h" #include "env.h" - -#include -#include +#ifdef __HIP_PLATFORM_AMD__ + #include + #include +#else + #include + #include +#endif #include namespace ctranslate2 { @@ -63,7 +67,7 @@ namespace ctranslate2 { class CudaAsyncAllocator : public Allocator { public: void* allocate(size_t size, int device_index) override { -#if CUDA_VERSION >= 11020 +#if (CUDA_VERSION >= 11020) || defined (__HIP_PLATFORM_AMD__) int prev_device_index = -1; if (device_index >= 0) { CUDA_CHECK(cudaGetDevice(&prev_device_index)); @@ -86,7 +90,7 @@ namespace ctranslate2 { } void free(void* ptr, int device_index) override { -#if CUDA_VERSION >= 11020 +#if (CUDA_VERSION >= 11020) || defined (__HIP_PLATFORM_AMD__) int prev_device_index = -1; if (device_index >= 0) { CUDA_CHECK(cudaGetDevice(&prev_device_index)); @@ -107,12 +111,16 @@ namespace ctranslate2 { }; static bool support_cuda_malloc_async() { -#if CUDA_VERSION < 11020 +#if CUDA_VERSION < 11020 && !defined(__HIP_PLATFORM_AMD__) return false; #else for (int i = 0; i < get_gpu_count(); ++i) { int supported = 0; +#ifdef __HIP_PLATFORM_AMD__ + supported = 1; +#else cudaDeviceGetAttribute(&supported, cudaDevAttrMemoryPoolsSupported, i); +#endif if (!supported) return false; } diff --git a/src/cuda/helpers.h b/src/cuda/helpers.h index a34d5d89..62c54df1 100644 --- a/src/cuda/helpers.h +++ b/src/cuda/helpers.h @@ -2,21 +2,23 @@ #include #include - -#include -#include - +#ifdef __HIP_PLATFORM_AMD__ + #include +#else + #include + #include +#endif #include "ctranslate2/types.h" #include "utils.h" -#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 ||defined(__HIP_PLATFORM_AMD__) # define CUDA_CAN_USE_HALF 1 #else # define CUDA_CAN_USE_HALF 0 #endif -#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) || defined(__HIP_PLATFORM_AMD__) # define CUDA_CAN_USE_BF16_MATH 1 #else # define CUDA_CAN_USE_BF16_MATH 0 @@ -362,9 +364,11 @@ namespace ctranslate2 { // The following kernels are adapted from: // https://github.com/pytorch/pytorch/blob/40eff454ce5638fbff638a7f4502e29ffb9a2f0d/aten/src/ATen/native/cuda/SoftMax.cu // They help define row-wise reduction where each block handles a single row. - -#define C10_WARP_SIZE 32 - +#ifdef __HIP_PLATFORM_AMD__ + #define C10_WARP_SIZE 64 +#else + #define C10_WARP_SIZE 32 +#endif template inline dim3 get_block_size(index_t dim_size) { index_t block_size = 1; @@ -400,7 +404,9 @@ namespace ctranslate2 { for (index_t i = 0; i < C10_WARP_SIZE; ++i) { warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); } +#ifndef __HIP_PLATFORM_AMD__ __syncwarp(mask); +#endif smem[lane] = warpVal; } } diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 149e10db..8fc2825f 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -1,7 +1,11 @@ #include "ctranslate2/primitives.h" - -#include -#include +#ifndef __HIP_PLATFORM_AMD__ + #include + #include +#else + #include + #include +#endif #include #include "cuda/helpers.h" @@ -240,7 +244,11 @@ namespace ctranslate2 { }; template - __global__ void penalize_previous_tokens_kernel(T* scores, + __global__ void +#ifdef __HIP_PLATFORM_AMD__ + __launch_bounds__(64) +#endif + penalize_previous_tokens_kernel(T* scores, const T* previous_scores, const int32_t* previous_ids, float penalty, @@ -265,7 +273,11 @@ namespace ctranslate2 { dim_t batch_size, dim_t length, dim_t vocabulary_size) { +#ifndef __HIP_PLATFORM_AMD__ dim3 block(32); +#else + dim3 block(64); +#endif dim3 grid((batch_size * length + block.x - 1) / block.x); penalize_previous_tokens_kernel<<>>( cuda::device_cast(scores), @@ -478,12 +490,12 @@ namespace ctranslate2 { const void* alpha_ptr = &alpha_h; const void* beta_ptr = &beta_h; - cudaDataType_t compute_type = CUDA_R_16F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; //cudaDataType_t compute_type = CUDA_R_16F; if (!cuda::use_true_fp16_gemm()) { alpha_ptr = α beta_ptr = β - compute_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F; //compute_type = CUDA_R_32F; } // cuBLAS assumes column-major storage, so swap a and b accordingly. @@ -521,7 +533,7 @@ namespace ctranslate2 { a, CUDA_R_16BF, lda, &beta, c, CUDA_R_16BF, ldc, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, //CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } @@ -549,7 +561,7 @@ namespace ctranslate2 { a, CUDA_R_8I, lda, &beta_i, c, CUDA_R_32I, ldc, - CUDA_R_32I, + CUBLAS_COMPUTE_32I, //CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } @@ -591,12 +603,12 @@ namespace ctranslate2 { const void* alpha_ptr = &alpha_h; const void* beta_ptr = &beta_h; - cudaDataType_t compute_type = CUDA_R_16F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; //cudaDataType_t compute_type = CUDA_R_16F; if (!cuda::use_true_fp16_gemm()) { alpha_ptr = α beta_ptr = β - compute_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F; //compute_type = CUDA_R_32F; } // cuBLAS assumes column-major storage, so swap a and b accordingly. @@ -635,7 +647,7 @@ namespace ctranslate2 { &beta, c, CUDA_R_16BF, ldc, stridec, batch_size, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, //CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } diff --git a/src/cuda/random.cu b/src/cuda/random.cu index f016bb44..1d081303 100644 --- a/src/cuda/random.cu +++ b/src/cuda/random.cu @@ -22,7 +22,11 @@ namespace ctranslate2 { ScopedCurandStates(size_t num_states) : _allocator(get_allocator()) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t num_init_threads = 64; +#else constexpr size_t num_init_threads = 32; +#endif const size_t blocks = ceil_divide(num_states, num_init_threads); _num_states = blocks * num_init_threads; _states = static_cast(_allocator.allocate(_num_states * sizeof (curandState))); diff --git a/src/cuda/random.h b/src/cuda/random.h index e12ae20f..20fe7436 100644 --- a/src/cuda/random.h +++ b/src/cuda/random.h @@ -1,7 +1,10 @@ #pragma once - -#include - +#ifdef __HIP_PLATFORM_AMD__ + #include + #include +#else + #include +#endif namespace ctranslate2 { namespace cuda { diff --git a/src/cuda/utils.h b/src/cuda/utils.h index 29bc99a3..e14f1aab 100644 --- a/src/cuda/utils.h +++ b/src/cuda/utils.h @@ -2,12 +2,22 @@ #include -#include -#include +#ifdef __HIP_PLATFORM_AMD__ + #include + #include + #include +#else + #include + #include +#endif #include #ifdef CT2_WITH_CUDNN -# include + #ifdef __HIP_PLATFORM_AMD__ + #include + #else + #include + #endif #endif #include "ctranslate2/types.h" @@ -79,7 +89,10 @@ namespace ctranslate2 { }; // Convenience macro to call Thrust functions with a default execution policy. -#define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) - +#ifdef __HIP_PLATFORM_AMD__ + #define THRUST_CALL(FUN, ...) FUN(thrust::hip::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) +#else + #define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) +#endif } } diff --git a/src/cuda2hip_macros.hpp b/src/cuda2hip_macros.hpp new file mode 100644 index 00000000..48be99c6 --- /dev/null +++ b/src/cuda2hip_macros.hpp @@ -0,0 +1,138 @@ +#pragma once +#ifdef __HIP_PLATFORM_AMD__ + #include + #include + #define __nv_bfloat16 hip_bfloat16 + #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT + __device__ hip_bfloat16 hlog(const hip_bfloat16 h) { + return hip_bfloat16(__ocml_log_f32(float(h))); + } + __device__ hip_bfloat16 hsin(const hip_bfloat16 h) { + return hip_bfloat16(__ocml_sin_f32(float(h))); + } + __device__ hip_bfloat16 hcos(const hip_bfloat16 h) { + return hip_bfloat16(__ocml_cos_f32(float(h))); + } + __device__ hip_bfloat16 hexp(const hip_bfloat16 h) { + return hip_bfloat16(__ocml_exp_f32(float(h))); + } + + __device__ hip_bfloat16 __habs(const hip_bfloat16 a) { + auto ret = a; + ret.data &= 0x7FFF; + return ret; + } + __device__ hip_bfloat16 __hmax(const hip_bfloat16 a, const hip_bfloat16 b) { + return hip_bfloat16(__ocml_fmax_f32(float(a), float(b))); + } + #define curandStatePhilox4_32_10_t hiprandStatePhilox4_32_10_t + #define cublasStatus_t hipblasStatus_t + #define cublasHandle_t hipblasHandle_t + #define curand_init hiprand_init + #define cudaDeviceProp hipDeviceProp_t + #define cudaDeviceSynchronize hipDeviceSynchronize + #define cudaErrorInsufficientDriver hipErrorInsufficientDriver + #define cudaErrorNoDevice hipErrorNoDevice + #define cudaError_t hipError_t + #define cudaEventCreate hipEventCreate + #define cudaEventElapsedTime hipEventElapsedTime + #define cudaEventRecord hipEventRecord + #define cudaEventSynchronize hipEventSynchronize + #define cudaEvent_t hipEvent_t + #define cudaFree hipFree + #define cudaGetDevice hipGetDevice + #define cudaGetDeviceCount hipGetDeviceCount + #define cudaGetDeviceProperties hipGetDeviceProperties + #define cudaGetErrorString hipGetErrorString + #define cudaGetLastError hipGetLastError + #define cudaLaunchKernelGGL hipLaunchKernelGGL + #define cudaMalloc hipMalloc + #define cudaMemcpy hipMemcpy + #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost + #define cudaMemcpyHostToDevice hipMemcpyHostToDevice + #define cudaSuccess hipSuccess + #define cudaMallocHost hipHostMalloc + #define cudaStream_t hipStream_t + #define cudaStreamCreate hipStreamCreate + #define cudaStreamCreateWithFlags hipStreamCreateWithFlags + #define cudaStreamNonBlocking hipStreamNonBlocking + #define cudaStreamDestroy hipStreamDestroy + #define cudaSetDevice hipSetDevice + #define udaMemcpyToSymbol hipMemcpyToSymbol c + #define cudaMemcpyAsync hipMemcpyAsync + #define cudaFreeHost hipHostFree + #define cudaDeviceReset hipDeviceReset + #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice + #define cudaStreamSynchronize hipStreamSynchronize + #define cudaMallocAsync hipMallocAsync + #define cudaFreeAsync hipFreeAsync + #define cudaDeviceGetAttribute hipDeviceGetAttribute + #define cudaDevAttrMemoryPoolsSupported hipDeviceAttributeMemoryPoolsSupported + #define CUBLAS_OP_T HIPBLAS_OP_T + #define CUBLAS_OP_N HIPBLAS_OP_N + #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS + #define cublasSgemmStridedBatched hipblasSgemmStridedBatched + #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx + #define cublasSgemm hipblasSgemm + #define cublasGemmEx hipblasGemmEx + #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F //new added, HIPBLAS_R_16F deprecated + #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F //new added + #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F //new added + #define CUBLAS_COMPUTE_32F_FAST_16BF HIPBLAS_COMPUTE_32F_FAST_16BF //new added + #define CUBLAS_COMPUTE_32I HIPBLAS_COMPUTE_32I //new added + #define CUDA_R_16F HIP_R_16F //HIPBLAS_R_16F deprecated + #define CUDA_R_32F HIP_R_32F //HIPBLAS_R_32F deprecated + #define CUDA_R_16B HIP_R_16BF //HIPBLAS_R_16B deprecated + #define CUDA_R_16BF HIP_R_16BF //HIPBLAS_R_16B deprecated + #define CUDA_R_32I HIP_R_32I //HIPBLAS_R_32I deprecated + #define CUDA_R_8I HIP_R_8I //HIPBLAS_R_8I deprecated + #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED + #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED + #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE + #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH + #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR + #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED + #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR + #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + #define CUBLAS_STATUS_LICENSE_ERROR HIPBLAS_STATUS_UNKNOWN + #define cublasCreate hipblasCreate + #define cublasSetStream hipblasSetStream + #define cublasDestroy hipblasDestroy + #define cublasComputeType_t hipblasComputeType_t //hipblasDatatype_t & hipblasLtComputeType_t deprecated + #define cudaDataType_t hipDataType //hipblasDatatype_t deprecated, use hipDataType for rocm7 + #define cub hipcub + #define cudaStreamDefault hipStreamDefault + #define curand_uniform hiprand_uniform +//cudnn vs miopen + #define cudnnHandle_t miopenHandle_t + #define cudnnDataType_t miopenDataType_t + #define cudnnStatus_t miopenStatus_t + #define cudnnGetErrorString miopenGetErrorString + #define cudnnCreate miopenCreate + #define cudnnSetStream miopenSetStream + #define cudnnDestroy miopenDestroy + #define CUDNN_STATUS_SUCCESS miopenStatusSuccess + #define CUDNN_DATA_FLOAT miopenFloat + #define CUDNN_DATA_HALF miopenHalf + #define CUDNN_DATA_BFLOAT16 miopenBFloat16 + #define CUDNN_DATA_BFLOAT16 miopenBFloat16 + #define CUDNN_DATA_INT32 miopenInt32 + #define CUDNN_DATA_INT8 miopenInt8 + #define cudnnCreateTensorDescriptor miopenCreateTensorDescriptor + #define cudnnTensorDescriptor_t miopenTensorDescriptor_t + #define cudnnSetTensor4dDescriptor miopenSet4dTensorDescriptor + #define cudnnFilterDescriptor_t miopenTensorDescriptor_t + #define cudnnCreateFilterDescriptor miopenCreateTensorDescriptor + #define cudnnSetFilter4dDescriptor miopenSet4dTensorDescriptor + #define cudnnActivationDescriptor_t miopenActivationDescriptor_t + #define cudnnCreateActivationDescriptor miopenCreateActivationDescriptor + #define cudnnConvolutionDescriptor_t miopenConvolutionDescriptor_t + #define cudnnCreateConvolutionDescriptor miopenCreateConvolutionDescriptor + #define cudnnDestroyActivationDescriptor miopenDestroyActivationDescriptor + #define cudnnDestroyTensorDescriptor miopenDestroyTensorDescriptor + #define cudnnDestroyConvolutionDescriptor miopenDestroyConvolutionDescriptor + #define cudnnConvolutionBiasActivationForward miopenConvolutionBiasActivationForward + #define cudnnDestroyFilterDescriptor miopenDestroyTensorDescriptor +#endif + diff --git a/src/ops/bias_add_gpu.cu b/src/ops/bias_add_gpu.cu index 8f53bcf6..47222b02 100644 --- a/src/ops/bias_add_gpu.cu +++ b/src/ops/bias_add_gpu.cu @@ -7,9 +7,9 @@ namespace ctranslate2 { namespace ops { template - __global__ void bias_add_kernel(const T* value, - const T* bias, - T* output, + __global__ void bias_add_kernel(const T* __restrict__ value, + const T* __restrict__ bias, + T* __restrict__ output, cuda::index_t depth, const AddFunc& add_func, const Epilogue& epilogue) { diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu index 3b389358..be1fa170 100644 --- a/src/ops/conv1d_gpu.cu +++ b/src/ops/conv1d_gpu.cu @@ -30,40 +30,66 @@ namespace ctranslate2 { cudnnTensorDescriptor_t input_desc; CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, data_type, - batch_size, in_channels, 1, input_length)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, +#ifndef __HIP_PLATFORM_AMD__ + CUDNN_TENSOR_NCHW, +#endif + data_type, batch_size, in_channels, 1, input_length)); cudnnTensorDescriptor_t output_desc; CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, data_type, - batch_size, out_channels, 1, output_length)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, +#ifndef __HIP_PLATFORM_AMD__ + CUDNN_TENSOR_NCHW, +#endif + data_type, batch_size, out_channels, 1, output_length)); cudnnFilterDescriptor_t weight_desc; CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc)); - CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW, - out_channels, in_channels, 1, kernel_size)); + CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, +#ifndef __HIP_PLATFORM_AMD__ + CUDNN_TENSOR_NCHW, +#endif + out_channels, in_channels, 1, kernel_size)); cudnnConvolutionDescriptor_t conv_desc; CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); +#ifndef __HIP_PLATFORM_AMD__ CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, /*pad_h=*/0, /*pad_w=*/_padding, /*stride_h=*/1, /*stride_w=*/_stride, /*dilation_h=*/1, /*dilation_w=*/_dilation, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); - +#else + CUDNN_CHECK(miopenInitConvolutionDescriptor(conv_desc, + miopenConvolution, + /*pad_h=*/0, /*pad_w=*/_padding, + /*stride_h=*/1, /*stride_w=*/_stride, + /*dilation_h=*/1, /*dilation_w=*/_dilation) ); +#endif +#ifndef __HIP_PLATFORM_AMD__ CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); - +#endif cudnnHandle_t handle = cuda::get_cudnn_handle(); +#ifndef __HIP_PLATFORM_AMD__ cudnnConvolutionFwdAlgo_t algo = (bias ? CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM); +#else + miopenConvFwdAlgorithm_t algo = (bias + ? miopenConvolutionFwdAlgoGEMM + : miopenConvolutionFwdAlgoImplicitGEMM); +#endif + size_t workspace_size = 0; void* workspace = nullptr; +#ifndef __HIP_PLATFORM_AMD__ + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, input_desc, weight_desc, @@ -74,23 +100,85 @@ namespace ctranslate2 { if (workspace_size > 0) workspace = get_allocator().allocate(workspace_size); - +#else + std::size_t count; + CUDNN_CHECK(miopenConvolutionForwardGetSolutionCount(handle, + weight_desc, + input_desc, + conv_desc, + output_desc, + &count)); + if(count <1){ + std::cout<<"count: "<(count); + CUDNN_CHECK(miopenConvolutionForwardGetSolution(handle, + weight_desc, + input_desc, + conv_desc, + output_desc, + count, + &count, + solutions.data())); + const miopenConvSolution_t* selected = &solutions.front(); + CUDNN_CHECK(miopenConvolutionForwardGetSolutionWorkspaceSize(handle, + weight_desc, + input_desc, + conv_desc, + output_desc, + selected->solution_id, + &workspace_size)); + if (workspace_size > 0){ + workspace = get_allocator().allocate(workspace_size); + } + CUDNN_CHECK(miopenConvolutionForwardCompileSolution(handle, + weight_desc, + input_desc, + conv_desc, + output_desc, + selected->solution_id)); + + CUDNN_CHECK(miopenConvolutionForwardImmediate(handle, + weight_desc, + weight.buffer(), + input_desc, + input.buffer(), + conv_desc, + output_desc, + output.buffer(), + workspace, + workspace_size, + selected->solution_id)); +#endif float alpha = 1; float beta = 0; if (bias) { cudnnTensorDescriptor_t bias_desc; CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, data_type, - 1, out_channels, 1, 1)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, +#ifndef __HIP_PLATFORM_AMD__ + CUDNN_TENSOR_NCHW, +#endif + data_type, 1, out_channels, 1, 1)); + cudnnActivationDescriptor_t activation_desc; CUDNN_CHECK(cudnnCreateActivationDescriptor(&activation_desc)); +#ifndef __HIP_PLATFORM_AMD__ CUDNN_CHECK(cudnnSetActivationDescriptor(activation_desc, CUDNN_ACTIVATION_IDENTITY, CUDNN_NOT_PROPAGATE_NAN, /*coef=*/0)); - +#else + CUDNN_CHECK(miopenSetActivationDescriptor(activation_desc, + miopenActivationPASTHRU, + 0, + 0, + 0)); +#endif +#ifndef __HIP_PLATFORM_AMD__ CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle, &alpha, input_desc, @@ -109,11 +197,21 @@ namespace ctranslate2 { activation_desc, output_desc, output.buffer())); +#else + CUDNN_CHECK(miopenConvolutionForwardBias(handle, + &alpha, + bias_desc, + bias->buffer(), + &beta, + output_desc, + output.buffer())); +#endif CUDNN_CHECK(cudnnDestroyActivationDescriptor(activation_desc)); CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); } else { +#ifndef __HIP_PLATFORM_AMD__ CUDNN_CHECK(cudnnConvolutionForward(handle, &alpha, input_desc, @@ -127,6 +225,7 @@ namespace ctranslate2 { &beta, output_desc, output.buffer())); +#endif } if (workspace) diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 241b3acd..99de40b9 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -4,6 +4,16 @@ namespace ctranslate2 { namespace ops { +#ifdef __HIP_PLATFORM_AMD__ + template + struct dequantize_func { + __device__ __forceinline__ + OutT operator()(float scale, InT x) const { + return OutT(__fdividef(float(x), scale)); + } + }; + +#else template struct dequantize_func { @@ -12,7 +22,7 @@ namespace ctranslate2 { return __fdividef(static_cast(x), scale); } }; - +#endif template void Dequantize::dequantize(const StorageView& input, const StorageView& scale, diff --git a/src/ops/gumbel_max_gpu.cu b/src/ops/gumbel_max_gpu.cu index 160390c4..b50b5343 100644 --- a/src/ops/gumbel_max_gpu.cu +++ b/src/ops/gumbel_max_gpu.cu @@ -17,7 +17,11 @@ namespace ctranslate2 { template __device__ DataType operator()(DataType value, IndexType id) const { const float z = -logf(curand_uniform(_states + id)); +#ifdef __HIP_PLATFORM_AMD__ + return value + DataType(z); +#else return float(value) + z; +#endif } private: diff --git a/src/ops/layer_norm_gpu.cu b/src/ops/layer_norm_gpu.cu index 8c644d87..7dfcae9a 100644 --- a/src/ops/layer_norm_gpu.cu +++ b/src/ops/layer_norm_gpu.cu @@ -139,14 +139,17 @@ namespace ctranslate2 { ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ - -#include - +#ifndef __HIP_PLATFORM_AMD__ + #include +#else + #include + #include +#endif namespace at { namespace native { template - __global__ void LayerNormForwardCUDAKernel(SizeT N, + __global__ void __launch_bounds__(CUDA_NUM_THREADS) LayerNormForwardCUDAKernel(SizeT N, float eps, const T* X, const T* gamma, diff --git a/src/ops/mean_gpu.cu b/src/ops/mean_gpu.cu index 5125924c..508d9706 100644 --- a/src/ops/mean_gpu.cu +++ b/src/ops/mean_gpu.cu @@ -1,7 +1,10 @@ #include "ctranslate2/ops/mean.h" - -#include - +#ifndef __HIP_PLATFORM_AMD__ + #include +#else +#include +#include +#endif #include "type_dispatch.h" #include "cuda/helpers.h" @@ -11,7 +14,7 @@ namespace ctranslate2 { constexpr dim_t num_threads = 256; template - __global__ void mean_kernel(const T* input, + __global__ void __launch_bounds__(num_threads) mean_kernel(const T* input, const cuda::index_t outer_size, const cuda::index_t axis_size, const cuda::index_t inner_size, diff --git a/src/ops/multinomial_gpu.cu b/src/ops/multinomial_gpu.cu index 90f36377..e2425836 100644 --- a/src/ops/multinomial_gpu.cu +++ b/src/ops/multinomial_gpu.cu @@ -1,8 +1,12 @@ #include "ctranslate2/ops/multinomial.h" - -#include -#include - +#ifndef __HIP_PLATFORM_AMD__ + #include + #include +#else + #include + #include + #include +#endif #include "cuda/helpers.h" #include "cuda/random.h" @@ -24,7 +28,7 @@ namespace ctranslate2 { constexpr dim_t num_threads = 256; template - __global__ void multinomial_kernel(const In* probs, + __global__ void __launch_bounds__(num_threads) multinomial_kernel(const In* probs, cuda::index_t class_size, Out* output, curandStatePhilox4_32_10_t* states) { diff --git a/src/ops/rms_norm_gpu.cu b/src/ops/rms_norm_gpu.cu index 41130cd9..f2734ede 100644 --- a/src/ops/rms_norm_gpu.cu +++ b/src/ops/rms_norm_gpu.cu @@ -1,7 +1,10 @@ #include "ctranslate2/ops/rms_norm.h" - -#include - +#ifndef __HIP_PLATFORM_AMD__ + #include +#else + #include + #include +#endif #include "cuda/helpers.h" #include "cuda/utils.h" @@ -11,7 +14,7 @@ namespace ctranslate2 { constexpr dim_t num_threads = 512; template - __global__ void rms_norm_kernel(const T* input, + __global__ void __launch_bounds__(num_threads) rms_norm_kernel(const T* input, const T* gamma, T* output, cuda::index_t depth, diff --git a/src/ops/topk_gpu.cu b/src/ops/topk_gpu.cu index ad010fb4..ffbe192e 100644 --- a/src/ops/topk_gpu.cu +++ b/src/ops/topk_gpu.cu @@ -115,9 +115,12 @@ namespace ctranslate2 { OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ - -#include - +#ifndef __HIP_PLATFORM_AMD__ + #include +#else + #include + #include +#endif namespace fastertransformer { #define NOT_FOUND -1 diff --git a/src/ops/topp_mask_gpu.cu b/src/ops/topp_mask_gpu.cu index a4e6cb6e..4d6661b0 100644 --- a/src/ops/topp_mask_gpu.cu +++ b/src/ops/topp_mask_gpu.cu @@ -1,6 +1,10 @@ #include "ctranslate2/ops/topp_mask.h" - -#include +#ifndef __HIP_PLATFORM_AMD__ + #include +#else + #include + #include +#endif #include "cuda/helpers.h" @@ -10,7 +14,7 @@ namespace ctranslate2 { constexpr dim_t num_threads = 256; template - __global__ void topp_mask_kernel(const T* input, + __global__ void __launch_bounds__(num_threads) topp_mask_kernel(const T* input, const T* probs, T* output, const float p, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 283c49db..87906a7e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,18 +18,45 @@ add_executable(ctranslate2_test target_include_directories(ctranslate2_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src ) -target_link_libraries(ctranslate2_test - ${PROJECT_NAME} - gtest_main +if(WITH_CUDA) + if(GPU_RUNTIME STREQUAL "HIP") + find_package(hipblas REQUIRED) + find_package(miopen REQUIRED) + target_link_libraries(ctranslate2_test + ${PROJECT_NAME} + gtest_main # roc::hipblas + MIOpen + ) + else() + target_link_libraries(ctranslate2_test + ${PROJECT_NAME} + gtest_main + ) + endif() +else() + + target_link_libraries(ctranslate2_test + ${PROJECT_NAME} + gtest_main ) +endif() add_executable(benchmark_ops benchmark_ops.cc ) +target_include_directories(benchmark_ops PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ) target_link_libraries(benchmark_ops - ${PROJECT_NAME} - ) + ${PROJECT_NAME} + ) if(WITH_CUDA) - target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) + set_source_files_properties(benchmark_ops.cc PROPERTIES LANGUAGE ${GPU_RUNTIME}) + if(GPU_RUNTIME STREQUAL "HIP") + find_package(hipblas REQUIRED) + target_link_libraries(benchmark_ops roc::hipblas) + else() + target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) + endif() endif() diff --git a/tests/benchmark_utils.h b/tests/benchmark_utils.h index 55ee77b0..c35da837 100644 --- a/tests/benchmark_utils.h +++ b/tests/benchmark_utils.h @@ -5,8 +5,13 @@ #include #ifdef CT2_WITH_CUDA -# include -# define SYNCHRONIZE cudaDeviceSynchronize() + #ifndef __HIP_PLATFORM_AMD__ + #include + #define SYNCHRONIZE cudaDeviceSynchronize() + #else + #include + #define SYNCHRONIZE hipDeviceSynchronize() + #endif #else # define SYNCHRONIZE do {} while (false) #endif