raise NotImplementedError for now

This commit is contained in:
yoni13 2025-01-19 12:02:08 +08:00
parent 1653cd9cd7
commit 2b967ca358

View File

@ -4,7 +4,6 @@ from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from app.schemas import SessionNode
@ -22,7 +21,6 @@ def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArr
class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(str(model_path).replace("model", soc_name))
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
self.tpe = settings.rknn_threads
@ -33,28 +31,12 @@ class RknnSession:
def __del__(self) -> None:
self.rknnpool.release()
def _load_ort_session(self) -> None:
self.ort_session = ort.InferenceSession(
self.ort_model_path.as_posix(),
)
self.inputs: list[SessionNode] = self.ort_session.get_inputs()
self.outputs: list[SessionNode] = self.ort_session.get_outputs()
del self.ort_session
def get_inputs(self) -> list[SessionNode]:
try:
return self.inputs
except AttributeError:
self._load_ort_session()
return self.inputs
raise NotImplementedError
def get_outputs(self) -> list[SessionNode]:
try:
return self.outputs
except AttributeError:
self._load_ort_session()
return self.outputs
raise NotImplementedError
def run(
self,
output_names: list[str] | None,