mirror of
https://github.com/immich-app/immich.git
synced 2025-07-07 10:14:08 -04:00
raise NotImplementedError for now
This commit is contained in:
parent
1653cd9cd7
commit
2b967ca358
@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from app.schemas import SessionNode
|
||||
@ -22,7 +21,6 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr
|
||||
class RknnSession:
|
||||
def __init__(self, model_path: Path | str):
|
||||
self.model_path = Path(str(model_path).replace("model", soc_name))
|
||||
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
|
||||
|
||||
self.tpe = settings.rknn_threads
|
||||
|
||||
@ -33,27 +31,11 @@ class RknnSession:
|
||||
def __del__(self) -> None:
|
||||
self.rknnpool.release()
|
||||
|
||||
def _load_ort_session(self) -> None:
|
||||
self.ort_session = ort.InferenceSession(
|
||||
self.ort_model_path.as_posix(),
|
||||
)
|
||||
self.inputs: list[SessionNode] = self.ort_session.get_inputs()
|
||||
self.outputs: list[SessionNode] = self.ort_session.get_outputs()
|
||||
del self.ort_session
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
try:
|
||||
return self.inputs
|
||||
except AttributeError:
|
||||
self._load_ort_session()
|
||||
return self.inputs
|
||||
raise NotImplementedError
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
try:
|
||||
return self.outputs
|
||||
except AttributeError:
|
||||
self._load_ort_session()
|
||||
return self.outputs
|
||||
raise NotImplementedError
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user