mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-30 10:12:33 -04:00 
			
		
		
		
	feat(ml): introduce support of onnxruntime-rocm for AMD GPU
This commit is contained in:
		
							parent
							
								
									3f4bbab4eb
								
							
						
					
					
						commit
						fe26ccd1b7
					
				
							
								
								
									
										13
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							| @ -49,7 +49,7 @@ jobs: | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       matrix: | ||||
|         suffix: ["", "-cuda", "-openvino", "-armnn"] | ||||
|         suffix: ['', '-cuda', '-rocm', '-openvino', '-armnn'] | ||||
|     steps: | ||||
|       - name: Login to GitHub Container Registry | ||||
|         uses: docker/login-action@v3 | ||||
| @ -74,7 +74,7 @@ jobs: | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       matrix: | ||||
|         suffix: [""] | ||||
|         suffix: [''] | ||||
|     steps: | ||||
|       - name: Login to GitHub Container Registry | ||||
|         uses: docker/login-action@v3 | ||||
| @ -125,6 +125,11 @@ jobs: | ||||
|             device: openvino | ||||
|             suffix: -openvino | ||||
| 
 | ||||
|           - platforms: linux/amd64 | ||||
|             runner: mich | ||||
|             device: rocm | ||||
|             suffix: -rocm | ||||
| 
 | ||||
|           - platform: linux/arm64 | ||||
|             runner: ubuntu-24.04-arm | ||||
|             device: armnn | ||||
| @ -250,7 +255,7 @@ jobs: | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v5 | ||||
|         env: | ||||
|           DOCKER_METADATA_PR_HEAD_SHA: "true" | ||||
|           DOCKER_METADATA_PR_HEAD_SHA: 'true' | ||||
|         with: | ||||
|           flavor: | | ||||
|             # Disable latest tag | ||||
| @ -403,7 +408,7 @@ jobs: | ||||
|         id: meta | ||||
|         uses: docker/metadata-action@v5 | ||||
|         env: | ||||
|           DOCKER_METADATA_PR_HEAD_SHA: "true" | ||||
|           DOCKER_METADATA_PR_HEAD_SHA: 'true' | ||||
|         with: | ||||
|           flavor: | | ||||
|             # Disable latest tag | ||||
|  | ||||
| @ -95,12 +95,12 @@ services: | ||||
|     image: immich-machine-learning-dev:latest | ||||
|     # extends: | ||||
|     #   file: hwaccel.ml.yml | ||||
|     #   service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference | ||||
|     #   service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference | ||||
|     build: | ||||
|       context: ../machine-learning | ||||
|       dockerfile: Dockerfile | ||||
|       args: | ||||
|         - DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference | ||||
|         - DEVICE=cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference | ||||
|     ports: | ||||
|       - 3003:3003 | ||||
|     volumes: | ||||
|  | ||||
| @ -38,12 +38,12 @@ services: | ||||
|     image: immich-machine-learning:latest | ||||
|     # extends: | ||||
|     #   file: hwaccel.ml.yml | ||||
|     #   service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference | ||||
|     #   service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference | ||||
|     build: | ||||
|       context: ../machine-learning | ||||
|       dockerfile: Dockerfile | ||||
|       args: | ||||
|         - DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference | ||||
|         - DEVICE=cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference | ||||
|     ports: | ||||
|       - 3003:3003 | ||||
|     volumes: | ||||
|  | ||||
| @ -33,12 +33,12 @@ services: | ||||
| 
 | ||||
|   immich-machine-learning: | ||||
|     container_name: immich_machine_learning | ||||
|     # For hardware acceleration, add one of -[armnn, cuda, openvino] to the image tag. | ||||
|     # For hardware acceleration, add one of -[armnn, cuda, rocm, openvino] to the image tag. | ||||
|     # Example tag: ${IMMICH_VERSION:-release}-cuda | ||||
|     image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release} | ||||
|     # extends: # uncomment this section for hardware acceleration - see https://immich.app/docs/features/ml-hardware-acceleration | ||||
|     #   file: hwaccel.ml.yml | ||||
|     #   service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable | ||||
|     #   service: cpu # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable | ||||
|     volumes: | ||||
|       - model-cache:/cache | ||||
|     env_file: | ||||
|  | ||||
| @ -26,6 +26,13 @@ services: | ||||
|               capabilities: | ||||
|                 - gpu | ||||
| 
 | ||||
|   rocm: | ||||
|     group_add: | ||||
|       - video | ||||
|     devices: | ||||
|       - /dev/dri:/dev/dri | ||||
|       - /dev/kfd:/dev/kfd | ||||
| 
 | ||||
|   openvino: | ||||
|     device_cgroup_rules: | ||||
|       - 'c 189:* rmw' | ||||
|  | ||||
| @ -11,6 +11,7 @@ You do not need to redo any machine learning jobs after enabling hardware accele | ||||
| 
 | ||||
| - ARM NN (Mali) | ||||
| - CUDA (NVIDIA GPUs with [compute capability](https://developer.nvidia.com/cuda-gpus) 5.2 or higher) | ||||
| - ROCM (AMD GPUs) | ||||
| - OpenVINO (Intel GPUs such as Iris Xe and Arc) | ||||
| 
 | ||||
| ## Limitations | ||||
| @ -41,6 +42,10 @@ You do not need to redo any machine learning jobs after enabling hardware accele | ||||
| - The installed driver must be >= 535 (it must support CUDA 12.2). | ||||
| - On Linux (except for WSL2), you also need to have [NVIDIA Container Toolkit][nvct] installed. | ||||
| 
 | ||||
| #### ROCM | ||||
| 
 | ||||
| - The GPU must be supported by ROCM (or use `HSA_OVERRIDE_GFX_VERSION=<a supported version, ie 10.3.0>`) | ||||
| 
 | ||||
| #### OpenVINO | ||||
| 
 | ||||
| - Integrated GPUs are more likely to experience issues than discrete GPUs, especially for older processors or servers with low RAM. | ||||
| @ -51,12 +56,12 @@ You do not need to redo any machine learning jobs after enabling hardware accele | ||||
| 
 | ||||
| 1. If you do not already have it, download the latest [`hwaccel.ml.yml`][hw-file] file and ensure it's in the same folder as the `docker-compose.yml`. | ||||
| 2. In the `docker-compose.yml` under `immich-machine-learning`, uncomment the `extends` section and change `cpu` to the appropriate backend. | ||||
| 3. Still in `immich-machine-learning`, add one of -[armnn, cuda, openvino] to the `image` section's tag at the end of the line. | ||||
| 3. Still in `immich-machine-learning`, add one of -[armnn, cuda, rocm, openvino] to the `image` section's tag at the end of the line. | ||||
| 4. Redeploy the `immich-machine-learning` container with these updated settings. | ||||
| 
 | ||||
| ### Confirming Device Usage | ||||
| 
 | ||||
| You can confirm the device is being recognized and used by checking its utilization. There are many tools to display this, such as `nvtop` for NVIDIA or Intel and `intel_gpu_top` for Intel. | ||||
| You can confirm the device is being recognized and used by checking its utilization. There are many tools to display this, such as `nvtop` for NVIDIA or Intel, `intel_gpu_top` for Intel, and `radeontop` for AMD. | ||||
| 
 | ||||
| You can also check the logs of the `immich-machine-learning` container. When a Smart Search or Face Detection job begins, or when you search with text in Immich, you should either see a log for `Available ORT providers` containing the relevant provider (e.g. `CUDAExecutionProvider` in the case of CUDA), or a `Loaded ANN model` log entry without errors in the case of ARM NN. | ||||
| 
 | ||||
|  | ||||
| @ -23,12 +23,12 @@ name: immich_remote_ml | ||||
| services: | ||||
|   immich-machine-learning: | ||||
|     container_name: immich_machine_learning | ||||
|     # For hardware acceleration, add one of -[armnn, cuda, openvino] to the image tag. | ||||
|     # For hardware acceleration, add one of -[armnn, cuda, rocm, openvino] to the image tag. | ||||
|     # Example tag: ${IMMICH_VERSION:-release}-cuda | ||||
|     image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release} | ||||
|     # extends: | ||||
|     #   file: hwaccel.ml.yml | ||||
|     #   service: # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable | ||||
|     #   service: # set to one of [armnn, cuda, rocm, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable | ||||
|     volumes: | ||||
|       - model-cache:/cache | ||||
|     restart: always | ||||
|  | ||||
| @ -15,6 +15,40 @@ RUN mkdir /opt/armnn && \ | ||||
|     cd /opt/ann && \ | ||||
|     sh build.sh | ||||
| 
 | ||||
| # Warning: 26.3Gb of disk space required to pull this image | ||||
| # https://github.com/microsoft/onnxruntime/blob/main/dockerfiles/Dockerfile.rocm | ||||
| FROM rocm/dev-ubuntu-22.04:6.1.2-complete as builder-rocm | ||||
| 
 | ||||
| WORKDIR /code | ||||
| 
 | ||||
| RUN apt-get update && apt-get install -y --no-install-recommends wget git python3.10-venv | ||||
| # Install same version as the Dockerfile provided by onnxruntime | ||||
| RUN wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.sh && \ | ||||
|     chmod +x cmake-3.27.3-linux-x86_64.sh && \ | ||||
|     mkdir -p /code/cmake-3.27.3-linux-x86_64 && \ | ||||
|     ./cmake-3.27.3-linux-x86_64.sh --skip-license --prefix=/code/cmake-3.27.3-linux-x86_64 && \ | ||||
|     rm cmake-3.27.3-linux-x86_64.sh | ||||
| 
 | ||||
| ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} | ||||
| 
 | ||||
| # Prepare onnxruntime repository & build onnxruntime | ||||
| RUN git clone --single-branch --branch v1.18.1 --recursive "https://github.com/Microsoft/onnxruntime" onnxruntime | ||||
| WORKDIR /code/onnxruntime | ||||
| # EDIT PR | ||||
| # While there's still this PR open, we need to compile on the branch of the PR | ||||
| # https://github.com/microsoft/onnxruntime/pull/19567 | ||||
| COPY ./rocm-PR19567.patch /tmp/ | ||||
| RUN git apply /tmp/rocm-PR19567.patch | ||||
| # END EDIT PR | ||||
| RUN /bin/sh ./dockerfiles/scripts/install_common_deps.sh | ||||
| # I ran into a compilation error when parallelizing the build | ||||
| # I used 12 threads to build onnxruntime, but it needs more than 16GB of RAM, and that's the amount of RAM I have on my machine | ||||
| # I lowered the number of threads to 8, and it worked | ||||
| # Even with 12 threads, the compilation took more than 1,5 hours to fail | ||||
| RUN ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel 9 --cmake_extra_defines\ | ||||
|     ONNXRUNTIME_VERSION=1.18.1 --use_rocm --rocm_home=/opt/rocm | ||||
| RUN mv /code/onnxruntime/build/Linux/Release/dist/*.whl /opt/ | ||||
| 
 | ||||
| FROM builder-${DEVICE} AS builder | ||||
| 
 | ||||
| ARG DEVICE | ||||
| @ -32,6 +66,9 @@ RUN poetry config installer.max-workers 10 && \ | ||||
| RUN python3 -m venv /opt/venv | ||||
| 
 | ||||
| COPY poetry.lock pyproject.toml ./ | ||||
| RUN if [ "$DEVICE" = "rocm" ]; then \ | ||||
|     poetry add /opt/onnxruntime_rocm-*.whl; \ | ||||
|     fi | ||||
| RUN poetry install --sync --no-interaction --no-ansi --no-root --with ${DEVICE} --without dev | ||||
| 
 | ||||
| FROM python:3.11-slim-bookworm@sha256:614c8691ab74150465ec9123378cd4dde7a6e57be9e558c3108df40664667a4c AS prod-cpu | ||||
| @ -80,11 +117,15 @@ COPY --from=builder-armnn \ | ||||
|     /opt/ann/build.sh \ | ||||
|     /opt/armnn/ | ||||
| 
 | ||||
| FROM rocm/dev-ubuntu-22.04:6.1.2-complete AS prod-rocm | ||||
| 
 | ||||
| 
 | ||||
| FROM prod-${DEVICE} AS prod | ||||
| 
 | ||||
| ARG DEVICE | ||||
| 
 | ||||
| RUN apt-get update && \ | ||||
|     apt-get install -y --no-install-recommends tini $(if ! [ "$DEVICE" = "openvino" ]; then echo "libmimalloc2.0"; fi) && \ | ||||
|     apt-get install -y --no-install-recommends tini $(if ! [ "$DEVICE" = "openvino" ] && ! [ "$DEVICE" = "rocm" ]; then echo "libmimalloc2.0"; fi) && \ | ||||
|     apt-get autoremove -yqq && \ | ||||
|     apt-get clean && \ | ||||
|     rm -rf /var/lib/apt/lists/* | ||||
|  | ||||
| @ -7,7 +7,7 @@ | ||||
| 
 | ||||
| This project uses [Poetry](https://python-poetry.org/docs/#installation), so be sure to install it first. | ||||
| Running `poetry install --no-root --with dev --with cpu` will install everything you need in an isolated virtual environment. | ||||
| CUDA and OpenVINO are supported as acceleration APIs. To use them, you can replace `--with cpu` with either of `--with cuda` or `--with openvino`. In the case of CUDA, a [compute capability](https://developer.nvidia.com/cuda-gpus) of 5.2 or higher is required. | ||||
| CUDA, ROCM and OpenVINO are supported as acceleration APIs. To use them, you can replace `--with cpu` with either of `--with cuda`, `--with rocm` or `--with openvino`. In the case of CUDA, a [compute capability](https://developer.nvidia.com/cuda-gpus) of 5.2 or higher is required. | ||||
| 
 | ||||
| To add or remove dependencies, you can use the commands `poetry add $PACKAGE_NAME` and `poetry remove $PACKAGE_NAME`, respectively. | ||||
| Be sure to commit the `poetry.lock` and `pyproject.toml` files with `poetry lock --no-update` to reflect any changes in dependencies. | ||||
|  | ||||
| @ -63,7 +63,7 @@ _INSIGHTFACE_MODELS = { | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] | ||||
| SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "ROCMExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] | ||||
| 
 | ||||
| 
 | ||||
| def get_model_source(model_name: str) -> ModelSource | None: | ||||
|  | ||||
| @ -88,7 +88,7 @@ class OrtSession: | ||||
|             match provider: | ||||
|                 case "CPUExecutionProvider": | ||||
|                     options = {"arena_extend_strategy": "kSameAsRequested"} | ||||
|                 case "CUDAExecutionProvider": | ||||
|                 case "CUDAExecutionProvider" | "ROCMExecutionProvider": | ||||
|                     options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id} | ||||
|                 case "OpenVINOExecutionProvider": | ||||
|                     options = { | ||||
|  | ||||
							
								
								
									
										46
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										46
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							| @ -1,4 +1,4 @@ | ||||
| # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. | ||||
| # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. | ||||
| 
 | ||||
| [[package]] | ||||
| name = "aiocache" | ||||
| @ -147,10 +147,6 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, | ||||
|     {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, | ||||
| @ -163,14 +159,8 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, | ||||
|     {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, | ||||
| @ -181,24 +171,8 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, | ||||
|     {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, | ||||
|     {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, | ||||
| @ -208,10 +182,6 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, | ||||
|     {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, | ||||
| @ -223,10 +193,6 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, | ||||
|     {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, | ||||
| @ -239,10 +205,6 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, | ||||
|     {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, | ||||
| @ -255,10 +217,6 @@ files = [ | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, | ||||
|     {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, | ||||
|     {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, | ||||
| @ -3735,4 +3693,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] | ||||
| [metadata] | ||||
| lock-version = "2.0" | ||||
| python-versions = ">=3.10,<4.0" | ||||
| content-hash = "b690d5fbd141da3947f4f1dc029aba1b95e7faafd723166f2c4bdc47a66c095e" | ||||
| content-hash = "271a6c2a76b1b6286e02b91489ffd0c42e92daf151ae932514f5416c7869f71d" | ||||
|  | ||||
| @ -47,6 +47,11 @@ optional = true | ||||
| [tool.poetry.group.cuda.dependencies] | ||||
| onnxruntime-gpu = {version = "^1.17.0", source = "cuda12"} | ||||
| 
 | ||||
| [tool.poetry.group.rocm] | ||||
| optional = true | ||||
| 
 | ||||
| [tool.poetry.group.rocm.dependencies] | ||||
| 
 | ||||
| [tool.poetry.group.openvino] | ||||
| optional = true | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										176
									
								
								machine-learning/rocm-PR19567.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								machine-learning/rocm-PR19567.patch
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,176 @@ | ||||
| From a598a88db258f82a6e4bca75810921bd6bcee7e0 Mon Sep 17 00:00:00 2001 | ||||
| From: David Nieto <dmnieto@gmail.com> | ||||
| Date: Sat, 17 Feb 2024 11:23:12 -0800 | ||||
| Subject: [PATCH] Disable algo caching in ROCM EP | ||||
| 
 | ||||
| Similar to the work done by Liangxijun-1001 in | ||||
| https://github.com/apache/tvm/pull/16178 the ROCM spec mandates calling | ||||
| miopenFindConvolution*Algorithm() before using any Convolution API | ||||
| 
 | ||||
| This is the link to the porting guide describing this requirement | ||||
| https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html | ||||
| 
 | ||||
| Thus, this change disables the algo cache and enforces the official | ||||
| API semantics | ||||
| 
 | ||||
| Signed-off-by: David Nieto <dmnieto@gmail.com> | ||||
| ---
 | ||||
|  onnxruntime/core/providers/rocm/nn/conv.cc    | 61 +++++++++---------- | ||||
|  onnxruntime/core/providers/rocm/nn/conv.h     |  6 -- | ||||
|  .../core/providers/rocm/nn/conv_transpose.cc  | 17 +++--- | ||||
|  3 files changed, 36 insertions(+), 48 deletions(-) | ||||
| 
 | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| index 6214ec7bc0ea..b08aceca48b1 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.cc
 | ||||
| @@ -125,10 +125,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|      if (input_dims_changed) | ||||
|        s_.last_x_dims = gsl::make_span(x_dims); | ||||
|   | ||||
| -    if (w_dims_changed) {
 | ||||
| +    if (w_dims_changed)
 | ||||
|        s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -      s_.cached_benchmark_fwd_results.clear();
 | ||||
| -    }
 | ||||
|   | ||||
|      ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); | ||||
|   | ||||
| @@ -277,35 +275,6 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|        HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size)); | ||||
|        HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); | ||||
|      } | ||||
| -
 | ||||
| -    if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) {
 | ||||
| -      miopenConvAlgoPerf_t perf;
 | ||||
| -      int algo_count = 1;
 | ||||
| -      const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
 | ||||
| -      static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
 | ||||
| -      size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
 | ||||
| -                                                                   : AlgoSearchWorkspaceSize;
 | ||||
| -      IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
 | ||||
| -      MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
 | ||||
| -          GetMiopenHandle(context),
 | ||||
| -          s_.x_tensor,
 | ||||
| -          s_.x_data,
 | ||||
| -          s_.w_desc,
 | ||||
| -          s_.w_data,
 | ||||
| -          s_.conv_desc,
 | ||||
| -          s_.y_tensor,
 | ||||
| -          s_.y_data,
 | ||||
| -          1,            // requestedAlgoCount
 | ||||
| -          &algo_count,  // returnedAlgoCount
 | ||||
| -          &perf,
 | ||||
| -          algo_search_workspace.get(),
 | ||||
| -          max_ws_size,
 | ||||
| -          false));  // Do not do exhaustive algo search.
 | ||||
| -      s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory});
 | ||||
| -    }
 | ||||
| -    const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen);
 | ||||
| -    s_.fwd_algo = perf.fwd_algo;
 | ||||
| -    s_.workspace_bytes = perf.memory;
 | ||||
|    } else { | ||||
|      // set Y | ||||
|      s_.Y = context->Output(0, TensorShape(s_.y_dims)); | ||||
| @@ -319,6 +288,34 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 | ||||
|        s_.y_data = reinterpret_cast<HipT*>(s_.Y->MutableData<T>()); | ||||
|      } | ||||
|    } | ||||
| +  {
 | ||||
| +    /* FindConvolution must always be called by the runtime */
 | ||||
| +    TensorShapeVector x_dims_miopen{x_dims.begin(), x_dims.end()};
 | ||||
| +    miopenConvAlgoPerf_t perf;
 | ||||
| +    int algo_count = 1;
 | ||||
| +    const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
 | ||||
| +    static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
 | ||||
| +    size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
 | ||||
| +                                                                   : AlgoSearchWorkspaceSize;
 | ||||
| +    IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
 | ||||
| +    MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
 | ||||
| +        GetMiopenHandle(context),
 | ||||
| +        s_.x_tensor,
 | ||||
| +        s_.x_data,
 | ||||
| +        s_.w_desc,
 | ||||
| +        s_.w_data,
 | ||||
| +        s_.conv_desc,
 | ||||
| +        s_.y_tensor,
 | ||||
| +        s_.y_data,
 | ||||
| +        1,            // requestedAlgoCount
 | ||||
| +        &algo_count,  // returnedAlgoCount
 | ||||
| +        &perf,
 | ||||
| +        algo_search_workspace.get(),
 | ||||
| +        max_ws_size,
 | ||||
| +        false));  // Do not do exhaustive algo search.
 | ||||
| +    s_.fwd_algo = perf.fwd_algo;
 | ||||
| +    s_.workspace_bytes = perf.memory;
 | ||||
| +  }
 | ||||
|    return Status::OK(); | ||||
|  } | ||||
|   | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| index bc9846203e57..d54218f25854 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv.h
 | ||||
| @@ -108,9 +108,6 @@ class lru_unordered_map {
 | ||||
|    list_type lru_list_; | ||||
|  }; | ||||
|   | ||||
| -// cached miopen descriptors
 | ||||
| -constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000;
 | ||||
| -
 | ||||
|  template <typename AlgoPerfType> | ||||
|  struct MiopenConvState { | ||||
|    // if x/w dims changed, update algo and miopenTensors | ||||
| @@ -148,9 +145,6 @@ struct MiopenConvState {
 | ||||
|      decltype(AlgoPerfType().memory) memory; | ||||
|    }; | ||||
|   | ||||
| -  lru_unordered_map<TensorShapeVector, PerfFwdResultParams, vector_hash> cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| -  lru_unordered_map<TensorShapeVector, PerfBwdResultParams, vector_hash> cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS};
 | ||||
| -
 | ||||
|    // Some properties needed to support asymmetric padded Conv nodes | ||||
|    bool post_slicing_required; | ||||
|    TensorShapeVector slice_starts; | ||||
| diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| index 7447113fdf84..45ed4c8ac37a 100644
 | ||||
| --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc
 | ||||
| @@ -76,7 +76,6 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        if (w_dims_changed) { | ||||
|          s_.last_w_dims = gsl::make_span(w_dims); | ||||
| -        s_.cached_benchmark_bwd_results.clear();
 | ||||
|        } | ||||
|   | ||||
|        ConvTransposeAttributes::Prepare p; | ||||
| @@ -127,12 +126,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|   | ||||
|        y_data = reinterpret_cast<HipT*>(p.Y->MutableData<T>()); | ||||
|   | ||||
| -      if (!s_.cached_benchmark_bwd_results.contains(x_dims)) {
 | ||||
| -        IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
 | ||||
| -
 | ||||
| -        miopenConvAlgoPerf_t perf;
 | ||||
| -        int algo_count = 1;
 | ||||
| -        MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
 | ||||
| +    }
 | ||||
| +    // The following is required before calling convolution, we cannot cache the results
 | ||||
| +    {
 | ||||
| +      IAllocatorUniquePtr<void> algo_search_workspace = GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
 | ||||
| +      miopenConvAlgoPerf_t perf;
 | ||||
| +      int algo_count = 1;
 | ||||
| +      MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm(
 | ||||
|              GetMiopenHandle(context), | ||||
|              s_.x_tensor, | ||||
|              x_data, | ||||
| @@ -147,10 +147,7 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
 | ||||
|              algo_search_workspace.get(), | ||||
|              AlgoSearchWorkspaceSize, | ||||
|              false)); | ||||
| -        s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory});
 | ||||
| -      }
 | ||||
|   | ||||
| -      const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims);
 | ||||
|        s_.bwd_data_algo = perf.bwd_data_algo; | ||||
|        s_.workspace_bytes = perf.memory; | ||||
|      } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user