From a598a88db258f82a6e4bca75810921bd6bcee7e0 Mon Sep 17 00:00:00 2001 From: David Nieto Date: Sat, 17 Feb 2024 11:23:12 -0800 Subject: [PATCH] Disable algo caching in ROCM EP Similar to the work done by Liangxijun-1001 in https://github.com/apache/tvm/pull/16178 the ROCM spec mandates calling miopenFindConvolution*Algorithm() before using any Convolution API This is the link to the porting guide describing this requirement https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html Thus, this change disables the algo cache and enforces the official API semantics Signed-off-by: David Nieto --- onnxruntime/core/providers/rocm/nn/conv.cc | 61 +++++++++---------- onnxruntime/core/providers/rocm/nn/conv.h | 6 -- .../core/providers/rocm/nn/conv_transpose.cc | 17 +++--- 3 files changed, 36 insertions(+), 48 deletions(-) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6214ec7bc0ea..b08aceca48b1 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -125,10 +125,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (input_dims_changed) s_.last_x_dims = gsl::make_span(x_dims); - if (w_dims_changed) { + if (w_dims_changed) s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_fwd_results.clear(); - } ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); @@ -277,35 +275,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size)); HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); } - - if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { - miopenConvAlgoPerf_t perf; - int algo_count = 1; - const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); - static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos) - : AlgoSearchWorkspaceSize; - IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( - GetMiopenHandle(context), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf, - algo_search_workspace.get(), - max_ws_size, - false)); // Do not do exhaustive algo search. - s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory}); - } - const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen); - s_.fwd_algo = perf.fwd_algo; - s_.workspace_bytes = perf.memory; } else { // set Y s_.Y = context->Output(0, TensorShape(s_.y_dims)); @@ -319,6 +288,34 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.y_data = reinterpret_cast(s_.Y->MutableData()); } } + { + /* FindConvolution must always be called by the runtime */ + TensorShapeVector x_dims_miopen{x_dims.begin(), x_dims.end()}; + miopenConvAlgoPerf_t perf; + int algo_count = 1; + const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); + static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( + GetMiopenHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size, + false)); // Do not do exhaustive algo search. + s_.fwd_algo = perf.fwd_algo; + s_.workspace_bytes = perf.memory; + } return Status::OK(); } diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index bc9846203e57..d54218f25854 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -108,9 +108,6 @@ class lru_unordered_map { list_type lru_list_; }; -// cached miopen descriptors -constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000; - template struct MiopenConvState { // if x/w dims changed, update algo and miopenTensors @@ -148,9 +145,6 @@ struct MiopenConvState { decltype(AlgoPerfType().memory) memory; }; - lru_unordered_map cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - lru_unordered_map cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - // Some properties needed to support asymmetric padded Conv nodes bool post_slicing_required; TensorShapeVector slice_starts; diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 7447113fdf84..45ed4c8ac37a 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -76,7 +76,6 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_bwd_results.clear(); } ConvTransposeAttributes::Prepare p; @@ -127,12 +126,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy y_data = reinterpret_cast(p.Y->MutableData()); - if (!s_.cached_benchmark_bwd_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - miopenConvAlgoPerf_t perf; - int algo_count = 1; - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( + } + // The following is required before calling convolution, we cannot cache the results + { + IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + miopenConvAlgoPerf_t perf; + int algo_count = 1; + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( GetMiopenHandle(context), s_.x_tensor, x_data, @@ -147,10 +147,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy algo_search_workspace.get(), AlgoSearchWorkspaceSize, false)); - s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory}); - } - const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims); s_.bwd_data_algo = perf.bwd_data_algo; s_.workspace_bytes = perf.memory; }