mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-26 08:12:33 -04:00 
			
		
		
		
	feat(ml): ARMNN acceleration (#5667)
* feat(ml): ARMNN acceleration for CLIP * wrap ANN as ONNX-Session * strict typing * normalize ARMNN CLIP embedding * mutex to handle concurrent execution * make inputs contiguous * fine-grained locking; concurrent network execution --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									29747437f6
								
							
						
					
					
						commit
						753292956e
					
				
							
								
								
									
										11
									
								
								docker/mlaccel-armnn.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								docker/mlaccel-armnn.yml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | |||||||
|  | version: "3.8" | ||||||
|  | 
 | ||||||
|  | # ML acceleration on supported Mali ARM GPUs using ARM-NN | ||||||
|  | 
 | ||||||
|  | services: | ||||||
|  |   mlaccel: | ||||||
|  |     devices: | ||||||
|  |       - /dev/mali0:/dev/mali0 | ||||||
|  |     volumes: | ||||||
|  |       - /lib/firmware/mali_csffw.bin:/lib/firmware/mali_csffw.bin:ro # Mali firmware for your chipset (not always required depending on the driver) | ||||||
|  |       - /usr/lib/libmali.so:/usr/lib/libmali.so:ro # Mali driver for your chipset (always required) | ||||||
| @ -13,17 +13,40 @@ ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}" | |||||||
| COPY poetry.lock pyproject.toml ./ | COPY poetry.lock pyproject.toml ./ | ||||||
| RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | ||||||
| 
 | 
 | ||||||
| FROM python:3.11-slim-bookworm@sha256:8f64a67710f3d981cf3008d6f9f1dbe61accd7927f165f4e37ea3f8b883ccc3f |  | ||||||
| 
 | 
 | ||||||
|  | ARG TARGETPLATFORM | ||||||
|  | ENV ARMNN_PATH=/opt/armnn | ||||||
|  | COPY ann /opt/ann | ||||||
|  | RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ | ||||||
|  |     mkdir /opt/armnn && \ | ||||||
|  |     curl -SL "https://github.com/ARM-software/armnn/releases/download/v23.11/ArmNN-linux-aarch64.tar.gz" | tar -zx -C /opt/armnn && \ | ||||||
|  |     cd /opt/ann && \ | ||||||
|  |     sh build.sh; \ | ||||||
|  |   else \ | ||||||
|  |     mkdir /opt/armnn; \ | ||||||
|  |   fi | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | FROM python:3.11-slim-bookworm@sha256:8f64a67710f3d981cf3008d6f9f1dbe61accd7927f165f4e37ea3f8b883ccc3f | ||||||
|  | ARG TARGETPLATFORM | ||||||
| RUN apt-get update && apt-get install -y --no-install-recommends tini libmimalloc2.0 && rm -rf /var/lib/apt/lists/* | RUN apt-get update && apt-get install -y --no-install-recommends tini libmimalloc2.0 && rm -rf /var/lib/apt/lists/* | ||||||
| 
 | 
 | ||||||
|  | RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ | ||||||
|  |     apt-get update && apt-get install -y --no-install-recommends ocl-icd-libopencl1 mesa-opencl-icd && \ | ||||||
|  |     rm -rf /var/lib/apt/lists/* && \ | ||||||
|  |     mkdir --parents /etc/OpenCL/vendors && \ | ||||||
|  |     echo "/usr/lib/libmali.so" > /etc/OpenCL/vendors/mali.icd && \ | ||||||
|  |     mkdir /opt/armnn; \ | ||||||
|  |   fi | ||||||
|  | 
 | ||||||
| WORKDIR /usr/src/app | WORKDIR /usr/src/app | ||||||
| ENV NODE_ENV=production \ | ENV NODE_ENV=production \ | ||||||
|   TRANSFORMERS_CACHE=/cache \ |   TRANSFORMERS_CACHE=/cache \ | ||||||
|   PYTHONDONTWRITEBYTECODE=1 \ |   PYTHONDONTWRITEBYTECODE=1 \ | ||||||
|   PYTHONUNBUFFERED=1 \ |   PYTHONUNBUFFERED=1 \ | ||||||
|   PATH="/opt/venv/bin:$PATH" \ |   PATH="/opt/venv/bin:$PATH" \ | ||||||
|   PYTHONPATH=/usr/src |   PYTHONPATH=/usr/src \ | ||||||
|  |   LD_LIBRARY_PATH=/opt/armnn | ||||||
| 
 | 
 | ||||||
| # prevent core dumps | # prevent core dumps | ||||||
| RUN echo "hard core 0" >> /etc/security/limits.conf && \ | RUN echo "hard core 0" >> /etc/security/limits.conf && \ | ||||||
| @ -31,7 +54,10 @@ RUN echo "hard core 0" >> /etc/security/limits.conf && \ | |||||||
|     echo 'ulimit -S -c 0 > /dev/null 2>&1' >> /etc/profile |     echo 'ulimit -S -c 0 > /dev/null 2>&1' >> /etc/profile | ||||||
| 
 | 
 | ||||||
| COPY --from=builder /opt/venv /opt/venv | COPY --from=builder /opt/venv /opt/venv | ||||||
|  | COPY --from=builder /opt/armnn/libarmnn.so.?? /opt/armnn/libarmnnOnnxParser.so.?? /opt/armnn/libarmnnDeserializer.so.?? /opt/armnn/libarmnnTfLiteParser.so.?? /opt/armnn/libprotobuf.so.?.??.?.? /opt/ann/libann.s[o] /opt/ann/build.sh /opt/armnn | ||||||
|  | COPY ann/ann.py /usr/src/ann/ann.py | ||||||
| COPY start.sh log_conf.json ./ | COPY start.sh log_conf.json ./ | ||||||
| COPY app . | COPY app . | ||||||
|  | 
 | ||||||
| ENTRYPOINT ["tini", "--"] | ENTRYPOINT ["tini", "--"] | ||||||
| CMD ["./start.sh"] | CMD ["./start.sh"] | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								machine-learning/ann/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								machine-learning/ann/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | from .ann import Ann, is_available | ||||||
							
								
								
									
										281
									
								
								machine-learning/ann/ann.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										281
									
								
								machine-learning/ann/ann.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,281 @@ | |||||||
|  | #include <fstream> | ||||||
|  | #include <mutex> | ||||||
|  | #include <atomic> | ||||||
|  | 
 | ||||||
|  | #include "armnn/IRuntime.hpp" | ||||||
|  | #include "armnn/INetwork.hpp" | ||||||
|  | #include "armnn/Types.hpp" | ||||||
|  | #include "armnnDeserializer/IDeserializer.hpp" | ||||||
|  | #include "armnnTfLiteParser/ITfLiteParser.hpp" | ||||||
|  | #include "armnnOnnxParser/IOnnxParser.hpp" | ||||||
|  | 
 | ||||||
|  | using namespace armnn; | ||||||
|  | 
 | ||||||
|  | struct IOInfos | ||||||
|  | { | ||||||
|  |     std::vector<BindingPointInfo> inputInfos; | ||||||
|  |     std::vector<BindingPointInfo> outputInfos; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // from https://rigtorp.se/spinlock/
 | ||||||
|  | struct SpinLock | ||||||
|  | { | ||||||
|  |     std::atomic<bool> lock_ = {false}; | ||||||
|  | 
 | ||||||
|  |     void lock() | ||||||
|  |     { | ||||||
|  |         for (;;) | ||||||
|  |         { | ||||||
|  |             if (!lock_.exchange(true, std::memory_order_acquire)) | ||||||
|  |             { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             while (lock_.load(std::memory_order_relaxed)) | ||||||
|  |                 ; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void unlock() { lock_.store(false, std::memory_order_release); } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class Ann | ||||||
|  | { | ||||||
|  | 
 | ||||||
|  | public: | ||||||
|  |     int load(const char *modelPath, | ||||||
|  |              bool fastMath, | ||||||
|  |              bool fp16, | ||||||
|  |              bool saveCachedNetwork, | ||||||
|  |              const char *cachedNetworkPath) | ||||||
|  |     { | ||||||
|  |         INetworkPtr network = loadModel(modelPath); | ||||||
|  |         IOptimizedNetworkPtr optNet = OptimizeNetwork(network.get(), fastMath, fp16, saveCachedNetwork, cachedNetworkPath); | ||||||
|  |         const IOInfos infos = getIOInfos(optNet.get()); | ||||||
|  |         NetworkId netId; | ||||||
|  |         mutex.lock(); | ||||||
|  |         Status status = runtime->LoadNetwork(netId, std::move(optNet)); | ||||||
|  |         mutex.unlock(); | ||||||
|  |         if (status != Status::Success) | ||||||
|  |         { | ||||||
|  |             return -1; | ||||||
|  |         } | ||||||
|  |         spinLock.lock(); | ||||||
|  |         ioInfos[netId] = infos; | ||||||
|  |         mutexes.emplace(netId, std::make_unique<std::mutex>()); | ||||||
|  |         spinLock.unlock(); | ||||||
|  |         return netId; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void execute(NetworkId netId, const void **inputData, void **outputData) | ||||||
|  |     { | ||||||
|  |         spinLock.lock(); | ||||||
|  |         const IOInfos *infos = &ioInfos[netId]; | ||||||
|  |         auto m = mutexes[netId].get(); | ||||||
|  |         spinLock.unlock(); | ||||||
|  |         InputTensors inputTensors; | ||||||
|  |         inputTensors.reserve(infos->inputInfos.size()); | ||||||
|  |         size_t i = 0; | ||||||
|  |         for (const BindingPointInfo &info : infos->inputInfos) | ||||||
|  |             inputTensors.emplace_back(info.first, ConstTensor(info.second, inputData[i++])); | ||||||
|  |         OutputTensors outputTensors; | ||||||
|  |         outputTensors.reserve(infos->outputInfos.size()); | ||||||
|  |         i = 0; | ||||||
|  |         for (const BindingPointInfo &info : infos->outputInfos) | ||||||
|  |             outputTensors.emplace_back(info.first, Tensor(info.second, outputData[i++])); | ||||||
|  |         m->lock(); | ||||||
|  |         runtime->EnqueueWorkload(netId, inputTensors, outputTensors); | ||||||
|  |         m->unlock(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void unload(NetworkId netId) | ||||||
|  |     { | ||||||
|  |         mutex.lock(); | ||||||
|  |         runtime->UnloadNetwork(netId); | ||||||
|  |         mutex.unlock(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int tensors(NetworkId netId, bool isInput = false) | ||||||
|  |     { | ||||||
|  |         spinLock.lock(); | ||||||
|  |         const IOInfos *infos = &ioInfos[netId]; | ||||||
|  |         spinLock.unlock(); | ||||||
|  |         return (int)(isInput ? infos->inputInfos.size() : infos->outputInfos.size()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     unsigned long shape(NetworkId netId, bool isInput = false, int index = 0) | ||||||
|  |     { | ||||||
|  |         spinLock.lock(); | ||||||
|  |         const IOInfos *infos = &ioInfos[netId]; | ||||||
|  |         spinLock.unlock(); | ||||||
|  |         const TensorShape shape = (isInput ? infos->inputInfos : infos->outputInfos)[index].second.GetShape(); | ||||||
|  |         unsigned long s = 0; | ||||||
|  |         for (unsigned int d = 0; d < shape.GetNumDimensions(); d++) | ||||||
|  |             s |= ((unsigned long)shape[d]) << (d * 16); // stores up to 4 16-bit values in a 64-bit value
 | ||||||
|  |         return s; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ann(int tuningLevel, const char *tuningFile) | ||||||
|  |     { | ||||||
|  |         IRuntime::CreationOptions runtimeOptions; | ||||||
|  |         BackendOptions backendOptions{"GpuAcc", | ||||||
|  |                                       { | ||||||
|  |                                           {"TuningLevel", tuningLevel}, | ||||||
|  |                                           {"MemoryOptimizerStrategy", "ConstantMemoryStrategy"}, // SingleAxisPriorityList or ConstantMemoryStrategy
 | ||||||
|  |                                       }}; | ||||||
|  |         if (tuningFile) | ||||||
|  |             backendOptions.AddOption({"TuningFile", tuningFile}); | ||||||
|  |         runtimeOptions.m_BackendOptions.emplace_back(backendOptions); | ||||||
|  |         runtime = IRuntime::CreateRaw(runtimeOptions); | ||||||
|  |     }; | ||||||
|  |     ~Ann() | ||||||
|  |     { | ||||||
|  |         IRuntime::Destroy(runtime); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  | private: | ||||||
|  |     INetworkPtr loadModel(const char *modelPath) | ||||||
|  |     { | ||||||
|  |         const auto path = std::string(modelPath); | ||||||
|  |         if (path.rfind(".tflite") == path.length() - 7) // endsWith()
 | ||||||
|  |         { | ||||||
|  |             auto parser = armnnTfLiteParser::ITfLiteParser::CreateRaw(); | ||||||
|  |             return parser->CreateNetworkFromBinaryFile(modelPath); | ||||||
|  |         } | ||||||
|  |         else if (path.rfind(".onnx") == path.length() - 5) // endsWith()
 | ||||||
|  |         { | ||||||
|  |             auto parser = armnnOnnxParser::IOnnxParser::CreateRaw(); | ||||||
|  |             return parser->CreateNetworkFromBinaryFile(modelPath); | ||||||
|  |         } | ||||||
|  |         else | ||||||
|  |         { | ||||||
|  |             std::ifstream ifs(path, std::ifstream::in | std::ifstream::binary); | ||||||
|  |             auto parser = armnnDeserializer::IDeserializer::CreateRaw(); | ||||||
|  |             return parser->CreateNetworkFromBinary(ifs); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static BindingPointInfo getInputTensorInfo(LayerBindingId inputBindingId, TensorInfo info) | ||||||
|  |     { | ||||||
|  |         const auto newInfo = TensorInfo{info.GetShape(), info.GetDataType(), | ||||||
|  |                                         info.GetQuantizationScale(), | ||||||
|  |                                         info.GetQuantizationOffset(), | ||||||
|  |                                         true}; | ||||||
|  |         return {inputBindingId, newInfo}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     IOptimizedNetworkPtr OptimizeNetwork(INetwork *network, bool fastMath, bool fp16, bool saveCachedNetwork, const char *cachedNetworkPath) | ||||||
|  |     { | ||||||
|  |         const bool allowExpandedDims = false; | ||||||
|  |         const ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly; | ||||||
|  | 
 | ||||||
|  |         OptimizerOptionsOpaque options; | ||||||
|  |         options.SetReduceFp32ToFp16(fp16); | ||||||
|  |         options.SetShapeInferenceMethod(shapeInferenceMethod); | ||||||
|  |         options.SetAllowExpandedDims(allowExpandedDims); | ||||||
|  | 
 | ||||||
|  |         BackendOptions gpuAcc("GpuAcc", {{"FastMathEnabled", fastMath}}); | ||||||
|  |         if (cachedNetworkPath) | ||||||
|  |         { | ||||||
|  |             gpuAcc.AddOption({"SaveCachedNetwork", saveCachedNetwork}); | ||||||
|  |             gpuAcc.AddOption({"CachedNetworkFilePath", cachedNetworkPath}); | ||||||
|  |         } | ||||||
|  |         options.AddModelOption(gpuAcc); | ||||||
|  | 
 | ||||||
|  |         // No point in using ARMNN for CPU, use ONNX (quantized) instead.
 | ||||||
|  |         // BackendOptions cpuAcc("CpuAcc",
 | ||||||
|  |         //                       {
 | ||||||
|  |         //                           {"FastMathEnabled", fastMath},
 | ||||||
|  |         //                           {"NumberOfThreads", 0},
 | ||||||
|  |         //                       });
 | ||||||
|  |         // options.AddModelOption(cpuAcc);
 | ||||||
|  | 
 | ||||||
|  |         BackendOptions allowExDimOpt("AllowExpandedDims", | ||||||
|  |                                      {{"AllowExpandedDims", allowExpandedDims}}); | ||||||
|  |         options.AddModelOption(allowExDimOpt); | ||||||
|  |         BackendOptions shapeInferOpt("ShapeInferenceMethod", | ||||||
|  |                                      {{"InferAndValidate", shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate}}); | ||||||
|  |         options.AddModelOption(shapeInferOpt); | ||||||
|  | 
 | ||||||
|  |         std::vector<BackendId> backends = { | ||||||
|  |             BackendId("GpuAcc"), | ||||||
|  |             // BackendId("CpuAcc"),
 | ||||||
|  |             // BackendId("CpuRef"),
 | ||||||
|  |         }; | ||||||
|  |         return Optimize(*network, backends, runtime->GetDeviceSpec(), options); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     IOInfos getIOInfos(IOptimizedNetwork *optNet) | ||||||
|  |     { | ||||||
|  |         struct InfoStrategy : IStrategy | ||||||
|  |         { | ||||||
|  |             void ExecuteStrategy(const IConnectableLayer *layer, | ||||||
|  |                                  const BaseDescriptor &descriptor, | ||||||
|  |                                  const std::vector<ConstTensor> &constants, | ||||||
|  |                                  const char *name, | ||||||
|  |                                  const LayerBindingId id = 0) override | ||||||
|  |             { | ||||||
|  |                 IgnoreUnused(descriptor, constants, id); | ||||||
|  |                 const LayerType lt = layer->GetType(); | ||||||
|  |                 if (lt == LayerType::Input) | ||||||
|  |                     ioInfos.inputInfos.push_back(getInputTensorInfo(id, layer->GetOutputSlot(0).GetTensorInfo())); | ||||||
|  |                 else if (lt == LayerType::Output) | ||||||
|  |                     ioInfos.outputInfos.push_back({id, layer->GetInputSlot(0).GetTensorInfo()}); | ||||||
|  |             } | ||||||
|  |             IOInfos ioInfos; | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         InfoStrategy infoStrategy; | ||||||
|  |         optNet->ExecuteStrategy(infoStrategy); | ||||||
|  |         return infoStrategy.ioInfos; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     IRuntime *runtime; | ||||||
|  |     std::map<NetworkId, IOInfos> ioInfos; | ||||||
|  |     std::map<NetworkId, std::unique_ptr<std::mutex>> mutexes; // mutex per network to not execute the same the same network concurrently
 | ||||||
|  |     std::mutex mutex; // global mutex for load/unload calls to the runtime
 | ||||||
|  |     SpinLock spinLock; // fast spin lock to guard access to the ioInfos and mutexes maps
 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | extern "C" void *init(int logLevel, int tuningLevel, const char *tuningFile) | ||||||
|  | { | ||||||
|  |     LogSeverity level = static_cast<LogSeverity>(logLevel); | ||||||
|  |     ConfigureLogging(true, true, level); | ||||||
|  | 
 | ||||||
|  |     Ann *ann = new Ann(tuningLevel, tuningFile); | ||||||
|  |     return ann; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" void destroy(void *ann) | ||||||
|  | { | ||||||
|  |     delete ((Ann *)ann); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" int load(void *ann, | ||||||
|  |                     const char *path, | ||||||
|  |                     bool fastMath, | ||||||
|  |                     bool fp16, | ||||||
|  |                     bool saveCachedNetwork, | ||||||
|  |                     const char *cachedNetworkPath) | ||||||
|  | { | ||||||
|  |     return ((Ann *)ann)->load(path, fastMath, fp16, saveCachedNetwork, cachedNetworkPath); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" void unload(void *ann, NetworkId netId) | ||||||
|  | { | ||||||
|  |     ((Ann *)ann)->unload(netId); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" void execute(void *ann, NetworkId netId, const void **inputData, void **outputData) | ||||||
|  | { | ||||||
|  |     ((Ann *)ann)->execute(netId, inputData, outputData); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" unsigned long shape(void *ann, NetworkId netId, bool isInput, int index) | ||||||
|  | { | ||||||
|  |     return ((Ann *)ann)->shape(netId, isInput, index); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | extern "C" int tensors(void *ann, NetworkId netId, bool isInput) | ||||||
|  | { | ||||||
|  |     return ((Ann *)ann)->tensors(netId, isInput); | ||||||
|  | } | ||||||
							
								
								
									
										162
									
								
								machine-learning/ann/ann.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								machine-learning/ann/ann.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,162 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  | 
 | ||||||
|  | from ctypes import CDLL, Array, c_bool, c_char_p, c_int, c_ulong, c_void_p | ||||||
|  | from os.path import exists | ||||||
|  | from typing import Any, Generic, Protocol, Type, TypeVar | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | from numpy.typing import NDArray | ||||||
|  | 
 | ||||||
|  | from app.config import log | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     CDLL("libmali.so")  # fail if libmali.so is not mounted into container | ||||||
|  |     libann = CDLL("libann.so") | ||||||
|  |     libann.init.argtypes = c_int, c_int, c_char_p | ||||||
|  |     libann.init.restype = c_void_p | ||||||
|  |     libann.load.argtypes = c_void_p, c_char_p, c_bool, c_bool, c_bool, c_char_p | ||||||
|  |     libann.load.restype = c_int | ||||||
|  |     libann.execute.argtypes = c_void_p, c_int, Array[c_void_p], Array[c_void_p] | ||||||
|  |     libann.unload.argtypes = c_void_p, c_int | ||||||
|  |     libann.destroy.argtypes = (c_void_p,) | ||||||
|  |     libann.shape.argtypes = c_void_p, c_int, c_bool, c_int | ||||||
|  |     libann.shape.restype = c_ulong | ||||||
|  |     libann.tensors.argtypes = c_void_p, c_int, c_bool | ||||||
|  |     libann.tensors.restype = c_int | ||||||
|  |     is_available = True | ||||||
|  | except OSError as e: | ||||||
|  |     log.debug("Could not load ANN shared libraries, using ONNX: %s", e) | ||||||
|  |     is_available = False | ||||||
|  | 
 | ||||||
|  | T = TypeVar("T", covariant=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Newable(Protocol[T]): | ||||||
|  |     def new(self) -> None: | ||||||
|  |         ... | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class _Singleton(type, Newable[T]): | ||||||
|  |     _instances: dict[_Singleton[T], Newable[T]] = {} | ||||||
|  | 
 | ||||||
|  |     def __call__(cls, *args: Any, **kwargs: Any) -> Newable[T]: | ||||||
|  |         if cls not in cls._instances: | ||||||
|  |             obj: Newable[T] = super(_Singleton, cls).__call__(*args, **kwargs) | ||||||
|  |             cls._instances[cls] = obj | ||||||
|  |         else: | ||||||
|  |             obj = cls._instances[cls] | ||||||
|  |             obj.new() | ||||||
|  |         return obj | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Ann(metaclass=_Singleton): | ||||||
|  |     def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None: | ||||||
|  |         if not is_available: | ||||||
|  |             raise RuntimeError("libann is not available!") | ||||||
|  |         if tuning_file and not exists(tuning_file): | ||||||
|  |             raise ValueError("tuning_file must point to an existing (possibly empty) file!") | ||||||
|  |         if tuning_level == 0 and tuning_file is None: | ||||||
|  |             raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file") | ||||||
|  |         if tuning_level < 0 or tuning_level > 3: | ||||||
|  |             raise ValueError("tuning_level must be 0 (load from tuning_file), 1, 2 or 3.") | ||||||
|  |         if log_level < 0 or log_level > 5: | ||||||
|  |             raise ValueError("log_level must be 0 (trace), 1 (debug), 2 (info), 3 (warning), 4 (error) or 5 (fatal)") | ||||||
|  |         self.log_level = log_level | ||||||
|  |         self.tuning_level = tuning_level | ||||||
|  |         self.tuning_file = tuning_file | ||||||
|  |         self.output_shapes: dict[int, tuple[tuple[int], ...]] = {} | ||||||
|  |         self.input_shapes: dict[int, tuple[tuple[int], ...]] = {} | ||||||
|  |         self.ann: int | None = None | ||||||
|  |         self.new() | ||||||
|  | 
 | ||||||
|  |     def new(self) -> None: | ||||||
|  |         if self.ann is None: | ||||||
|  |             self.ann = libann.init( | ||||||
|  |                 self.log_level, | ||||||
|  |                 self.tuning_level, | ||||||
|  |                 self.tuning_file.encode() if self.tuning_file is not None else None, | ||||||
|  |             ) | ||||||
|  |             self.ref_count = 0 | ||||||
|  | 
 | ||||||
|  |         self.ref_count += 1 | ||||||
|  | 
 | ||||||
|  |     def destroy(self) -> None: | ||||||
|  |         self.ref_count -= 1 | ||||||
|  |         if self.ref_count <= 0 and self.ann is not None: | ||||||
|  |             libann.destroy(self.ann) | ||||||
|  |             self.ann = None | ||||||
|  | 
 | ||||||
|  |     def __del__(self) -> None: | ||||||
|  |         if self.ann is not None: | ||||||
|  |             libann.destroy(self.ann) | ||||||
|  |             self.ann = None | ||||||
|  | 
 | ||||||
|  |     def load( | ||||||
|  |         self, | ||||||
|  |         model_path: str, | ||||||
|  |         fast_math: bool = True, | ||||||
|  |         fp16: bool = False, | ||||||
|  |         save_cached_network: bool = False, | ||||||
|  |         cached_network_path: str | None = None, | ||||||
|  |     ) -> int: | ||||||
|  |         if not model_path.endswith((".armnn", ".tflite", ".onnx")): | ||||||
|  |             raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx") | ||||||
|  |         if not exists(model_path): | ||||||
|  |             raise ValueError("model_path must point to an existing file!") | ||||||
|  |         if cached_network_path is not None and not exists(cached_network_path): | ||||||
|  |             raise ValueError("cached_network_path must point to an existing (possibly empty) file!") | ||||||
|  |         if save_cached_network and cached_network_path is None: | ||||||
|  |             raise ValueError("save_cached_network is True, cached_network_path must be specified!") | ||||||
|  |         net_id: int = libann.load( | ||||||
|  |             self.ann, | ||||||
|  |             model_path.encode(), | ||||||
|  |             fast_math, | ||||||
|  |             fp16, | ||||||
|  |             save_cached_network, | ||||||
|  |             cached_network_path.encode() if cached_network_path is not None else None, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.input_shapes[net_id] = tuple( | ||||||
|  |             self.shape(net_id, input=True, index=i) for i in range(self.tensors(net_id, input=True)) | ||||||
|  |         ) | ||||||
|  |         self.output_shapes[net_id] = tuple( | ||||||
|  |             self.shape(net_id, input=False, index=i) for i in range(self.tensors(net_id, input=False)) | ||||||
|  |         ) | ||||||
|  |         return net_id | ||||||
|  | 
 | ||||||
|  |     def unload(self, network_id: int) -> None: | ||||||
|  |         libann.unload(self.ann, network_id) | ||||||
|  |         del self.output_shapes[network_id] | ||||||
|  | 
 | ||||||
|  |     def execute(self, network_id: int, input_tensors: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]: | ||||||
|  |         if not isinstance(input_tensors, list): | ||||||
|  |             raise ValueError("input_tensors needs to be a list!") | ||||||
|  |         net_input_shapes = self.input_shapes[network_id] | ||||||
|  |         if len(input_tensors) != len(net_input_shapes): | ||||||
|  |             raise ValueError(f"input_tensors lengths {len(input_tensors)} != network inputs {len(net_input_shapes)}") | ||||||
|  |         for net_input_shape, input_tensor in zip(net_input_shapes, input_tensors): | ||||||
|  |             if net_input_shape != input_tensor.shape: | ||||||
|  |                 raise ValueError(f"input_tensor shape {input_tensor.shape} != network input shape {net_input_shape}") | ||||||
|  |             if not input_tensor.flags.c_contiguous: | ||||||
|  |                 raise ValueError("input_tensors must be c_contiguous numpy ndarrays") | ||||||
|  |         output_tensors: list[NDArray[np.float32]] = [ | ||||||
|  |             np.ndarray(s, dtype=np.float32) for s in self.output_shapes[network_id] | ||||||
|  |         ] | ||||||
|  |         input_type = c_void_p * len(input_tensors) | ||||||
|  |         inputs = input_type(*[t.ctypes.data_as(c_void_p) for t in input_tensors]) | ||||||
|  |         output_type = c_void_p * len(output_tensors) | ||||||
|  |         outputs = output_type(*[t.ctypes.data_as(c_void_p) for t in output_tensors]) | ||||||
|  |         libann.execute(self.ann, network_id, inputs, outputs) | ||||||
|  |         return output_tensors | ||||||
|  | 
 | ||||||
|  |     def shape(self, network_id: int, input: bool = False, index: int = 0) -> tuple[int]: | ||||||
|  |         s = libann.shape(self.ann, network_id, input, index) | ||||||
|  |         a = [] | ||||||
|  |         while s != 0: | ||||||
|  |             a.append(s & 0xFFFF) | ||||||
|  |             s >>= 16 | ||||||
|  |         return tuple(a) | ||||||
|  | 
 | ||||||
|  |     def tensors(self, network_id: int, input: bool = False) -> int: | ||||||
|  |         tensors: int = libann.tensors(self.ann, network_id, input) | ||||||
|  |         return tensors | ||||||
							
								
								
									
										1
									
								
								machine-learning/ann/build.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								machine-learning/ann/build.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | g++ -shared -O3 -o libann.so -fuse-ld=gold -std=c++17 -I$ARMNN_PATH/include -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -L$ARMNN_PATH ann.cpp | ||||||
							
								
								
									
										2
									
								
								machine-learning/ann/export/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								machine-learning/ann/export/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | armnn* | ||||||
|  | output/ | ||||||
							
								
								
									
										4
									
								
								machine-learning/ann/export/build-converter.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										4
									
								
								machine-learning/ann/export/build-converter.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #!/bin/sh | ||||||
|  | 
 | ||||||
|  | cd armnn-23.11/ | ||||||
|  | g++ -o ../armnnconverter -O1 -DARMNN_ONNX_PARSER -DARMNN_SERIALIZER -DARMNN_TF_LITE_PARSER -fuse-ld=gold -std=c++17 -Iinclude -Isrc/armnnUtils -Ithird-party -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -larmnnSerializer -L../armnn src/armnnConverter/ArmnnConverter.cpp | ||||||
							
								
								
									
										8
									
								
								machine-learning/ann/export/download-armnn.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										8
									
								
								machine-learning/ann/export/download-armnn.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,8 @@ | |||||||
|  | #!/bin/sh | ||||||
|  | 
 | ||||||
|  | # binaries | ||||||
|  | mkdir armnn | ||||||
|  | curl -SL "https://github.com/ARM-software/armnn/releases/download/v23.11/ArmNN-linux-x86_64.tar.gz" | tar -zx -C armnn | ||||||
|  | 
 | ||||||
|  | # source to build ArmnnConverter | ||||||
|  | curl -SL "https://github.com/ARM-software/armnn/archive/refs/tags/v23.11.tar.gz" | tar -zx | ||||||
							
								
								
									
										201
									
								
								machine-learning/ann/export/env.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								machine-learning/ann/export/env.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | |||||||
|  | name: annexport | ||||||
|  | channels: | ||||||
|  |   - pytorch | ||||||
|  |   - nvidia | ||||||
|  |   - conda-forge | ||||||
|  | dependencies: | ||||||
|  |   - _libgcc_mutex=0.1=conda_forge | ||||||
|  |   - _openmp_mutex=4.5=2_kmp_llvm | ||||||
|  |   - aiohttp=3.9.1=py310h2372a71_0 | ||||||
|  |   - aiosignal=1.3.1=pyhd8ed1ab_0 | ||||||
|  |   - arpack=3.8.0=nompi_h0baa96a_101 | ||||||
|  |   - async-timeout=4.0.3=pyhd8ed1ab_0 | ||||||
|  |   - attrs=23.1.0=pyh71513ae_1 | ||||||
|  |   - aws-c-auth=0.7.3=h28f7589_1 | ||||||
|  |   - aws-c-cal=0.6.1=hc309b26_1 | ||||||
|  |   - aws-c-common=0.9.0=hd590300_0 | ||||||
|  |   - aws-c-compression=0.2.17=h4d4d85c_2 | ||||||
|  |   - aws-c-event-stream=0.3.1=h2e3709c_4 | ||||||
|  |   - aws-c-http=0.7.11=h00aa349_4 | ||||||
|  |   - aws-c-io=0.13.32=he9a53bd_1 | ||||||
|  |   - aws-c-mqtt=0.9.3=hb447be9_1 | ||||||
|  |   - aws-c-s3=0.3.14=hf3aad02_1 | ||||||
|  |   - aws-c-sdkutils=0.1.12=h4d4d85c_1 | ||||||
|  |   - aws-checksums=0.1.17=h4d4d85c_1 | ||||||
|  |   - aws-crt-cpp=0.21.0=hb942446_5 | ||||||
|  |   - aws-sdk-cpp=1.10.57=h85b1a90_19 | ||||||
|  |   - blas=2.120=openblas | ||||||
|  |   - blas-devel=3.9.0=20_linux64_openblas | ||||||
|  |   - brotli-python=1.0.9=py310hd8f1fbe_9 | ||||||
|  |   - bzip2=1.0.8=hd590300_5 | ||||||
|  |   - c-ares=1.23.0=hd590300_0 | ||||||
|  |   - ca-certificates=2023.11.17=hbcca054_0 | ||||||
|  |   - certifi=2023.11.17=pyhd8ed1ab_0 | ||||||
|  |   - charset-normalizer=3.3.2=pyhd8ed1ab_0 | ||||||
|  |   - click=8.1.7=unix_pyh707e725_0 | ||||||
|  |   - colorama=0.4.6=pyhd8ed1ab_0 | ||||||
|  |   - coloredlogs=15.0.1=pyhd8ed1ab_3 | ||||||
|  |   - cuda-cudart=11.7.99=0 | ||||||
|  |   - cuda-cupti=11.7.101=0 | ||||||
|  |   - cuda-libraries=11.7.1=0 | ||||||
|  |   - cuda-nvrtc=11.7.99=0 | ||||||
|  |   - cuda-nvtx=11.7.91=0 | ||||||
|  |   - cuda-runtime=11.7.1=0 | ||||||
|  |   - dataclasses=0.8=pyhc8e2a94_3 | ||||||
|  |   - datasets=2.14.7=pyhd8ed1ab_0 | ||||||
|  |   - dill=0.3.7=pyhd8ed1ab_0 | ||||||
|  |   - filelock=3.13.1=pyhd8ed1ab_0 | ||||||
|  |   - flatbuffers=23.5.26=h59595ed_1 | ||||||
|  |   - freetype=2.12.1=h267a509_2 | ||||||
|  |   - frozenlist=1.4.0=py310h2372a71_1 | ||||||
|  |   - fsspec=2023.10.0=pyhca7485f_0 | ||||||
|  |   - ftfy=6.1.3=pyhd8ed1ab_0 | ||||||
|  |   - gflags=2.2.2=he1b5a44_1004 | ||||||
|  |   - glog=0.6.0=h6f12383_0 | ||||||
|  |   - glpk=5.0=h445213a_0 | ||||||
|  |   - gmp=6.3.0=h59595ed_0 | ||||||
|  |   - gmpy2=2.1.2=py310h3ec546c_1 | ||||||
|  |   - huggingface_hub=0.17.3=pyhd8ed1ab_0 | ||||||
|  |   - humanfriendly=10.0=pyhd8ed1ab_6 | ||||||
|  |   - icu=73.2=h59595ed_0 | ||||||
|  |   - idna=3.6=pyhd8ed1ab_0 | ||||||
|  |   - importlib-metadata=7.0.0=pyha770c72_0 | ||||||
|  |   - importlib_metadata=7.0.0=hd8ed1ab_0 | ||||||
|  |   - joblib=1.3.2=pyhd8ed1ab_0 | ||||||
|  |   - keyutils=1.6.1=h166bdaf_0 | ||||||
|  |   - krb5=1.21.2=h659d440_0 | ||||||
|  |   - lcms2=2.15=h7f713cb_2 | ||||||
|  |   - ld_impl_linux-64=2.40=h41732ed_0 | ||||||
|  |   - lerc=4.0.0=h27087fc_0 | ||||||
|  |   - libabseil=20230125.3=cxx17_h59595ed_0 | ||||||
|  |   - libarrow=12.0.1=hb87d912_8_cpu | ||||||
|  |   - libblas=3.9.0=20_linux64_openblas | ||||||
|  |   - libbrotlicommon=1.0.9=h166bdaf_9 | ||||||
|  |   - libbrotlidec=1.0.9=h166bdaf_9 | ||||||
|  |   - libbrotlienc=1.0.9=h166bdaf_9 | ||||||
|  |   - libcblas=3.9.0=20_linux64_openblas | ||||||
|  |   - libcrc32c=1.1.2=h9c3ff4c_0 | ||||||
|  |   - libcublas=11.10.3.66=0 | ||||||
|  |   - libcufft=10.7.2.124=h4fbf590_0 | ||||||
|  |   - libcufile=1.8.1.2=0 | ||||||
|  |   - libcurand=10.3.4.101=0 | ||||||
|  |   - libcurl=8.5.0=hca28451_0 | ||||||
|  |   - libcusolver=11.4.0.1=0 | ||||||
|  |   - libcusparse=11.7.4.91=0 | ||||||
|  |   - libdeflate=1.19=hd590300_0 | ||||||
|  |   - libedit=3.1.20191231=he28a2e2_2 | ||||||
|  |   - libev=4.33=hd590300_2 | ||||||
|  |   - libevent=2.1.12=hf998b51_1 | ||||||
|  |   - libffi=3.4.2=h7f98852_5 | ||||||
|  |   - libgcc-ng=13.2.0=h807b86a_3 | ||||||
|  |   - libgfortran-ng=13.2.0=h69a702a_3 | ||||||
|  |   - libgfortran5=13.2.0=ha4646dd_3 | ||||||
|  |   - libgoogle-cloud=2.12.0=hac9eb74_1 | ||||||
|  |   - libgrpc=1.54.3=hb20ce57_0 | ||||||
|  |   - libhwloc=2.9.3=default_h554bfaf_1009 | ||||||
|  |   - libiconv=1.17=hd590300_1 | ||||||
|  |   - libjpeg-turbo=2.1.5.1=hd590300_1 | ||||||
|  |   - liblapack=3.9.0=20_linux64_openblas | ||||||
|  |   - liblapacke=3.9.0=20_linux64_openblas | ||||||
|  |   - libnghttp2=1.58.0=h47da74e_1 | ||||||
|  |   - libnpp=11.7.4.75=0 | ||||||
|  |   - libnsl=2.0.1=hd590300_0 | ||||||
|  |   - libnuma=2.0.16=h0b41bf4_1 | ||||||
|  |   - libnvjpeg=11.8.0.2=0 | ||||||
|  |   - libopenblas=0.3.25=pthreads_h413a1c8_0 | ||||||
|  |   - libpng=1.6.39=h753d276_0 | ||||||
|  |   - libprotobuf=3.21.12=hfc55251_2 | ||||||
|  |   - libsentencepiece=0.1.99=h180e1df_0 | ||||||
|  |   - libsqlite=3.44.2=h2797004_0 | ||||||
|  |   - libssh2=1.11.0=h0841786_0 | ||||||
|  |   - libstdcxx-ng=13.2.0=h7e041cc_3 | ||||||
|  |   - libthrift=0.18.1=h8fd135c_2 | ||||||
|  |   - libtiff=4.6.0=h29866fb_1 | ||||||
|  |   - libutf8proc=2.8.0=h166bdaf_0 | ||||||
|  |   - libuuid=2.38.1=h0b41bf4_0 | ||||||
|  |   - libwebp-base=1.3.2=hd590300_0 | ||||||
|  |   - libxcb=1.15=h0b41bf4_0 | ||||||
|  |   - libxml2=2.11.6=h232c23b_0 | ||||||
|  |   - libzlib=1.2.13=hd590300_5 | ||||||
|  |   - llvm-openmp=17.0.6=h4dfa4b3_0 | ||||||
|  |   - lz4-c=1.9.4=hcb278e6_0 | ||||||
|  |   - mkl=2022.2.1=h84fe81f_16997 | ||||||
|  |   - mkl-devel=2022.2.1=ha770c72_16998 | ||||||
|  |   - mkl-include=2022.2.1=h84fe81f_16997 | ||||||
|  |   - mpc=1.3.1=hfe3b2da_0 | ||||||
|  |   - mpfr=4.2.1=h9458935_0 | ||||||
|  |   - mpmath=1.3.0=pyhd8ed1ab_0 | ||||||
|  |   - multidict=6.0.4=py310h2372a71_1 | ||||||
|  |   - multiprocess=0.70.15=py310h2372a71_1 | ||||||
|  |   - ncurses=6.4=h59595ed_2 | ||||||
|  |   - numpy=1.26.2=py310hb13e2d6_0 | ||||||
|  |   - onnx=1.14.0=py310ha3deec4_1 | ||||||
|  |   - onnx2torch=1.5.13=pyhd8ed1ab_0 | ||||||
|  |   - onnxruntime=1.16.3=py310hd4b7fbc_1_cpu | ||||||
|  |   - open-clip-torch=2.23.0=pyhd8ed1ab_1 | ||||||
|  |   - openblas=0.3.25=pthreads_h7a3da1a_0 | ||||||
|  |   - openjpeg=2.5.0=h488ebb8_3 | ||||||
|  |   - openssl=3.2.0=hd590300_1 | ||||||
|  |   - orc=1.9.0=h2f23424_1 | ||||||
|  |   - packaging=23.2=pyhd8ed1ab_0 | ||||||
|  |   - pandas=2.1.4=py310hcc13569_0 | ||||||
|  |   - pillow=10.0.1=py310h29da1c1_1 | ||||||
|  |   - pip=23.3.1=pyhd8ed1ab_0 | ||||||
|  |   - protobuf=4.21.12=py310heca2aa9_0 | ||||||
|  |   - pthread-stubs=0.4=h36c2ea0_1001 | ||||||
|  |   - pyarrow=12.0.1=py310h0576679_8_cpu | ||||||
|  |   - pyarrow-hotfix=0.6=pyhd8ed1ab_0 | ||||||
|  |   - pysocks=1.7.1=pyha2e5f31_6 | ||||||
|  |   - python=3.10.13=hd12c33a_0_cpython | ||||||
|  |   - python-dateutil=2.8.2=pyhd8ed1ab_0 | ||||||
|  |   - python-flatbuffers=23.5.26=pyhd8ed1ab_0 | ||||||
|  |   - python-tzdata=2023.3=pyhd8ed1ab_0 | ||||||
|  |   - python-xxhash=3.4.1=py310h2372a71_0 | ||||||
|  |   - python_abi=3.10=4_cp310 | ||||||
|  |   - pytorch=1.13.1=cpu_py310hd11e9c7_1 | ||||||
|  |   - pytorch-cuda=11.7=h778d358_5 | ||||||
|  |   - pytorch-mutex=1.0=cuda | ||||||
|  |   - pytz=2023.3.post1=pyhd8ed1ab_0 | ||||||
|  |   - pyyaml=6.0.1=py310h2372a71_1 | ||||||
|  |   - rdma-core=28.9=h59595ed_1 | ||||||
|  |   - re2=2023.03.02=h8c504da_0 | ||||||
|  |   - readline=8.2=h8228510_1 | ||||||
|  |   - regex=2023.10.3=py310h2372a71_0 | ||||||
|  |   - requests=2.31.0=pyhd8ed1ab_0 | ||||||
|  |   - s2n=1.3.49=h06160fa_0 | ||||||
|  |   - sacremoses=0.0.53=pyhd8ed1ab_0 | ||||||
|  |   - safetensors=0.3.3=py310hcb5633a_1 | ||||||
|  |   - sentencepiece=0.1.99=hff52083_0 | ||||||
|  |   - sentencepiece-python=0.1.99=py310hebdb9f0_0 | ||||||
|  |   - sentencepiece-spm=0.1.99=h180e1df_0 | ||||||
|  |   - setuptools=68.2.2=pyhd8ed1ab_0 | ||||||
|  |   - six=1.16.0=pyh6c4a22f_0 | ||||||
|  |   - sleef=3.5.1=h9b69904_2 | ||||||
|  |   - snappy=1.1.10=h9fff704_0 | ||||||
|  |   - sympy=1.12=pypyh9d50eac_103 | ||||||
|  |   - tbb=2021.11.0=h00ab1b0_0 | ||||||
|  |   - texttable=1.7.0=pyhd8ed1ab_0 | ||||||
|  |   - timm=0.9.12=pyhd8ed1ab_0 | ||||||
|  |   - tk=8.6.13=noxft_h4845f30_101 | ||||||
|  |   - tokenizers=0.14.1=py310h320607d_2 | ||||||
|  |   - torchvision=0.14.1=cpu_py310hd3d2ac3_1 | ||||||
|  |   - tqdm=4.66.1=pyhd8ed1ab_0 | ||||||
|  |   - transformers=4.35.2=pyhd8ed1ab_0 | ||||||
|  |   - typing-extensions=4.9.0=hd8ed1ab_0 | ||||||
|  |   - typing_extensions=4.9.0=pyha770c72_0 | ||||||
|  |   - tzdata=2023c=h71feb2d_0 | ||||||
|  |   - ucx=1.14.1=h64cca9d_5 | ||||||
|  |   - urllib3=2.1.0=pyhd8ed1ab_0 | ||||||
|  |   - wcwidth=0.2.12=pyhd8ed1ab_0 | ||||||
|  |   - wheel=0.42.0=pyhd8ed1ab_0 | ||||||
|  |   - xorg-libxau=1.0.11=hd590300_0 | ||||||
|  |   - xorg-libxdmcp=1.1.3=h7f98852_0 | ||||||
|  |   - xxhash=0.8.2=hd590300_0 | ||||||
|  |   - xz=5.2.6=h166bdaf_0 | ||||||
|  |   - yaml=0.2.5=h7f98852_2 | ||||||
|  |   - yarl=1.9.3=py310h2372a71_0 | ||||||
|  |   - zipp=3.17.0=pyhd8ed1ab_0 | ||||||
|  |   - zlib=1.2.13=hd590300_5 | ||||||
|  |   - zstd=1.5.5=hfc55251_0 | ||||||
|  |   - pip: | ||||||
|  |       - git+https://github.com/fyfrey/TinyNeuralNetwork.git | ||||||
							
								
								
									
										157
									
								
								machine-learning/ann/export/run.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								machine-learning/ann/export/run.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,157 @@ | |||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import platform | ||||||
|  | import subprocess | ||||||
|  | from abc import abstractmethod | ||||||
|  | 
 | ||||||
|  | import onnx | ||||||
|  | import open_clip | ||||||
|  | import torch | ||||||
|  | from onnx2torch import convert | ||||||
|  | from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed | ||||||
|  | from tinynn.converter import TFLiteConverter | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ExportBase(torch.nn.Module): | ||||||
|  |     input_shape: tuple[int, ...] | ||||||
|  | 
 | ||||||
|  |     def __init__(self, device: torch.device, name: str): | ||||||
|  |         super().__init__() | ||||||
|  |         self.device = device | ||||||
|  |         self.name = name | ||||||
|  |         self.optimize = 5 | ||||||
|  |         self.nchw_transpose = False | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def dummy_input(self) -> torch.FloatTensor: | ||||||
|  |         return torch.rand((1, 3, 224, 224), device=self.device) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ArcFace(ExportBase): | ||||||
|  |     input_shape = (1, 3, 112, 112) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, onnx_model_path: str, device: torch.device): | ||||||
|  |         name, _ = os.path.splitext(os.path.basename(onnx_model_path)) | ||||||
|  |         super().__init__(device, name) | ||||||
|  |         onnx_model = onnx.load_model(onnx_model_path) | ||||||
|  |         make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape) | ||||||
|  |         fix_output_shapes(onnx_model) | ||||||
|  |         self.model = convert(onnx_model).to(device) | ||||||
|  |         if self.device.type == "cuda": | ||||||
|  |             self.model = self.model.half() | ||||||
|  | 
 | ||||||
|  |     def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor: | ||||||
|  |         embedding: torch.FloatTensor = self.model( | ||||||
|  |             input_tensor.half() if self.device.type == "cuda" else input_tensor | ||||||
|  |         ).float() | ||||||
|  |         assert isinstance(embedding, torch.FloatTensor) | ||||||
|  |         return embedding | ||||||
|  | 
 | ||||||
|  |     def dummy_input(self) -> torch.FloatTensor: | ||||||
|  |         return torch.rand(self.input_shape, device=self.device) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RetinaFace(ExportBase): | ||||||
|  |     input_shape = (1, 3, 640, 640) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, onnx_model_path: str, device: torch.device): | ||||||
|  |         name, _ = os.path.splitext(os.path.basename(onnx_model_path)) | ||||||
|  |         super().__init__(device, name) | ||||||
|  |         self.optimize = 3 | ||||||
|  |         self.model = convert(onnx_model_path).eval().to(device) | ||||||
|  |         if self.device.type == "cuda": | ||||||
|  |             self.model = self.model.half() | ||||||
|  | 
 | ||||||
|  |     def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]: | ||||||
|  |         out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor) | ||||||
|  |         return tuple(o.float() for o in out) | ||||||
|  | 
 | ||||||
|  |     def dummy_input(self) -> torch.FloatTensor: | ||||||
|  |         return torch.rand(self.input_shape, device=self.device) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ClipVision(ExportBase): | ||||||
|  |     input_shape = (1, 3, 224, 224) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, model_name: str, weights: str, device: torch.device): | ||||||
|  |         super().__init__(device, model_name + "__" + weights) | ||||||
|  |         self.model = open_clip.create_model( | ||||||
|  |             model_name, | ||||||
|  |             weights, | ||||||
|  |             precision="fp16" if device.type == "cuda" else "fp32", | ||||||
|  |             jit=False, | ||||||
|  |             require_pretrained=True, | ||||||
|  |             device=device, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor: | ||||||
|  |         embedding: torch.Tensor = self.model.encode_image( | ||||||
|  |             input_tensor.half() if self.device.type == "cuda" else input_tensor, | ||||||
|  |             normalize=True, | ||||||
|  |         ).float() | ||||||
|  |         return embedding | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def export(model: ExportBase) -> None: | ||||||
|  |     model.eval() | ||||||
|  |     for param in model.parameters(): | ||||||
|  |         param.requires_grad = False | ||||||
|  |     dummy_input = model.dummy_input() | ||||||
|  |     model(dummy_input) | ||||||
|  |     jit = torch.jit.trace(model, dummy_input)  # type: ignore[no-untyped-call,attr-defined] | ||||||
|  |     tflite_model_path = f"output/{model.name}.tflite" | ||||||
|  |     os.makedirs("output", exist_ok=True) | ||||||
|  | 
 | ||||||
|  |     converter = TFLiteConverter( | ||||||
|  |         jit, | ||||||
|  |         dummy_input, | ||||||
|  |         tflite_model_path, | ||||||
|  |         optimize=model.optimize, | ||||||
|  |         nchw_transpose=model.nchw_transpose, | ||||||
|  |     ) | ||||||
|  |     # segfaults on ARM, must run on x86_64 / AMD64 | ||||||
|  |     converter.convert() | ||||||
|  | 
 | ||||||
|  |     armnn_model_path = f"output/{model.name}.armnn" | ||||||
|  |     os.environ["LD_LIBRARY_PATH"] = "armnn" | ||||||
|  |     subprocess.run( | ||||||
|  |         [ | ||||||
|  |             "./armnnconverter", | ||||||
|  |             "-f", | ||||||
|  |             "tflite-binary", | ||||||
|  |             "-m", | ||||||
|  |             tflite_model_path, | ||||||
|  |             "-i", | ||||||
|  |             "input_tensor", | ||||||
|  |             "-o", | ||||||
|  |             "output_tensor", | ||||||
|  |             "-p", | ||||||
|  |             armnn_model_path, | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main() -> None: | ||||||
|  |     if platform.machine() not in ("x86_64", "AMD64"): | ||||||
|  |         raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}") | ||||||
|  | 
 | ||||||
|  |     device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
|  |     if device.type != "cuda": | ||||||
|  |         logging.warning( | ||||||
|  |             "No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)" | ||||||
|  |         ) | ||||||
|  |     models = [ | ||||||
|  |         ClipVision("ViT-B-32", "openai", device), | ||||||
|  |         ArcFace("buffalo_l_rec.onnx", device), | ||||||
|  |         RetinaFace("buffalo_l_det.onnx", device), | ||||||
|  |     ] | ||||||
|  |     for model in models: | ||||||
|  |         export(model) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         main() | ||||||
| @ -26,6 +26,7 @@ class Settings(BaseSettings): | |||||||
|     request_threads: int = os.cpu_count() or 4 |     request_threads: int = os.cpu_count() or 4 | ||||||
|     model_inter_op_threads: int = 1 |     model_inter_op_threads: int = 1 | ||||||
|     model_intra_op_threads: int = 2 |     model_intra_op_threads: int = 2 | ||||||
|  |     ann: bool = True | ||||||
| 
 | 
 | ||||||
|     class Config: |     class Config: | ||||||
|         env_prefix = "MACHINE_LEARNING_" |         env_prefix = "MACHINE_LEARNING_" | ||||||
|  | |||||||
							
								
								
									
										68
									
								
								machine-learning/app/models/ann.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								machine-learning/app/models/ann.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,68 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  | 
 | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Any, NamedTuple | ||||||
|  | 
 | ||||||
|  | from numpy import ascontiguousarray | ||||||
|  | 
 | ||||||
|  | from ann.ann import Ann | ||||||
|  | from app.schemas import ndarray_f32, ndarray_i32 | ||||||
|  | 
 | ||||||
|  | from ..config import log, settings | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AnnSession: | ||||||
|  |     """ | ||||||
|  |     Wrapper for ANN to be drop-in replacement for ONNX session. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, model_path: Path): | ||||||
|  |         tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann" | ||||||
|  |         with tuning_file.open(mode="a"): | ||||||
|  |             # make sure tuning file exists (without clearing contents) | ||||||
|  |             # once filled, the tuning file reduces the cost/time of the first | ||||||
|  |             # inference after model load by 10s of seconds | ||||||
|  |             pass | ||||||
|  |         self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix()) | ||||||
|  |         log.info("Loading ANN model %s ...", model_path) | ||||||
|  |         cache_file = model_path.with_suffix(".anncache") | ||||||
|  |         save = False | ||||||
|  |         if not cache_file.is_file(): | ||||||
|  |             save = True | ||||||
|  |             with cache_file.open(mode="a"): | ||||||
|  |                 # create empty model cache file | ||||||
|  |                 pass | ||||||
|  | 
 | ||||||
|  |         self.model = self.ann.load( | ||||||
|  |             model_path.as_posix(), | ||||||
|  |             save_cached_network=save, | ||||||
|  |             cached_network_path=cache_file.as_posix(), | ||||||
|  |         ) | ||||||
|  |         log.info("Loaded ANN model with ID %d", self.model) | ||||||
|  | 
 | ||||||
|  |     def __del__(self) -> None: | ||||||
|  |         self.ann.unload(self.model) | ||||||
|  |         log.info("Unloaded ANN model %d", self.model) | ||||||
|  |         self.ann.destroy() | ||||||
|  | 
 | ||||||
|  |     def get_inputs(self) -> list[AnnNode]: | ||||||
|  |         shapes = self.ann.input_shapes[self.model] | ||||||
|  |         return [AnnNode(None, s) for s in shapes] | ||||||
|  | 
 | ||||||
|  |     def get_outputs(self) -> list[AnnNode]: | ||||||
|  |         shapes = self.ann.output_shapes[self.model] | ||||||
|  |         return [AnnNode(None, s) for s in shapes] | ||||||
|  | 
 | ||||||
|  |     def run( | ||||||
|  |         self, | ||||||
|  |         output_names: list[str] | None, | ||||||
|  |         input_feed: dict[str, ndarray_f32] | dict[str, ndarray_i32], | ||||||
|  |         run_options: Any = None, | ||||||
|  |     ) -> list[ndarray_f32]: | ||||||
|  |         inputs: list[ndarray_f32] = [ascontiguousarray(v) for v in input_feed.values()] | ||||||
|  |         return self.ann.execute(self.model, inputs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AnnNode(NamedTuple): | ||||||
|  |     name: str | None | ||||||
|  |     shape: tuple[int, ...] | ||||||
| @ -10,8 +10,11 @@ import onnxruntime as ort | |||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||||
| from typing_extensions import Buffer | from typing_extensions import Buffer | ||||||
| 
 | 
 | ||||||
|  | import ann.ann | ||||||
|  | 
 | ||||||
| from ..config import get_cache_dir, get_hf_model_name, log, settings | from ..config import get_cache_dir, get_hf_model_name, log, settings | ||||||
| from ..schemas import ModelType | from ..schemas import ModelType | ||||||
|  | from .ann import AnnSession | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class InferenceModel(ABC): | class InferenceModel(ABC): | ||||||
| @ -138,6 +141,21 @@ class InferenceModel(ABC): | |||||||
|             self.cache_dir.unlink() |             self.cache_dir.unlink() | ||||||
|         self.cache_dir.mkdir(parents=True, exist_ok=True) |         self.cache_dir.mkdir(parents=True, exist_ok=True) | ||||||
| 
 | 
 | ||||||
|  |     def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession: | ||||||
|  |         armnn_path = model_path.with_suffix(".armnn") | ||||||
|  |         if settings.ann and ann.ann.is_available and armnn_path.is_file(): | ||||||
|  |             session = AnnSession(armnn_path) | ||||||
|  |         elif model_path.is_file(): | ||||||
|  |             session = ort.InferenceSession( | ||||||
|  |                 model_path.as_posix(), | ||||||
|  |                 sess_options=self.sess_options, | ||||||
|  |                 providers=self.providers, | ||||||
|  |                 provider_options=self.provider_options, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"the file model_path='{model_path}' does not exist") | ||||||
|  |         return session | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # HF deep copies configs, so we need to make session options picklable | # HF deep copies configs, so we need to make session options picklable | ||||||
| class PicklableSessionOptions(ort.SessionOptions):  # type: ignore[misc] | class PicklableSessionOptions(ort.SessionOptions):  # type: ignore[misc] | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from pathlib import Path | |||||||
| from typing import Any, Literal | from typing import Any, Literal | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import onnxruntime as ort |  | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from tokenizers import Encoding, Tokenizer | from tokenizers import Encoding, Tokenizer | ||||||
| 
 | 
 | ||||||
| @ -33,24 +32,12 @@ class BaseCLIPEncoder(InferenceModel): | |||||||
|     def _load(self) -> None: |     def _load(self) -> None: | ||||||
|         if self.mode == "text" or self.mode is None: |         if self.mode == "text" or self.mode is None: | ||||||
|             log.debug(f"Loading clip text model '{self.model_name}'") |             log.debug(f"Loading clip text model '{self.model_name}'") | ||||||
| 
 |             self.text_model = self._make_session(self.textual_path) | ||||||
|             self.text_model = ort.InferenceSession( |  | ||||||
|                 self.textual_path.as_posix(), |  | ||||||
|                 sess_options=self.sess_options, |  | ||||||
|                 providers=self.providers, |  | ||||||
|                 provider_options=self.provider_options, |  | ||||||
|             ) |  | ||||||
|             log.debug(f"Loaded clip text model '{self.model_name}'") |             log.debug(f"Loaded clip text model '{self.model_name}'") | ||||||
| 
 | 
 | ||||||
|         if self.mode == "vision" or self.mode is None: |         if self.mode == "vision" or self.mode is None: | ||||||
|             log.debug(f"Loading clip vision model '{self.model_name}'") |             log.debug(f"Loading clip vision model '{self.model_name}'") | ||||||
| 
 |             self.vision_model = self._make_session(self.visual_path) | ||||||
|             self.vision_model = ort.InferenceSession( |  | ||||||
|                 self.visual_path.as_posix(), |  | ||||||
|                 sess_options=self.sess_options, |  | ||||||
|                 providers=self.providers, |  | ||||||
|                 provider_options=self.provider_options, |  | ||||||
|             ) |  | ||||||
|             log.debug(f"Loaded clip vision model '{self.model_name}'") |             log.debug(f"Loaded clip vision model '{self.model_name}'") | ||||||
| 
 | 
 | ||||||
|     def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32: |     def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32: | ||||||
| @ -61,12 +48,10 @@ class BaseCLIPEncoder(InferenceModel): | |||||||
|             case Image.Image(): |             case Image.Image(): | ||||||
|                 if self.mode == "text": |                 if self.mode == "text": | ||||||
|                     raise TypeError("Cannot encode image as text-only model") |                     raise TypeError("Cannot encode image as text-only model") | ||||||
| 
 |  | ||||||
|                 outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0] |                 outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0] | ||||||
|             case str(): |             case str(): | ||||||
|                 if self.mode == "vision": |                 if self.mode == "vision": | ||||||
|                     raise TypeError("Cannot encode text as vision-only model") |                     raise TypeError("Cannot encode text as vision-only model") | ||||||
| 
 |  | ||||||
|                 outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0] |                 outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0] | ||||||
|             case _: |             case _: | ||||||
|                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") |                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ from typing import Any | |||||||
| 
 | 
 | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
| import onnxruntime as ort |  | ||||||
| from insightface.model_zoo import ArcFaceONNX, RetinaFace | from insightface.model_zoo import ArcFaceONNX, RetinaFace | ||||||
| from insightface.utils.face_align import norm_crop | from insightface.utils.face_align import norm_crop | ||||||
| 
 | 
 | ||||||
| @ -27,23 +26,8 @@ class FaceRecognizer(InferenceModel): | |||||||
|         super().__init__(clean_name(model_name), cache_dir, **model_kwargs) |         super().__init__(clean_name(model_name), cache_dir, **model_kwargs) | ||||||
| 
 | 
 | ||||||
|     def _load(self) -> None: |     def _load(self) -> None: | ||||||
|         self.det_model = RetinaFace( |         self.det_model = RetinaFace(session=self._make_session(self.det_file)) | ||||||
|             session=ort.InferenceSession( |         self.rec_model = ArcFaceONNX(self.rec_file.as_posix(), session=self._make_session(self.rec_file)) | ||||||
|                 self.det_file.as_posix(), |  | ||||||
|                 sess_options=self.sess_options, |  | ||||||
|                 providers=self.providers, |  | ||||||
|                 provider_options=self.provider_options, |  | ||||||
|             ), |  | ||||||
|         ) |  | ||||||
|         self.rec_model = ArcFaceONNX( |  | ||||||
|             self.rec_file.as_posix(), |  | ||||||
|             session=ort.InferenceSession( |  | ||||||
|                 self.rec_file.as_posix(), |  | ||||||
|                 sess_options=self.sess_options, |  | ||||||
|                 providers=self.providers, |  | ||||||
|                 provider_options=self.provider_options, |  | ||||||
|             ), |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         self.det_model.prepare( |         self.det_model.prepare( | ||||||
|             ctx_id=0, |             ctx_id=0, | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from PIL import Image | |||||||
| from pytest_mock import MockerFixture | from pytest_mock import MockerFixture | ||||||
| 
 | 
 | ||||||
| from .config import settings | from .config import settings | ||||||
| from .models.base import PicklableSessionOptions | from .models.base import InferenceModel, PicklableSessionOptions | ||||||
| from .models.cache import ModelCache | from .models.cache import ModelCache | ||||||
| from .models.clip import OpenCLIPEncoder | from .models.clip import OpenCLIPEncoder | ||||||
| from .models.facial_recognition import FaceRecognizer | from .models.facial_recognition import FaceRecognizer | ||||||
| @ -36,9 +36,10 @@ class TestCLIP: | |||||||
|         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) |         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) | ||||||
|         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) |         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) | ||||||
|         mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg) |         mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg) | ||||||
|  | 
 | ||||||
|  |         mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value | ||||||
|  |         mocked.run.return_value = [[self.embedding]] | ||||||
|         mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True) |         mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True) | ||||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) |  | ||||||
|         mocked.return_value.run.return_value = [[self.embedding]] |  | ||||||
| 
 | 
 | ||||||
|         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") |         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") | ||||||
|         embedding = clip_encoder.predict(pil_image) |         embedding = clip_encoder.predict(pil_image) | ||||||
| @ -47,7 +48,7 @@ class TestCLIP: | |||||||
|         assert isinstance(embedding, np.ndarray) |         assert isinstance(embedding, np.ndarray) | ||||||
|         assert embedding.shape[0] == clip_model_cfg["embed_dim"] |         assert embedding.shape[0] == clip_model_cfg["embed_dim"] | ||||||
|         assert embedding.dtype == np.float32 |         assert embedding.dtype == np.float32 | ||||||
|         clip_encoder.vision_model.run.assert_called_once() |         mocked.run.assert_called_once() | ||||||
| 
 | 
 | ||||||
|     def test_basic_text( |     def test_basic_text( | ||||||
|         self, |         self, | ||||||
| @ -60,9 +61,10 @@ class TestCLIP: | |||||||
|         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) |         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) | ||||||
|         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) |         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) | ||||||
|         mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg) |         mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg) | ||||||
|  | 
 | ||||||
|  |         mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value | ||||||
|  |         mocked.run.return_value = [[self.embedding]] | ||||||
|         mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True) |         mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True) | ||||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) |  | ||||||
|         mocked.return_value.run.return_value = [[self.embedding]] |  | ||||||
| 
 | 
 | ||||||
|         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") |         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") | ||||||
|         embedding = clip_encoder.predict("test search query") |         embedding = clip_encoder.predict("test search query") | ||||||
| @ -71,7 +73,7 @@ class TestCLIP: | |||||||
|         assert isinstance(embedding, np.ndarray) |         assert isinstance(embedding, np.ndarray) | ||||||
|         assert embedding.shape[0] == clip_model_cfg["embed_dim"] |         assert embedding.shape[0] == clip_model_cfg["embed_dim"] | ||||||
|         assert embedding.dtype == np.float32 |         assert embedding.dtype == np.float32 | ||||||
|         clip_encoder.text_model.run.assert_called_once() |         mocked.run.assert_called_once() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestFaceRecognition: | class TestFaceRecognition: | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user