diff --git a/machine-learning/app/sessions/rknn.py b/machine-learning/app/sessions/rknn.py index 9348b3770f..4e5e85a242 100644 --- a/machine-learning/app/sessions/rknn.py +++ b/machine-learning/app/sessions/rknn.py @@ -23,10 +23,12 @@ class RknnSession: def __init__(self, model_path: Path | str): self.model_path = Path(model_path) self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx") - - if 'textual' in str(self.model_path): + self.inputs = None + self.outputs = None + + if "textual" in str(self.model_path): self.tpe = settings.rknn_textual_threads - elif 'visual' in str(self.model_path): + elif "visual" in str(self.model_path): self.tpe = settings.rknn_visual_threads else: self.tpe = settings.rknn_facial_detection_threads @@ -34,21 +36,27 @@ class RknnSession: log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.") self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence) - self.ort_session = ort.InferenceSession( - self.ort_model_path, - ) - self.inputs = self.ort_session.get_inputs() - self.outputs = self.ort_session.get_outputs() - - del self.ort_session - def __del__(self): self.rknnpool.release() def get_inputs(self) -> list[SessionNode]: + if not self.inputs: + self.ort_session = ort.InferenceSession( + self.ort_model_path, + ) + self.inputs = self.ort_session.get_inputs() + self.outputs = self.ort_session.get_outputs() + del self.ort_session return self.inputs def get_outputs(self) -> list[SessionNode]: + if not self.outputs: + self.ort_session = ort.InferenceSession( + self.ort_model_path, + ) + self.inputs = self.ort_session.get_inputs() + self.outputs = self.ort_session.get_outputs() + del self.ort_session return self.outputs def run(