only load knnx model when required

This commit is contained in:
yoni13 2025-01-12 19:11:16 +00:00
parent 8965a9fb16
commit 4c7ac1438b

View File

@ -23,10 +23,12 @@ class RknnSession:
def __init__(self, model_path: Path | str): def __init__(self, model_path: Path | str):
self.model_path = Path(model_path) self.model_path = Path(model_path)
self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx") self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
self.inputs = None
if 'textual' in str(self.model_path): self.outputs = None
if "textual" in str(self.model_path):
self.tpe = settings.rknn_textual_threads self.tpe = settings.rknn_textual_threads
elif 'visual' in str(self.model_path): elif "visual" in str(self.model_path):
self.tpe = settings.rknn_visual_threads self.tpe = settings.rknn_visual_threads
else: else:
self.tpe = settings.rknn_facial_detection_threads self.tpe = settings.rknn_facial_detection_threads
@ -34,21 +36,27 @@ class RknnSession:
log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.") log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence) self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
def __del__(self): def __del__(self):
self.rknnpool.release() self.rknnpool.release()
def get_inputs(self) -> list[SessionNode]: def get_inputs(self) -> list[SessionNode]:
if not self.inputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.inputs return self.inputs
def get_outputs(self) -> list[SessionNode]: def get_outputs(self) -> list[SessionNode]:
if not self.outputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.outputs return self.outputs
def run( def run(