diff --git a/mealie/db/data_access_layer/_base_access_model.py b/mealie/db/data_access_layer/_base_access_model.py index cda6986db9ae..8ae531884255 100644 --- a/mealie/db/data_access_layer/_base_access_model.py +++ b/mealie/db/data_access_layer/_base_access_model.py @@ -1,20 +1,30 @@ -from typing import Callable, Union +from typing import Callable, Generic, TypeVar, Union from mealie.core.root_logger import get_logger -from mealie.db.models._model_base import SqlAlchemyBase -from pydantic import BaseModel from sqlalchemy import func from sqlalchemy.orm import load_only from sqlalchemy.orm.session import Session logger = get_logger() +T = TypeVar("T") +D = TypeVar("D") -class BaseAccessModel: - def __init__(self, primary_key, sql_model, schema) -> None: - self.primary_key: str = primary_key - self.sql_model: SqlAlchemyBase = sql_model - self.schema: BaseModel = schema + +class BaseAccessModel(Generic[T, D]): + """A Generic BaseAccess Model method to perform common operations on the database + + Args: + Generic ([T]): Represents the Pydantic Model + Generic ([D]): Represents the SqlAlchemyModel Model + """ + + def __init__(self, primary_key: Union[str, int], sql_model: D, schema: T) -> None: + self.primary_key = primary_key + + self.sql_model = sql_model + + self.schema = schema self.observers: list = [] @@ -29,7 +39,7 @@ class BaseAccessModel: def get_all( self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None - ) -> list[dict]: + ) -> list[T]: eff_schema = override_schema or self.schema if order_by: @@ -42,7 +52,7 @@ class BaseAccessModel: return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()] - def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]: + def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[D]: """Queries the database for the selected model. Restricts return responses to the keys specified under "fields" @@ -70,7 +80,7 @@ class BaseAccessModel: results_as_dict = [x.dict() for x in results] return [x.get(self.primary_key) for x in results_as_dict] - def _query_one(self, session: Session, match_value: str, match_key: str = None) -> SqlAlchemyBase: + def _query_one(self, session: Session, match_value: str, match_key: str = None) -> D: """Query the sql database for one item an return the sql alchemy model object. If no match key is provided the primary_key attribute will be used. @@ -89,7 +99,7 @@ class BaseAccessModel: def get( self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None - ) -> Union[BaseModel, list[BaseModel]]: + ) -> Union[T, list[T]]: """Retrieves an entry from the database by matching a key/value pair. If no key is provided the class objects primary key will be used to match against. @@ -121,9 +131,10 @@ class BaseAccessModel: return eff_schema.from_orm(result[0]) except IndexError: return None + return [eff_schema.from_orm(x) for x in result] - def create(self, session: Session, document: dict) -> BaseModel: + def create(self, session: Session, document: T) -> T: """Creates a new database entry for the given SQL Alchemy Model. Args: @@ -134,17 +145,17 @@ class BaseAccessModel: dict: A dictionary representation of the database entry """ document = document if isinstance(document, dict) else document.dict() - new_document = self.sql_model(session=session, **document) session.add(new_document) session.commit() + session.refresh(new_document) if self.observers: self.update_observers() return self.schema.from_orm(new_document) - def update(self, session: Session, match_value: str, new_data: dict) -> BaseModel: + def update(self, session: Session, match_value: str, new_data: dict) -> T: """Update a database entry. Args: session (Session): Database Session @@ -165,7 +176,7 @@ class BaseAccessModel: session.commit() return self.schema.from_orm(entry) - def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel: + def patch(self, session: Session, match_value: str, new_data: dict) -> T: new_data = new_data if isinstance(new_data, dict) else new_data.dict() entry = self._query_one(session=session, match_value=match_value) @@ -178,7 +189,7 @@ class BaseAccessModel: return self.update(session, match_value, entry_as_dict) - def delete(self, session: Session, primary_key_value) -> dict: + def delete(self, session: Session, primary_key_value) -> D: result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one() results_as_model = self.schema.from_orm(result) @@ -205,7 +216,7 @@ class BaseAccessModel: def _count_attribute( self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None - ) -> Union[int, BaseModel]: + ) -> Union[int, T]: eff_schema = override_schema or self.schema # attr_filter = getattr(self.sql_model, attribute_name) diff --git a/mealie/db/data_access_layer/group_access_model.py b/mealie/db/data_access_layer/group_access_model.py index b54ab99a2b5a..5fe86a06acbc 100644 --- a/mealie/db/data_access_layer/group_access_model.py +++ b/mealie/db/data_access_layer/group_access_model.py @@ -1,3 +1,4 @@ +from mealie.db.models.group import Group from mealie.schema.meal_plan.meal import MealPlanOut from mealie.schema.user.user import GroupInDB from sqlalchemy.orm.session import Session @@ -5,7 +6,7 @@ from sqlalchemy.orm.session import Session from ._base_access_model import BaseAccessModel -class GroupDataAccessModel(BaseAccessModel): +class GroupDataAccessModel(BaseAccessModel[GroupInDB, Group]): def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]: """A Helper function to get the group from the database and return a sorted list of diff --git a/mealie/db/data_access_layer/recipe_access_model.py b/mealie/db/data_access_layer/recipe_access_model.py index 1c34e4e86ef5..0903b0a1893a 100644 --- a/mealie/db/data_access_layer/recipe_access_model.py +++ b/mealie/db/data_access_layer/recipe_access_model.py @@ -2,12 +2,13 @@ from random import randint from mealie.db.models.recipe.recipe import RecipeModel from mealie.db.models.recipe.settings import RecipeSettings +from mealie.schema.recipe import Recipe from sqlalchemy.orm.session import Session from ._base_access_model import BaseAccessModel -class RecipeDataAccessModel(BaseAccessModel): +class RecipeDataAccessModel(BaseAccessModel[Recipe, RecipeModel]): def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None): eff_schema = override_schema or self.schema diff --git a/mealie/db/data_access_layer/user_access_model.py b/mealie/db/data_access_layer/user_access_model.py index 04c3ad486c79..067d6d7e41a7 100644 --- a/mealie/db/data_access_layer/user_access_model.py +++ b/mealie/db/data_access_layer/user_access_model.py @@ -1,7 +1,10 @@ +from mealie.db.models.users import User +from mealie.schema.user.user import UserInDB + from ._base_access_model import BaseAccessModel -class UserDataAccessModel(BaseAccessModel): +class UserDataAccessModel(BaseAccessModel[UserInDB, User]): def update_password(self, session, id, password: str): entry = self._query_one(session=session, match_value=id) entry.update_password(password)