mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-09 03:04:54 -04:00
feat(backend): ✨ Add Generic Type Hint Support for Data Access Layer
This commit is contained in:
parent
0675c570ce
commit
a266a244d9
@ -1,20 +1,30 @@
|
|||||||
from typing import Callable, Union
|
from typing import Callable, Generic, TypeVar, Union
|
||||||
|
|
||||||
from mealie.core.root_logger import get_logger
|
from mealie.core.root_logger import get_logger
|
||||||
from mealie.db.models._model_base import SqlAlchemyBase
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import load_only
|
from sqlalchemy.orm import load_only
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
D = TypeVar("D")
|
||||||
|
|
||||||
class BaseAccessModel:
|
|
||||||
def __init__(self, primary_key, sql_model, schema) -> None:
|
class BaseAccessModel(Generic[T, D]):
|
||||||
self.primary_key: str = primary_key
|
"""A Generic BaseAccess Model method to perform common operations on the database
|
||||||
self.sql_model: SqlAlchemyBase = sql_model
|
|
||||||
self.schema: BaseModel = schema
|
Args:
|
||||||
|
Generic ([T]): Represents the Pydantic Model
|
||||||
|
Generic ([D]): Represents the SqlAlchemyModel Model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, primary_key: Union[str, int], sql_model: D, schema: T) -> None:
|
||||||
|
self.primary_key = primary_key
|
||||||
|
|
||||||
|
self.sql_model = sql_model
|
||||||
|
|
||||||
|
self.schema = schema
|
||||||
|
|
||||||
self.observers: list = []
|
self.observers: list = []
|
||||||
|
|
||||||
@ -29,7 +39,7 @@ class BaseAccessModel:
|
|||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None
|
self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None
|
||||||
) -> list[dict]:
|
) -> list[T]:
|
||||||
eff_schema = override_schema or self.schema
|
eff_schema = override_schema or self.schema
|
||||||
|
|
||||||
if order_by:
|
if order_by:
|
||||||
@ -42,7 +52,7 @@ class BaseAccessModel:
|
|||||||
|
|
||||||
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
|
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
|
||||||
|
|
||||||
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]:
|
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[D]:
|
||||||
"""Queries the database for the selected model. Restricts return responses to the
|
"""Queries the database for the selected model. Restricts return responses to the
|
||||||
keys specified under "fields"
|
keys specified under "fields"
|
||||||
|
|
||||||
@ -70,7 +80,7 @@ class BaseAccessModel:
|
|||||||
results_as_dict = [x.dict() for x in results]
|
results_as_dict = [x.dict() for x in results]
|
||||||
return [x.get(self.primary_key) for x in results_as_dict]
|
return [x.get(self.primary_key) for x in results_as_dict]
|
||||||
|
|
||||||
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> SqlAlchemyBase:
|
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> D:
|
||||||
"""Query the sql database for one item an return the sql alchemy model
|
"""Query the sql database for one item an return the sql alchemy model
|
||||||
object. If no match key is provided the primary_key attribute will be used.
|
object. If no match key is provided the primary_key attribute will be used.
|
||||||
|
|
||||||
@ -89,7 +99,7 @@ class BaseAccessModel:
|
|||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None
|
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None
|
||||||
) -> Union[BaseModel, list[BaseModel]]:
|
) -> Union[T, list[T]]:
|
||||||
"""Retrieves an entry from the database by matching a key/value pair. If no
|
"""Retrieves an entry from the database by matching a key/value pair. If no
|
||||||
key is provided the class objects primary key will be used to match against.
|
key is provided the class objects primary key will be used to match against.
|
||||||
|
|
||||||
@ -121,9 +131,10 @@ class BaseAccessModel:
|
|||||||
return eff_schema.from_orm(result[0])
|
return eff_schema.from_orm(result[0])
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return [eff_schema.from_orm(x) for x in result]
|
return [eff_schema.from_orm(x) for x in result]
|
||||||
|
|
||||||
def create(self, session: Session, document: dict) -> BaseModel:
|
def create(self, session: Session, document: T) -> T:
|
||||||
"""Creates a new database entry for the given SQL Alchemy Model.
|
"""Creates a new database entry for the given SQL Alchemy Model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -134,17 +145,17 @@ class BaseAccessModel:
|
|||||||
dict: A dictionary representation of the database entry
|
dict: A dictionary representation of the database entry
|
||||||
"""
|
"""
|
||||||
document = document if isinstance(document, dict) else document.dict()
|
document = document if isinstance(document, dict) else document.dict()
|
||||||
|
|
||||||
new_document = self.sql_model(session=session, **document)
|
new_document = self.sql_model(session=session, **document)
|
||||||
session.add(new_document)
|
session.add(new_document)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
session.refresh(new_document)
|
||||||
|
|
||||||
if self.observers:
|
if self.observers:
|
||||||
self.update_observers()
|
self.update_observers()
|
||||||
|
|
||||||
return self.schema.from_orm(new_document)
|
return self.schema.from_orm(new_document)
|
||||||
|
|
||||||
def update(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
|
def update(self, session: Session, match_value: str, new_data: dict) -> T:
|
||||||
"""Update a database entry.
|
"""Update a database entry.
|
||||||
Args:
|
Args:
|
||||||
session (Session): Database Session
|
session (Session): Database Session
|
||||||
@ -165,7 +176,7 @@ class BaseAccessModel:
|
|||||||
session.commit()
|
session.commit()
|
||||||
return self.schema.from_orm(entry)
|
return self.schema.from_orm(entry)
|
||||||
|
|
||||||
def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
|
def patch(self, session: Session, match_value: str, new_data: dict) -> T:
|
||||||
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
|
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
|
||||||
|
|
||||||
entry = self._query_one(session=session, match_value=match_value)
|
entry = self._query_one(session=session, match_value=match_value)
|
||||||
@ -178,7 +189,7 @@ class BaseAccessModel:
|
|||||||
|
|
||||||
return self.update(session, match_value, entry_as_dict)
|
return self.update(session, match_value, entry_as_dict)
|
||||||
|
|
||||||
def delete(self, session: Session, primary_key_value) -> dict:
|
def delete(self, session: Session, primary_key_value) -> D:
|
||||||
result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one()
|
result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one()
|
||||||
results_as_model = self.schema.from_orm(result)
|
results_as_model = self.schema.from_orm(result)
|
||||||
|
|
||||||
@ -205,7 +216,7 @@ class BaseAccessModel:
|
|||||||
|
|
||||||
def _count_attribute(
|
def _count_attribute(
|
||||||
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
|
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
|
||||||
) -> Union[int, BaseModel]:
|
) -> Union[int, T]:
|
||||||
eff_schema = override_schema or self.schema
|
eff_schema = override_schema or self.schema
|
||||||
# attr_filter = getattr(self.sql_model, attribute_name)
|
# attr_filter = getattr(self.sql_model, attribute_name)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from mealie.db.models.group import Group
|
||||||
from mealie.schema.meal_plan.meal import MealPlanOut
|
from mealie.schema.meal_plan.meal import MealPlanOut
|
||||||
from mealie.schema.user.user import GroupInDB
|
from mealie.schema.user.user import GroupInDB
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
@ -5,7 +6,7 @@ from sqlalchemy.orm.session import Session
|
|||||||
from ._base_access_model import BaseAccessModel
|
from ._base_access_model import BaseAccessModel
|
||||||
|
|
||||||
|
|
||||||
class GroupDataAccessModel(BaseAccessModel):
|
class GroupDataAccessModel(BaseAccessModel[GroupInDB, Group]):
|
||||||
def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]:
|
def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]:
|
||||||
"""A Helper function to get the group from the database and return a sorted list of
|
"""A Helper function to get the group from the database and return a sorted list of
|
||||||
|
|
||||||
|
@ -2,12 +2,13 @@ from random import randint
|
|||||||
|
|
||||||
from mealie.db.models.recipe.recipe import RecipeModel
|
from mealie.db.models.recipe.recipe import RecipeModel
|
||||||
from mealie.db.models.recipe.settings import RecipeSettings
|
from mealie.db.models.recipe.settings import RecipeSettings
|
||||||
|
from mealie.schema.recipe import Recipe
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
from ._base_access_model import BaseAccessModel
|
from ._base_access_model import BaseAccessModel
|
||||||
|
|
||||||
|
|
||||||
class RecipeDataAccessModel(BaseAccessModel):
|
class RecipeDataAccessModel(BaseAccessModel[Recipe, RecipeModel]):
|
||||||
def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None):
|
def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None):
|
||||||
eff_schema = override_schema or self.schema
|
eff_schema = override_schema or self.schema
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
from mealie.db.models.users import User
|
||||||
|
from mealie.schema.user.user import UserInDB
|
||||||
|
|
||||||
from ._base_access_model import BaseAccessModel
|
from ._base_access_model import BaseAccessModel
|
||||||
|
|
||||||
|
|
||||||
class UserDataAccessModel(BaseAccessModel):
|
class UserDataAccessModel(BaseAccessModel[UserInDB, User]):
|
||||||
def update_password(self, session, id, password: str):
|
def update_password(self, session, id, password: str):
|
||||||
entry = self._query_one(session=session, match_value=id)
|
entry = self._query_one(session=session, match_value=id)
|
||||||
entry.update_password(password)
|
entry.update_password(password)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user