mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 10:37:11 -04:00 
			
		
		
		
	disable algo caching
This commit is contained in:
		
							parent
							
								
									7ac30995a8
								
							
						
					
					
						commit
						f19cf206ba
					
				| @ -1,150 +0,0 @@ | ||||
| From 350e3237eadb738a0d96295a62f2eed96653c315 Mon Sep 17 00:00:00 2001 | ||||
| From: mertalev <101130780+mertalev@users.noreply.github.com> | ||||
| Date: Fri, 20 Dec 2024 00:59:21 -0500 | ||||
| Subject: [PATCH 1/1] fix: avoid race condition for rocm conv algo caching | ||||
| 
 | ||||
| ---
 | ||||
|  onnxruntime/core/providers/rocm/nn/conv.cc         |  8 ++++---- | ||||
|  onnxruntime/core/providers/rocm/nn/conv.h          | 14 ++++++++++++-- | ||||
|  .../core/providers/rocm/nn/conv_transpose.cc       |  8 ++++---- | ||||
|  3 files changed, 20 insertions(+), 10 deletions(-) | ||||
| 
 | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| index d7f47d07a8..98b6b69212 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| @@ -127,7 +127,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|   | ||||
|      if (w_dims_changed) { | ||||
|        s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -      s_.cached_benchmark_fwd_results.clear();
 | ||||
|      } | ||||
|   | ||||
|      ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); | ||||
| @@ -278,7 +277,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|        HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); | ||||
|      } | ||||
|   | ||||
| -    if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
 | ||||
| +    const std::size_t algo_key = HashConvAlgoKey(x_dims_miopen, w_dims);
 | ||||
| +    if (!s_.cached_benchmark_fwd_results.contains(algo_key)) {
 | ||||
|        miopenConvAlgoPerf_t perf; | ||||
|        int algo_count = 1; | ||||
|        const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider()); | ||||
| @@ -301,9 +301,9 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|            algo_search_workspace.get(), | ||||
|            max_ws_size, | ||||
|            false));  // Do not do exhaustive algo search. | ||||
| -      s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory});
 | ||||
| +      s_.cached_benchmark_fwd_results.insert(algo_key, {perf.fwd_algo, perf.memory});
 | ||||
|      } | ||||
| -    const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen);
 | ||||
| +    const auto& perf = s_.cached_benchmark_fwd_results.at(algo_key);
 | ||||
|      s_.fwd_algo = perf.fwd_algo; | ||||
|      s_.workspace_bytes = perf.memory; | ||||
|    } else { | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| index bc9846203e..b1ca5f8e4b 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| @@ -43,6 +43,11 @@ struct vector_hash {
 | ||||
|    } | ||||
|  }; | ||||
|   | ||||
| +inline std::size_t HashConvAlgoKey(const TensorShapeVector& x_dims, const TensorShapeVector& w_dims) {
 | ||||
| +  vector_hash vh;
 | ||||
| +  return vh(x_dims) ^ vh(w_dims);
 | ||||
| +}
 | ||||
| +
 | ||||
|  template <typename Key, typename T, | ||||
|            typename Hash = std::hash<Key>, | ||||
|            typename KeyEqual = std::equal_to<Key>, | ||||
| @@ -52,6 +57,7 @@ class lru_unordered_map {
 | ||||
|    lru_unordered_map(size_t max_size) : max_size_(max_size) {} | ||||
|   | ||||
|    void insert(const Key& key, const T& value) { | ||||
| +    std::lock_guard<std::mutex> guard(mutex_);
 | ||||
|      auto it = items_.find(key); | ||||
|      if (it != items_.end()) { | ||||
|        it->second.value = value; | ||||
| @@ -69,6 +75,7 @@ class lru_unordered_map {
 | ||||
|    } | ||||
|   | ||||
|    T& at(const Key& key) { | ||||
| +    std::lock_guard<std::mutex> guard(mutex_);
 | ||||
|      auto it = items_.find(key); | ||||
|      if (it == items_.end()) { | ||||
|        throw std::out_of_range("There is no such key in cache"); | ||||
| @@ -78,6 +85,7 @@ class lru_unordered_map {
 | ||||
|    } | ||||
|   | ||||
|    bool contains(const Key& key) const { | ||||
| +    std::lock_guard<std::mutex> guard(mutex_);
 | ||||
|      return items_.find(key) != items_.end(); | ||||
|    } | ||||
|   | ||||
| @@ -86,6 +94,7 @@ class lru_unordered_map {
 | ||||
|    } | ||||
|   | ||||
|    void clear() { | ||||
| +    std::lock_guard<std::mutex> guard(mutex_);
 | ||||
|      items_.clear(); | ||||
|      lru_list_.clear(); | ||||
|    } | ||||
| @@ -106,6 +115,7 @@ class lru_unordered_map {
 | ||||
|    size_t max_size_; | ||||
|    std::unordered_map<Key, value_type, Hash, KeyEqual, MapAllocator> items_; | ||||
|    list_type lru_list_; | ||||
| +  mutable std::mutex mutex_;
 | ||||
|  }; | ||||
|   | ||||
|  // cached miopen descriptors | ||||
| @@ -148,8 +158,8 @@ struct MiopenConvState {
 | ||||
|      decltype(AlgoPerfType().memory) memory; | ||||
|    }; | ||||
|   | ||||
| -  lru_unordered_map<TensorShapeVector, PerfFwdResultParams, vector_hash> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| -  lru_unordered_map<TensorShapeVector, PerfBwdResultParams, vector_hash> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| +  lru_unordered_map<std::size_t, PerfFwdResultParams> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| +  lru_unordered_map<std::size_t, PerfBwdResultParams> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
|   | ||||
|    // Some properties needed to support asymmetric padded Conv nodes | ||||
|    bool post_slicing_required; | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| index 7447113fdf..dea9bf2a05 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| @@ -76,7 +76,6 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        if (w_dims_changed) { | ||||
|          s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -        s_.cached_benchmark_bwd_results.clear();
 | ||||
|        } | ||||
|   | ||||
|        ConvTransposeAttributes::Prepare p; | ||||
| @@ -127,7 +126,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>()); | ||||
|   | ||||
| -      if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
 | ||||
| +      const std::size_t algo_key = HashConvAlgoKey(x_dims, w_dims);
 | ||||
| +      if (!s_.cached_benchmark_bwd_results.contains(algo_key)) {
 | ||||
|          IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream()); | ||||
|   | ||||
|          miopenConvAlgoPerf_t perf; | ||||
| @@ -147,10 +147,10 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|              algo_search_workspace.get(), | ||||
|              AlgoSearchWorkspaceSize, | ||||
|              false)); | ||||
| -        s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory});
 | ||||
| +        s_.cached_benchmark_bwd_results.insert(algo_key, {perf.bwd_data_algo, perf.memory});
 | ||||
|        } | ||||
|   | ||||
| -      const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims);
 | ||||
| +      const auto& perf = s_.cached_benchmark_bwd_results.at(algo_key);
 | ||||
|        s_.bwd_data_algo = perf.bwd_data_algo; | ||||
|        s_.workspace_bytes = perf.memory; | ||||
|      } | ||||
| -- 
 | ||||
| 2.43.0 | ||||
| 
 | ||||
| @ -15,15 +15,13 @@ RUN mkdir /opt/armnn && \ | ||||
|     cd /opt/ann && \ | ||||
|     sh build.sh | ||||
| 
 | ||||
| # Warning: 26.3Gb of disk space required to pull this image | ||||
| # https://github.com/microsoft/onnxruntime/blob/main/dockerfiles/Dockerfile.rocm | ||||
| # 6.2 or later fails to build as of writing | ||||
| # Warning: 25GiB+ disk space required to pull this image | ||||
| # TODO: find a way to reduce the image size | ||||
| FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS builder-rocm | ||||
| 
 | ||||
| WORKDIR /code | ||||
| 
 | ||||
| RUN apt-get update && apt-get install -y --no-install-recommends wget git python3.10-venv | ||||
| # Install same version as the Dockerfile provided by onnxruntime | ||||
| RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3.30.1-linux-x86_64.sh && \ | ||||
|     chmod +x cmake-3.30.1-linux-x86_64.sh && \ | ||||
|     mkdir -p /code/cmake-3.30.1-linux-x86_64 && \ | ||||
| @ -32,13 +30,12 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.30.1/cmake-3. | ||||
| 
 | ||||
| ENV PATH=/code/cmake-3.30.1-linux-x86_64/bin:${PATH} | ||||
| 
 | ||||
| # Prepare onnxruntime repository & build onnxruntime | ||||
| # 1.20.1 fails to build as of writing | ||||
| RUN git clone --single-branch --branch v1.20.1 --recursive "https://github.com/Microsoft/onnxruntime" onnxruntime | ||||
| WORKDIR /code/onnxruntime | ||||
| # Fix for multi-threading based on comments in https://github.com/microsoft/onnxruntime/pull/19567 | ||||
| COPY ./0001-fix-rocm-conv-thread-safety.patch /tmp/ | ||||
| RUN git apply /tmp/0001-fix-rocm-conv-thread-safety.patch | ||||
| # TODO: find a way to fix this without disabling algo caching | ||||
| COPY ./rocm-PR19567.patch /tmp/ | ||||
| RUN git apply /tmp/rocm-PR19567.patch | ||||
| 
 | ||||
| RUN /bin/sh ./dockerfiles/scripts/install_common_deps.sh | ||||
| # Note: the `parallel` setting uses a substantial amount of RAM | ||||
|  | ||||
							
								
								
									
										176
									
								
								machine-learning/rocm-PR19567.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								machine-learning/rocm-PR19567.patch
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,176 @@ | ||||
| From a598a88db258f82a6e4bca75810921bd6bcee7e0 Mon Sep 17 00:00:00 2001 | ||||
| From: David Nieto <dmnieto@gmail.com> | ||||
| Date: Sat, 17 Feb 2024 11:23:12 -0800 | ||||
| Subject: [PATCH] Disable algo caching in ROCM EP | ||||
| 
 | ||||
| Similar to the work done by Liangxijun-1001 in | ||||
| https://github.com/apache/tvm/pull/16178 the ROCM spec mandates calling | ||||
| miopenFindConvolution*Algorithm() before using any Convolution API | ||||
| 
 | ||||
| This is the link to the porting guide describing this requirement | ||||
| https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html | ||||
| 
 | ||||
| Thus, this change disables the algo cache and enforces the official | ||||
| API semantics | ||||
| 
 | ||||
| Signed-off-by: David Nieto <dmnieto@gmail.com> | ||||
| ---
 | ||||
|  onnxruntime/core/providers/rocm/nn/conv.cc    | 61 +++++++++---------- | ||||
|  onnxruntime/core/providers/rocm/nn/conv.h     |  6 -- | ||||
|  .../core/providers/rocm/nn/conv_transpose.cc  | 17 +++--- | ||||
|  3 files changed, 36 insertions(+), 48 deletions(-) | ||||
| 
 | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| index 6214ec7bc0ea..b08aceca48b1 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| @@ -125,10 +125,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|      if (input_dims_changed) | ||||
|        s_.last_x_dims = gsl::make_span(x_dims); | ||||
|   | ||||
| -    if (w_dims_changed) {
 | ||||
| +    if (w_dims_changed)
 | ||||
|        s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -      s_.cached_benchmark_fwd_results.clear();
 | ||||
| -    }
 | ||||
|   | ||||
|      ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); | ||||
|   | ||||
| @@ -277,35 +275,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|        HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size)); | ||||
|        HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); | ||||
|      } | ||||
| -
 | ||||
| -    if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
 | ||||
| -      miopenConvAlgoPerf_t perf;
 | ||||
| -      int algo_count = 1;
 | ||||
| -      const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
 | ||||
| -      static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
 | ||||
| -      size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
 | ||||
| -                                                                   : AlgoSearchWorkspaceSize;
 | ||||
| -      IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
 | ||||
| -      MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
 | ||||
| -          GetMiopenHandle(context),
 | ||||
| -          s_.x_tensor,
 | ||||
| -          s_.x_data,
 | ||||
| -          s_.w_desc,
 | ||||
| -          s_.w_data,
 | ||||
| -          s_.conv_desc,
 | ||||
| -          s_.y_tensor,
 | ||||
| -          s_.y_data,
 | ||||
| -          1,            // requestedAlgoCount
 | ||||
| -          &algo_count,  // returnedAlgoCount
 | ||||
| -          &perf,
 | ||||
| -          algo_search_workspace.get(),
 | ||||
| -          max_ws_size,
 | ||||
| -          false));  // Do not do exhaustive algo search.
 | ||||
| -      s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory});
 | ||||
| -    }
 | ||||
| -    const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen);
 | ||||
| -    s_.fwd_algo = perf.fwd_algo;
 | ||||
| -    s_.workspace_bytes = perf.memory;
 | ||||
|    } else { | ||||
|      // set Y | ||||
|      s_.Y = context->Output(0, TensorShape(s_.y_dims)); | ||||
| @@ -319,6 +288,34 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|        s_.y_data = reinterpret_cast<HipT*>(s_.Y->MutableData<T>()); | ||||
|      } | ||||
|    } | ||||
| +  {
 | ||||
| +    /* FindConvolution must always be called by the runtime */
 | ||||
| +    TensorShapeVector x_dims_miopen{x_dims.begin(), x_dims.end()};
 | ||||
| +    miopenConvAlgoPerf_t perf;
 | ||||
| +    int algo_count = 1;
 | ||||
| +    const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
 | ||||
| +    static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
 | ||||
| +    size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
 | ||||
| +                                                                   : AlgoSearchWorkspaceSize;
 | ||||
| +    IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
 | ||||
| +    MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
 | ||||
| +        GetMiopenHandle(context),
 | ||||
| +        s_.x_tensor,
 | ||||
| +        s_.x_data,
 | ||||
| +        s_.w_desc,
 | ||||
| +        s_.w_data,
 | ||||
| +        s_.conv_desc,
 | ||||
| +        s_.y_tensor,
 | ||||
| +        s_.y_data,
 | ||||
| +        1,            // requestedAlgoCount
 | ||||
| +        &algo_count,  // returnedAlgoCount
 | ||||
| +        &perf,
 | ||||
| +        algo_search_workspace.get(),
 | ||||
| +        max_ws_size,
 | ||||
| +        false));  // Do not do exhaustive algo search.
 | ||||
| +    s_.fwd_algo = perf.fwd_algo;
 | ||||
| +    s_.workspace_bytes = perf.memory;
 | ||||
| +  }
 | ||||
|    return Status::OK(); | ||||
|  } | ||||
|   | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| index bc9846203e57..d54218f25854 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| @@ -108,9 +108,6 @@ class lru_unordered_map {
 | ||||
|    list_type lru_list_; | ||||
|  }; | ||||
|   | ||||
| -// cached miopen descriptors
 | ||||
| -constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000;
 | ||||
| -
 | ||||
|  template <typename AlgoPerfType> | ||||
|  struct MiopenConvState { | ||||
|    // if x/w dims changed, update algo and miopenTensors | ||||
| @@ -148,9 +145,6 @@ struct MiopenConvState {
 | ||||
|      decltype(AlgoPerfType().memory) memory; | ||||
|    }; | ||||
|   | ||||
| -  lru_unordered_map<TensorShapeVector, PerfFwdResultParams, vector_hash> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| -  lru_unordered_map<TensorShapeVector, PerfBwdResultParams, vector_hash> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| -
 | ||||
|    // Some properties needed to support asymmetric padded Conv nodes | ||||
|    bool post_slicing_required; | ||||
|    TensorShapeVector slice_starts; | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| index 7447113fdf84..45ed4c8ac37a 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| @@ -76,7 +76,6 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        if (w_dims_changed) { | ||||
|          s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -        s_.cached_benchmark_bwd_results.clear();
 | ||||
|        } | ||||
|   | ||||
|        ConvTransposeAttributes::Prepare p; | ||||
| @@ -127,12 +126,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>()); | ||||
|   | ||||
| -      if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
 | ||||
| -        IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
 | ||||
| -
 | ||||
| -        miopenConvAlgoPerf_t perf;
 | ||||
| -        int algo_count = 1;
 | ||||
| -        MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
 | ||||
| +    }
 | ||||
| +    // The following is required before calling convolution, we cannot cache the results
 | ||||
| +    {
 | ||||
| +      IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
 | ||||
| +      miopenConvAlgoPerf_t perf;
 | ||||
| +      int algo_count = 1;
 | ||||
| +      MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
 | ||||
|              GetMiopenHandle(context), | ||||
|              s_.x_tensor, | ||||
|              x_data, | ||||
| @@ -147,10 +147,7 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|              algo_search_workspace.get(), | ||||
|              AlgoSearchWorkspaceSize, | ||||
|              false)); | ||||
| -        s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory});
 | ||||
| -      }
 | ||||
|   | ||||
| -      const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims);
 | ||||
|        s_.bwd_data_algo = perf.bwd_data_algo; | ||||
|        s_.workspace_bytes = perf.memory; | ||||
|      } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user