only load knnx model when required

This commit is contained in:
yoni13 2025-01-12 19:11:16 +00:00
parent 8965a9fb16
commit 4c7ac1438b

View File

@ -23,10 +23,12 @@ class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(model_path)
self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
if 'textual' in str(self.model_path):
self.inputs = None
self.outputs = None
if "textual" in str(self.model_path):
self.tpe = settings.rknn_textual_threads
elif 'visual' in str(self.model_path):
elif "visual" in str(self.model_path):
self.tpe = settings.rknn_visual_threads
else:
self.tpe = settings.rknn_facial_detection_threads
@ -34,21 +36,27 @@ class RknnSession:
log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
def __del__(self):
self.rknnpool.release()
def get_inputs(self) -> list[SessionNode]:
if not self.inputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.inputs
def get_outputs(self) -> list[SessionNode]:
if not self.outputs:
self.ort_session = ort.InferenceSession(
self.ort_model_path,
)
self.inputs = self.ort_session.get_inputs()
self.outputs = self.ort_session.get_outputs()
del self.ort_session
return self.outputs
def run(