mirror of
https://github.com/immich-app/immich.git
synced 2025-07-09 03:04:16 -04:00
Load model by SOC name
This commit is contained in:
parent
daf886088a
commit
ebdfe1b7b6
@ -8,7 +8,7 @@ import onnxruntime as ort
|
|||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from app.schemas import SessionNode
|
from app.schemas import SessionNode
|
||||||
from rknn.rknnpool import rknnPoolExecutor
|
from rknn.rknnpool import rknnPoolExecutor, soc_name
|
||||||
|
|
||||||
from ..config import log, settings
|
from ..config import log, settings
|
||||||
|
|
||||||
@ -21,8 +21,8 @@ def runInfrence(rknn_lite, input):
|
|||||||
|
|
||||||
class RknnSession:
|
class RknnSession:
|
||||||
def __init__(self, model_path: Path | str):
|
def __init__(self, model_path: Path | str):
|
||||||
self.model_path = Path(model_path)
|
self.model_path = Path(str(model_path).replace("model", soc_name))
|
||||||
self.ort_model_path = str(self.model_path).replace(".rknn", ".onnx")
|
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))
|
||||||
self.inputs = None
|
self.inputs = None
|
||||||
self.outputs = None
|
self.outputs = None
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class RknnSession:
|
|||||||
def get_inputs(self) -> list[SessionNode]:
|
def get_inputs(self) -> list[SessionNode]:
|
||||||
if not self.inputs:
|
if not self.inputs:
|
||||||
self.ort_session = ort.InferenceSession(
|
self.ort_session = ort.InferenceSession(
|
||||||
self.ort_model_path,
|
self.ort_model_path.as_posix(),
|
||||||
)
|
)
|
||||||
self.inputs = self.ort_session.get_inputs()
|
self.inputs = self.ort_session.get_inputs()
|
||||||
self.outputs = self.ort_session.get_outputs()
|
self.outputs = self.ort_session.get_outputs()
|
||||||
@ -52,7 +52,7 @@ class RknnSession:
|
|||||||
def get_outputs(self) -> list[SessionNode]:
|
def get_outputs(self) -> list[SessionNode]:
|
||||||
if not self.outputs:
|
if not self.outputs:
|
||||||
self.ort_session = ort.InferenceSession(
|
self.ort_session = ort.InferenceSession(
|
||||||
self.ort_model_path,
|
self.ort_model_path.as_posix(),
|
||||||
)
|
)
|
||||||
self.inputs = self.ort_session.get_inputs()
|
self.inputs = self.ort_session.get_inputs()
|
||||||
self.outputs = self.ort_session.get_outputs()
|
self.outputs = self.ort_session.get_outputs()
|
||||||
|
@ -20,9 +20,11 @@ try:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
is_available = False
|
is_available = False
|
||||||
|
soc_name = None
|
||||||
is_available = os.path.exists("/sys/kernel/debug/rknpu/load")
|
is_available = os.path.exists("/sys/kernel/debug/rknpu/load")
|
||||||
except (FileNotFoundError, ImportError):
|
except (FileNotFoundError, ImportError):
|
||||||
is_available = False
|
is_available = False
|
||||||
|
soc_name = None
|
||||||
|
|
||||||
|
|
||||||
def initRKNN(rknnModel="./rknnModel/yolov5s.rknn", id=0):
|
def initRKNN(rknnModel="./rknnModel/yolov5s.rknn", id=0):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user