raise NotImplementedError for now

This commit is contained in:
yoni13 2025-01-19 12:02:08 +08:00
parent 1653cd9cd7
commit 2b967ca358

View File

@ -4,7 +4,6 @@ from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray from numpy.typing import NDArray
from app.schemas import SessionNode from app.schemas import SessionNode
@ -22,7 +21,6 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr
class RknnSession: class RknnSession:
def __init__(self, model_path: Path | str): def __init__(self, model_path: Path | str):
self.model_path = Path(str(model_path).replace("model", soc_name)) self.model_path = Path(str(model_path).replace("model", soc_name))
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
self.tpe = settings.rknn_threads self.tpe = settings.rknn_threads
@ -33,27 +31,11 @@ class RknnSession:
def __del__(self) -> None: def __del__(self) -> None:
self.rknnpool.release() self.rknnpool.release()
def _load_ort_session(self) -> None:
self.ort_session = ort.InferenceSession(
self.ort_model_path.as_posix(),
)
self.inputs: list[SessionNode] = self.ort_session.get_inputs()
self.outputs: list[SessionNode] = self.ort_session.get_outputs()
del self.ort_session
def get_inputs(self) -> list[SessionNode]: def get_inputs(self) -> list[SessionNode]:
try: raise NotImplementedError
return self.inputs
except AttributeError:
self._load_ort_session()
return self.inputs
def get_outputs(self) -> list[SessionNode]: def get_outputs(self) -> list[SessionNode]:
try: raise NotImplementedError
return self.outputs
except AttributeError:
self._load_ort_session()
return self.outputs
def run( def run(
self, self,