feat(backend): Add Generic Type Hint Support for Data Access Layer

This commit is contained in:
hay-kot 2021-08-27 20:27:20 -08:00
parent 0675c570ce
commit a266a244d9
4 changed files with 37 additions and 21 deletions

View File

@ -1,20 +1,30 @@
from typing import Callable, Union from typing import Callable, Generic, TypeVar, Union
from mealie.core.root_logger import get_logger from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from pydantic import BaseModel
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import load_only from sqlalchemy.orm import load_only
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
logger = get_logger() logger = get_logger()
T = TypeVar("T")
D = TypeVar("D")
class BaseAccessModel:
def __init__(self, primary_key, sql_model, schema) -> None: class BaseAccessModel(Generic[T, D]):
self.primary_key: str = primary_key """A Generic BaseAccess Model method to perform common operations on the database
self.sql_model: SqlAlchemyBase = sql_model
self.schema: BaseModel = schema Args:
Generic ([T]): Represents the Pydantic Model
Generic ([D]): Represents the SqlAlchemyModel Model
"""
def __init__(self, primary_key: Union[str, int], sql_model: D, schema: T) -> None:
self.primary_key = primary_key
self.sql_model = sql_model
self.schema = schema
self.observers: list = [] self.observers: list = []
@ -29,7 +39,7 @@ class BaseAccessModel:
def get_all( def get_all(
self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None
) -> list[dict]: ) -> list[T]:
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
if order_by: if order_by:
@ -42,7 +52,7 @@ class BaseAccessModel:
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()] return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]: def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[D]:
"""Queries the database for the selected model. Restricts return responses to the """Queries the database for the selected model. Restricts return responses to the
keys specified under "fields" keys specified under "fields"
@ -70,7 +80,7 @@ class BaseAccessModel:
results_as_dict = [x.dict() for x in results] results_as_dict = [x.dict() for x in results]
return [x.get(self.primary_key) for x in results_as_dict] return [x.get(self.primary_key) for x in results_as_dict]
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> SqlAlchemyBase: def _query_one(self, session: Session, match_value: str, match_key: str = None) -> D:
"""Query the sql database for one item an return the sql alchemy model """Query the sql database for one item an return the sql alchemy model
object. If no match key is provided the primary_key attribute will be used. object. If no match key is provided the primary_key attribute will be used.
@ -89,7 +99,7 @@ class BaseAccessModel:
def get( def get(
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None
) -> Union[BaseModel, list[BaseModel]]: ) -> Union[T, list[T]]:
"""Retrieves an entry from the database by matching a key/value pair. If no """Retrieves an entry from the database by matching a key/value pair. If no
key is provided the class objects primary key will be used to match against. key is provided the class objects primary key will be used to match against.
@ -121,9 +131,10 @@ class BaseAccessModel:
return eff_schema.from_orm(result[0]) return eff_schema.from_orm(result[0])
except IndexError: except IndexError:
return None return None
return [eff_schema.from_orm(x) for x in result] return [eff_schema.from_orm(x) for x in result]
def create(self, session: Session, document: dict) -> BaseModel: def create(self, session: Session, document: T) -> T:
"""Creates a new database entry for the given SQL Alchemy Model. """Creates a new database entry for the given SQL Alchemy Model.
Args: Args:
@ -134,17 +145,17 @@ class BaseAccessModel:
dict: A dictionary representation of the database entry dict: A dictionary representation of the database entry
""" """
document = document if isinstance(document, dict) else document.dict() document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=session, **document) new_document = self.sql_model(session=session, **document)
session.add(new_document) session.add(new_document)
session.commit() session.commit()
session.refresh(new_document)
if self.observers: if self.observers:
self.update_observers() self.update_observers()
return self.schema.from_orm(new_document) return self.schema.from_orm(new_document)
def update(self, session: Session, match_value: str, new_data: dict) -> BaseModel: def update(self, session: Session, match_value: str, new_data: dict) -> T:
"""Update a database entry. """Update a database entry.
Args: Args:
session (Session): Database Session session (Session): Database Session
@ -165,7 +176,7 @@ class BaseAccessModel:
session.commit() session.commit()
return self.schema.from_orm(entry) return self.schema.from_orm(entry)
def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel: def patch(self, session: Session, match_value: str, new_data: dict) -> T:
new_data = new_data if isinstance(new_data, dict) else new_data.dict() new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(session=session, match_value=match_value) entry = self._query_one(session=session, match_value=match_value)
@ -178,7 +189,7 @@ class BaseAccessModel:
return self.update(session, match_value, entry_as_dict) return self.update(session, match_value, entry_as_dict)
def delete(self, session: Session, primary_key_value) -> dict: def delete(self, session: Session, primary_key_value) -> D:
result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one() result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one()
results_as_model = self.schema.from_orm(result) results_as_model = self.schema.from_orm(result)
@ -205,7 +216,7 @@ class BaseAccessModel:
def _count_attribute( def _count_attribute(
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
) -> Union[int, BaseModel]: ) -> Union[int, T]:
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
# attr_filter = getattr(self.sql_model, attribute_name) # attr_filter = getattr(self.sql_model, attribute_name)

View File

@ -1,3 +1,4 @@
from mealie.db.models.group import Group
from mealie.schema.meal_plan.meal import MealPlanOut from mealie.schema.meal_plan.meal import MealPlanOut
from mealie.schema.user.user import GroupInDB from mealie.schema.user.user import GroupInDB
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -5,7 +6,7 @@ from sqlalchemy.orm.session import Session
from ._base_access_model import BaseAccessModel from ._base_access_model import BaseAccessModel
class GroupDataAccessModel(BaseAccessModel): class GroupDataAccessModel(BaseAccessModel[GroupInDB, Group]):
def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]: def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]:
"""A Helper function to get the group from the database and return a sorted list of """A Helper function to get the group from the database and return a sorted list of

View File

@ -2,12 +2,13 @@ from random import randint
from mealie.db.models.recipe.recipe import RecipeModel from mealie.db.models.recipe.recipe import RecipeModel
from mealie.db.models.recipe.settings import RecipeSettings from mealie.db.models.recipe.settings import RecipeSettings
from mealie.schema.recipe import Recipe
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from ._base_access_model import BaseAccessModel from ._base_access_model import BaseAccessModel
class RecipeDataAccessModel(BaseAccessModel): class RecipeDataAccessModel(BaseAccessModel[Recipe, RecipeModel]):
def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None): def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None):
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema

View File

@ -1,7 +1,10 @@
from mealie.db.models.users import User
from mealie.schema.user.user import UserInDB
from ._base_access_model import BaseAccessModel from ._base_access_model import BaseAccessModel
class UserDataAccessModel(BaseAccessModel): class UserDataAccessModel(BaseAccessModel[UserInDB, User]):
def update_password(self, session, id, password: str): def update_password(self, session, id, password: str):
entry = self._query_one(session=session, match_value=id) entry = self._query_one(session=session, match_value=id)
entry.update_password(password) entry.update_password(password)